Feat: Add get_user tool

This commit is contained in:
vimukthiRajapaksha
2025-08-06 14:53:25 +05:30
parent dec904fbbe
commit fa6cd65285
7 changed files with 249 additions and 51 deletions

View File

@@ -12,6 +12,6 @@ FHIR_MCP_PORT="8000"
FHIR_SERVER_BASE_URL=""
FHIR_SERVER_CLIENT_ID=""
FHIR_SERVER_CLIENT_SECRET=""
FHIR_SERVER_SCOPES=""
FHIR_SERVER_SCOPES="fhirUser openid"
# (Optional) If set, the authorization flow will be skipped and this access token will be used directly
# FHIR_SERVER_ACCESS_TOKEN=""

View File

@@ -14,14 +14,15 @@
# specific language governing permissions and limitations
# under the License.
from .common import handle_successful_authentication, handle_failed_authentication
from .common import handle_failed_authentication
from .server_provider import OAuthServerProvider
from .types import ServerConfigs, OAuthToken
from .types import ServerConfigs, OAuthToken, IDToken, decode_jws
__all__ = [
"handle_successful_authentication",
"handle_failed_authentication",
"OAuthServerProvider",
"OAuthToken",
"ServerConfigs",
"IDToken",
"decode_jws",
]

View File

@@ -69,26 +69,6 @@ def get_endpoint(metadata, endpoint: str) -> str:
return str(value)
def handle_successful_authentication() -> HTMLResponse:
return HTMLResponse(
f"""
<!DOCTYPE html>
<html>
<head>
<title>FHIR MCP Server | Authentication Complete</title>
</head>
<body style="font-family:Arial,sans-serif;display:flex;justify-content:center;align-items:center;height:100vh;margin:0;background:#F5F5F5;">
<div style="text-align:center;padding:20px;background:#E5F5E0;border-radius:8px;box-shadow:0 2px 4px rgba(0,0,0,0.1);width:400px;">
<h2 style="color:#000000;margin:0 0 16px;">Authentication Successful!</h2>
<p style="color:#000000;margin:0 0 20px;">You can close this window and return to the application.</p>
</div>
<script>setTimeout(() => window.close(), 2000);</script>
</body>
</html>
"""
)
def handle_failed_authentication(error_desc: str = "") -> HTMLResponse:
return HTMLResponse(
f"""
@@ -142,6 +122,9 @@ async def perform_token_flow(
headers=headers,
timeout=timeout,
)
logger.debug(
f"Token endpoint response: {response.status_code} - {response.text}"
)
if response.status_code != 200:
logger.error(

View File

@@ -56,7 +56,7 @@ class OAuthServerProvider(OAuthAuthorizationServerProvider):
self.auth_code_mapping: Dict[str, AuthorizationCode] = {}
self.token_mapping: Dict[str, AccessToken | RefreshToken] = {}
self.state_mapping: Dict[str, Dict[str, str]] = {}
self.token_metadata_mapping: Dict[str, OAuth2Token] = {}
self._metadata: OAuthMetadata | None = None
async def initialize(self) -> None:
@@ -69,6 +69,7 @@ class OAuthServerProvider(OAuthAuthorizationServerProvider):
async def register_client(self, client_info: OAuthClientInformationFull):
"""Register a new OAuth client."""
client_info.client_id = f"fhir_mcp_{client_info.client_id}"
self.clients[client_info.client_id] = client_info
async def authorize(
@@ -203,6 +204,8 @@ class OAuthServerProvider(OAuthAuthorizationServerProvider):
expires_at=int(token.expires_at or 3600),
)
self.token_metadata_mapping[token.access_token] = token
return OAuthToken(
access_token=mcp_access_token,
refresh_token=mcp_refresh_token,
@@ -280,6 +283,8 @@ class OAuthServerProvider(OAuthAuthorizationServerProvider):
client_id=client.client_id,
)
self.token_metadata_mapping[new_token.access_token] = new_token
return OAuthToken(
access_token=mcp_access_token,
refresh_token=mcp_refresh_token,

View File

@@ -14,9 +14,16 @@
# specific language governing permissions and limitations
# under the License.
import base64
import json
import logging
from typing import Any, Dict
from pydantic import AnyHttpUrl, BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
logger = logging.getLogger(__name__)
class ServerConfigs(BaseSettings):
"""Contains environment configurations of the MCP server."""
@@ -39,7 +46,7 @@ class ServerConfigs(BaseSettings):
server_client_id: str = ""
server_client_secret: str = ""
server_scopes: str = ""
server_base_url: str = ""
server_base_url: str
server_access_token: str | None = None
def callback_url(
@@ -59,7 +66,11 @@ class ServerConfigs(BaseSettings):
def scopes(self) -> list[str]:
# If the raw value is a string, split on empty spaces
if isinstance(self.server_scopes, str):
return [scope.strip() for scope in self.server_scopes.split(" ") if scope.strip()]
return [
scope.strip()
for scope in self.server_scopes.split(" ")
if scope.strip()
]
return [self.server_scopes]
@property
@@ -110,11 +121,28 @@ class OAuthToken(BaseModel):
scope: str | None = None
refresh_token: str | None = None
expires_at: float | None = None
id_token: str | None = None
client_id: str | None = None
@property
def scopes(self) -> list[str]:
return self.scope.split(" ") if self.scope else []
def get_id_token(self) -> "IDToken | None":
"""
Parse the id_token and return an IDToken object.
Returns:
An IDToken instance populated from the JWT payload or None if parsing fails.
"""
payload: Dict[str, Any] | None = (
decode_jws(self.id_token) if self.id_token else None
)
if not payload:
return None
return IDToken.model_validate(payload)
class AuthorizationCode(BaseModel):
code: str
@@ -125,3 +153,65 @@ class AuthorizationCode(BaseModel):
code_challenge: str
redirect_uri: AnyHttpUrl
redirect_uri_provided_explicitly: bool
class IDToken(BaseModel):
fhirUser: str | None = None
def parse_fhir_user(self) -> tuple[str, str] | None:
"""
Parse the fhirUser URL to extract resource type and resource ID.
The fhirUser URL MAY be absolute (e.g., https://ehr.example.org/Practitioner/123),
or it MAY be relative to the FHIR server base URL (e.g., Practitioner/123).
Returns:
A tuple of (resource_type, resource_id) if fhirUser is valid,
None otherwise.
"""
if not self.fhirUser:
return None
logger.debug(f"Parsing fhirUser: {self.fhirUser}")
parts: list[str] = self.fhirUser.rstrip('/').split("/")
if len(parts) < 2:
return None
return parts[len(parts) - 2], parts[len(parts) - 1]
@property
def resource_type(self) -> str | None:
"""Get the FHIR resource type from fhirUser URL."""
parsed = self.parse_fhir_user()
return parsed[0] if parsed else None
@property
def resource_id(self) -> str | None:
"""Get the FHIR resource ID from fhirUser URL."""
parsed = self.parse_fhir_user()
return parsed[1] if parsed else None
def decode_jws(jws: str) -> Dict[str, Any] | None:
"""
Decode the provided JWS payload.
Returns:
The decoded JWS payload as a dictionary.
"""
try:
parts: list[str] = jws.split(".")
if len(parts) != 3:
logger.debug(
f"Decoding JWS failed: Invalid JWS format, expected 3 parts but got {len(parts)}: {jws}"
)
return None
padded: str = parts[1] + "=" * (4 - len(parts[1]) % 4)
decoded: bytes = base64.urlsafe_b64decode(padded)
return json.loads(decoded)
except Exception as e:
logger.exception("Error decoding JWS token. Caused by, ", exc_info=e)
return None

View File

@@ -18,6 +18,7 @@ import click
import logging
from fhir_mcp_server.utils import (
build_user_profile,
create_async_fhir_client,
get_bundle_entries,
get_default_headers,
@@ -62,17 +63,27 @@ async def get_user_access_token(click_ctx: click.Context) -> OAuthToken | None:
"""
if configs.server_access_token:
logger.debug("Using configured FHIR access token for user.")
return OAuthToken(access_token=configs.server_access_token, token_type="Bearer")
user_token: AccessToken | None = get_access_token()
if not user_token:
logger.error("Failed to obtain client access token.")
raise ValueError("Failed to obtain client access token.")
return OAuthToken(
access_token=configs.server_access_token,
token_type="Bearer",
client_id=configs.server_client_id,
)
user_token: AccessToken | None = get_access_token()
logger.debug("Obtained client access token from context.")
# Return the FHIR access token
return user_token
return (
OAuthToken(
access_token=user_token.token,
client_id=configs.server_client_id,
token_type="Bearer",
expires_at=user_token.expires_at,
scope=" ".join(user_token.scopes),
)
if user_token
else None
)
@click.pass_context
@@ -86,17 +97,15 @@ async def get_async_fhir_client(click_ctx: click.Context) -> AsyncFHIRClient:
"extra_headers": get_default_headers(),
}
disable_auth: bool = (
click_ctx.obj.get("disable_auth") if click_ctx.obj else False
)
if not disable_auth:
user_token: AccessToken | None = await get_user_access_token()
if not user_token:
user_token: OAuthToken | None = await get_user_access_token()
disable_auth: bool = click_ctx.obj.get("disable_auth") if click_ctx.obj else False
if not user_token:
if not disable_auth:
logger.error("User is not authenticated.")
raise ValueError("User is not authenticated.")
client_kwargs["access_token"] = user_token.token
else:
logger.debug("FHIR authentication is disabled.")
client_kwargs["access_token"] = user_token.access_token
return await create_async_fhir_client(**client_kwargs)
@@ -205,9 +214,7 @@ def register_mcp_tools(mcp: FastMCP) -> None:
]:
try:
logger.debug(f"Invoked with resource_type='{type}'")
data: Dict[str, Any] = await get_capability_statement(
configs.metadata_url
)
data: Dict[str, Any] = await get_capability_statement(configs.metadata_url)
for resource in data["rest"][0]["resource"]:
if resource.get("type") == type:
logger.info(
@@ -690,6 +697,89 @@ def register_mcp_tools(mcp: FastMCP) -> None:
)
return await get_operation_outcome_exception()
@mcp.tool(
description=(
"Retrieves the authenticated user's FHIR profile. "
"Use this tool when you need to access the current user's demographic and contact details."
)
)
async def get_user() -> Annotated[
list[Dict[str, Any]] | Dict[str, Any],
Field(
description="A dictionary containing the authenticated user's demographic information such as 'id', 'name', and 'birthDate'."
),
]:
try:
logger.debug("Retrieving authenticated user's profile.")
# Validate user authentication
user_token = await get_user_access_token()
if not user_token:
logger.debug("Unauthorized access attempt to get_me endpoint.")
return {}
# Retrieve token metadata
token_metadata = server_provider.token_metadata_mapping.get(
user_token.access_token
)
if not token_metadata:
logger.debug("Token metadata not found for authenticated user.")
return {}
# Extract ID token information
id_token = token_metadata.get_id_token()
if not id_token:
logger.debug("ID token not found in token metadata.")
return {}
# Validate resource identifiers
resource_id = id_token.resource_id
resource_type = id_token.resource_type
if not resource_id or not resource_type:
logger.debug("Resource ID or type missing from ID token.")
return {}
logger.debug(f"Fetching FHIR resource: {resource_type}/{resource_id}")
# Fetch user's FHIR resource
client: AsyncFHIRClient = await get_async_fhir_client()
resource: Dict[str, Any] = await client.get(
resource_type_or_resource_or_ref=resource_type, id_or_ref=resource_id
)
# Build response with only available fields
profile: Dict[str, Any] = build_user_profile(resource)
logger.debug(
f"Successfully retrieved profile for user: {resource_type}/{resource_id}"
)
return profile
except ValueError as ex:
logger.exception(
"Authorization error occurred while reading user resource. Caused by, ",
exc_info=ex,
)
return await get_operation_outcome(
code="forbidden",
diagnostics="The user does not have the rights to perform read operations.",
)
except OperationOutcome as ex:
logger.exception(
f"FHIR server error occurred while reading user resource. Caused by, ",
exc_info=ex,
)
return ex.resource.get("issue") or await get_operation_outcome_exception()
except Exception as ex:
logger.exception(
"Unexpected error occurred while reading user resource. Caused by, ",
exc_info=ex,
)
return await get_operation_outcome_exception()
@click.command()
@click.option(
@@ -714,9 +804,7 @@ def register_mcp_tools(mcp: FastMCP) -> None:
help="Disable authorization between MCP client and MCP server. [default: False]",
)
@click.pass_context
def main(
click_ctx: click.Context, transport, log_level, disable_auth
) -> int:
def main(click_ctx: click.Context, transport, log_level, disable_auth) -> int:
"""
FHIR MCP Server - helping you expose any FHIR Server or API as a MCP Server.
"""
@@ -731,9 +819,7 @@ def main(
try:
mcp: FastMCP = configure_mcp_server(disable_auth)
register_mcp_tools(mcp=mcp)
register_mcp_routes(
mcp=mcp, server_provider=server_provider
)
register_mcp_routes(mcp=mcp, server_provider=server_provider)
logger.info(f"Starting FHIR MCP server with {transport} transport")
mcp.run(transport=transport)
except Exception as ex:

View File

@@ -111,6 +111,7 @@ async def get_capability_statement(metadata_url: str) -> Dict[str, Any]:
Discover CapabilityStatement from server's metadata endpoint.
"""
try:
logger.debug(f"Fetching CapabilityStatement from {metadata_url}")
async with create_mcp_http_client() as client:
response = await client.get(url=metadata_url, headers=get_default_headers())
response.raise_for_status()
@@ -126,3 +127,35 @@ async def get_capability_statement(metadata_url: str) -> Dict[str, Any]:
def get_default_headers() -> Dict[str, str]:
return {"Accept": "application/fhir+json", "Content-Type": "application/fhir+json"}
def build_user_profile(resource: Dict[str, Any]) -> Dict[str, Any]:
"""
Build user profile dictionary from FHIR resource.
Args:
resource: The FHIR resource dictionary of the user.
Returns:
Dict containing only mandatory user fields
"""
# Define fields to extract from the resource
fields_to_extract = [
"id",
"resourceType",
"name",
"gender",
"birthDate",
"telecom",
"address",
]
profile: Dict[str, Any] = {}
# Add fields only if they exist and have values
for field in fields_to_extract:
value = resource.get(field)
if value is not None:
profile[field] = value
return profile