mirror of
https://github.com/wso2/fhir-mcp-server.git
synced 2025-11-14 22:18:14 +03:00
Feat: Add get_user tool
This commit is contained in:
@@ -12,6 +12,6 @@ FHIR_MCP_PORT="8000"
|
||||
FHIR_SERVER_BASE_URL=""
|
||||
FHIR_SERVER_CLIENT_ID=""
|
||||
FHIR_SERVER_CLIENT_SECRET=""
|
||||
FHIR_SERVER_SCOPES=""
|
||||
FHIR_SERVER_SCOPES="fhirUser openid"
|
||||
# (Optional) If set, the authorization flow will be skipped and this access token will be used directly
|
||||
# FHIR_SERVER_ACCESS_TOKEN=""
|
||||
|
||||
@@ -14,14 +14,15 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .common import handle_successful_authentication, handle_failed_authentication
|
||||
from .common import handle_failed_authentication
|
||||
from .server_provider import OAuthServerProvider
|
||||
from .types import ServerConfigs, OAuthToken
|
||||
from .types import ServerConfigs, OAuthToken, IDToken, decode_jws
|
||||
|
||||
__all__ = [
|
||||
"handle_successful_authentication",
|
||||
"handle_failed_authentication",
|
||||
"OAuthServerProvider",
|
||||
"OAuthToken",
|
||||
"ServerConfigs",
|
||||
"IDToken",
|
||||
"decode_jws",
|
||||
]
|
||||
|
||||
@@ -69,26 +69,6 @@ def get_endpoint(metadata, endpoint: str) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
def handle_successful_authentication() -> HTMLResponse:
|
||||
return HTMLResponse(
|
||||
f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>FHIR MCP Server | Authentication Complete</title>
|
||||
</head>
|
||||
<body style="font-family:Arial,sans-serif;display:flex;justify-content:center;align-items:center;height:100vh;margin:0;background:#F5F5F5;">
|
||||
<div style="text-align:center;padding:20px;background:#E5F5E0;border-radius:8px;box-shadow:0 2px 4px rgba(0,0,0,0.1);width:400px;">
|
||||
<h2 style="color:#000000;margin:0 0 16px;">Authentication Successful!</h2>
|
||||
<p style="color:#000000;margin:0 0 20px;">You can close this window and return to the application.</p>
|
||||
</div>
|
||||
<script>setTimeout(() => window.close(), 2000);</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def handle_failed_authentication(error_desc: str = "") -> HTMLResponse:
|
||||
return HTMLResponse(
|
||||
f"""
|
||||
@@ -142,6 +122,9 @@ async def perform_token_flow(
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
logger.debug(
|
||||
f"Token endpoint response: {response.status_code} - {response.text}"
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
|
||||
@@ -56,7 +56,7 @@ class OAuthServerProvider(OAuthAuthorizationServerProvider):
|
||||
self.auth_code_mapping: Dict[str, AuthorizationCode] = {}
|
||||
self.token_mapping: Dict[str, AccessToken | RefreshToken] = {}
|
||||
self.state_mapping: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
self.token_metadata_mapping: Dict[str, OAuth2Token] = {}
|
||||
self._metadata: OAuthMetadata | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@@ -69,6 +69,7 @@ class OAuthServerProvider(OAuthAuthorizationServerProvider):
|
||||
|
||||
async def register_client(self, client_info: OAuthClientInformationFull):
|
||||
"""Register a new OAuth client."""
|
||||
client_info.client_id = f"fhir_mcp_{client_info.client_id}"
|
||||
self.clients[client_info.client_id] = client_info
|
||||
|
||||
async def authorize(
|
||||
@@ -203,6 +204,8 @@ class OAuthServerProvider(OAuthAuthorizationServerProvider):
|
||||
expires_at=int(token.expires_at or 3600),
|
||||
)
|
||||
|
||||
self.token_metadata_mapping[token.access_token] = token
|
||||
|
||||
return OAuthToken(
|
||||
access_token=mcp_access_token,
|
||||
refresh_token=mcp_refresh_token,
|
||||
@@ -280,6 +283,8 @@ class OAuthServerProvider(OAuthAuthorizationServerProvider):
|
||||
client_id=client.client_id,
|
||||
)
|
||||
|
||||
self.token_metadata_mapping[new_token.access_token] = new_token
|
||||
|
||||
return OAuthToken(
|
||||
access_token=mcp_access_token,
|
||||
refresh_token=mcp_refresh_token,
|
||||
|
||||
@@ -14,9 +14,16 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServerConfigs(BaseSettings):
|
||||
"""Contains environment configurations of the MCP server."""
|
||||
@@ -39,7 +46,7 @@ class ServerConfigs(BaseSettings):
|
||||
server_client_id: str = ""
|
||||
server_client_secret: str = ""
|
||||
server_scopes: str = ""
|
||||
server_base_url: str = ""
|
||||
server_base_url: str
|
||||
server_access_token: str | None = None
|
||||
|
||||
def callback_url(
|
||||
@@ -59,7 +66,11 @@ class ServerConfigs(BaseSettings):
|
||||
def scopes(self) -> list[str]:
|
||||
# If the raw value is a string, split on empty spaces
|
||||
if isinstance(self.server_scopes, str):
|
||||
return [scope.strip() for scope in self.server_scopes.split(" ") if scope.strip()]
|
||||
return [
|
||||
scope.strip()
|
||||
for scope in self.server_scopes.split(" ")
|
||||
if scope.strip()
|
||||
]
|
||||
return [self.server_scopes]
|
||||
|
||||
@property
|
||||
@@ -110,11 +121,28 @@ class OAuthToken(BaseModel):
|
||||
scope: str | None = None
|
||||
refresh_token: str | None = None
|
||||
expires_at: float | None = None
|
||||
id_token: str | None = None
|
||||
client_id: str | None = None
|
||||
|
||||
@property
|
||||
def scopes(self) -> list[str]:
|
||||
return self.scope.split(" ") if self.scope else []
|
||||
|
||||
def get_id_token(self) -> "IDToken | None":
|
||||
"""
|
||||
Parse the id_token and return an IDToken object.
|
||||
|
||||
Returns:
|
||||
An IDToken instance populated from the JWT payload or None if parsing fails.
|
||||
"""
|
||||
payload: Dict[str, Any] | None = (
|
||||
decode_jws(self.id_token) if self.id_token else None
|
||||
)
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
return IDToken.model_validate(payload)
|
||||
|
||||
|
||||
class AuthorizationCode(BaseModel):
|
||||
code: str
|
||||
@@ -125,3 +153,65 @@ class AuthorizationCode(BaseModel):
|
||||
code_challenge: str
|
||||
redirect_uri: AnyHttpUrl
|
||||
redirect_uri_provided_explicitly: bool
|
||||
|
||||
|
||||
class IDToken(BaseModel):
|
||||
fhirUser: str | None = None
|
||||
|
||||
def parse_fhir_user(self) -> tuple[str, str] | None:
|
||||
"""
|
||||
Parse the fhirUser URL to extract resource type and resource ID.
|
||||
|
||||
The fhirUser URL MAY be absolute (e.g., https://ehr.example.org/Practitioner/123),
|
||||
or it MAY be relative to the FHIR server base URL (e.g., Practitioner/123).
|
||||
|
||||
Returns:
|
||||
A tuple of (resource_type, resource_id) if fhirUser is valid,
|
||||
None otherwise.
|
||||
"""
|
||||
if not self.fhirUser:
|
||||
return None
|
||||
|
||||
logger.debug(f"Parsing fhirUser: {self.fhirUser}")
|
||||
parts: list[str] = self.fhirUser.rstrip('/').split("/")
|
||||
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
|
||||
return parts[len(parts) - 2], parts[len(parts) - 1]
|
||||
|
||||
@property
|
||||
def resource_type(self) -> str | None:
|
||||
"""Get the FHIR resource type from fhirUser URL."""
|
||||
parsed = self.parse_fhir_user()
|
||||
return parsed[0] if parsed else None
|
||||
|
||||
@property
|
||||
def resource_id(self) -> str | None:
|
||||
"""Get the FHIR resource ID from fhirUser URL."""
|
||||
parsed = self.parse_fhir_user()
|
||||
return parsed[1] if parsed else None
|
||||
|
||||
|
||||
def decode_jws(jws: str) -> Dict[str, Any] | None:
|
||||
"""
|
||||
Decode the provided JWS payload.
|
||||
|
||||
Returns:
|
||||
The decoded JWS payload as a dictionary.
|
||||
"""
|
||||
try:
|
||||
parts: list[str] = jws.split(".")
|
||||
if len(parts) != 3:
|
||||
logger.debug(
|
||||
f"Decoding JWS failed: Invalid JWS format, expected 3 parts but got {len(parts)}: {jws}"
|
||||
)
|
||||
return None
|
||||
|
||||
padded: str = parts[1] + "=" * (4 - len(parts[1]) % 4)
|
||||
decoded: bytes = base64.urlsafe_b64decode(padded)
|
||||
return json.loads(decoded)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error decoding JWS token. Caused by, ", exc_info=e)
|
||||
return None
|
||||
|
||||
@@ -18,6 +18,7 @@ import click
|
||||
import logging
|
||||
|
||||
from fhir_mcp_server.utils import (
|
||||
build_user_profile,
|
||||
create_async_fhir_client,
|
||||
get_bundle_entries,
|
||||
get_default_headers,
|
||||
@@ -62,17 +63,27 @@ async def get_user_access_token(click_ctx: click.Context) -> OAuthToken | None:
|
||||
"""
|
||||
if configs.server_access_token:
|
||||
logger.debug("Using configured FHIR access token for user.")
|
||||
return OAuthToken(access_token=configs.server_access_token, token_type="Bearer")
|
||||
|
||||
user_token: AccessToken | None = get_access_token()
|
||||
if not user_token:
|
||||
logger.error("Failed to obtain client access token.")
|
||||
raise ValueError("Failed to obtain client access token.")
|
||||
return OAuthToken(
|
||||
access_token=configs.server_access_token,
|
||||
token_type="Bearer",
|
||||
client_id=configs.server_client_id,
|
||||
)
|
||||
|
||||
user_token: AccessToken | None = get_access_token()
|
||||
logger.debug("Obtained client access token from context.")
|
||||
|
||||
# Return the FHIR access token
|
||||
return user_token
|
||||
return (
|
||||
OAuthToken(
|
||||
access_token=user_token.token,
|
||||
client_id=configs.server_client_id,
|
||||
token_type="Bearer",
|
||||
expires_at=user_token.expires_at,
|
||||
scope=" ".join(user_token.scopes),
|
||||
)
|
||||
if user_token
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@click.pass_context
|
||||
@@ -86,17 +97,15 @@ async def get_async_fhir_client(click_ctx: click.Context) -> AsyncFHIRClient:
|
||||
"extra_headers": get_default_headers(),
|
||||
}
|
||||
|
||||
disable_auth: bool = (
|
||||
click_ctx.obj.get("disable_auth") if click_ctx.obj else False
|
||||
)
|
||||
if not disable_auth:
|
||||
user_token: AccessToken | None = await get_user_access_token()
|
||||
if not user_token:
|
||||
user_token: OAuthToken | None = await get_user_access_token()
|
||||
disable_auth: bool = click_ctx.obj.get("disable_auth") if click_ctx.obj else False
|
||||
if not user_token:
|
||||
if not disable_auth:
|
||||
logger.error("User is not authenticated.")
|
||||
raise ValueError("User is not authenticated.")
|
||||
client_kwargs["access_token"] = user_token.token
|
||||
else:
|
||||
logger.debug("FHIR authentication is disabled.")
|
||||
client_kwargs["access_token"] = user_token.access_token
|
||||
|
||||
return await create_async_fhir_client(**client_kwargs)
|
||||
|
||||
|
||||
@@ -205,9 +214,7 @@ def register_mcp_tools(mcp: FastMCP) -> None:
|
||||
]:
|
||||
try:
|
||||
logger.debug(f"Invoked with resource_type='{type}'")
|
||||
data: Dict[str, Any] = await get_capability_statement(
|
||||
configs.metadata_url
|
||||
)
|
||||
data: Dict[str, Any] = await get_capability_statement(configs.metadata_url)
|
||||
for resource in data["rest"][0]["resource"]:
|
||||
if resource.get("type") == type:
|
||||
logger.info(
|
||||
@@ -690,6 +697,89 @@ def register_mcp_tools(mcp: FastMCP) -> None:
|
||||
)
|
||||
return await get_operation_outcome_exception()
|
||||
|
||||
@mcp.tool(
|
||||
description=(
|
||||
"Retrieves the authenticated user's FHIR profile. "
|
||||
"Use this tool when you need to access the current user's demographic and contact details."
|
||||
)
|
||||
)
|
||||
async def get_user() -> Annotated[
|
||||
list[Dict[str, Any]] | Dict[str, Any],
|
||||
Field(
|
||||
description="A dictionary containing the authenticated user's demographic information such as 'id', 'name', and 'birthDate'."
|
||||
),
|
||||
]:
|
||||
try:
|
||||
logger.debug("Retrieving authenticated user's profile.")
|
||||
|
||||
# Validate user authentication
|
||||
user_token = await get_user_access_token()
|
||||
if not user_token:
|
||||
logger.debug("Unauthorized access attempt to get_me endpoint.")
|
||||
return {}
|
||||
|
||||
# Retrieve token metadata
|
||||
token_metadata = server_provider.token_metadata_mapping.get(
|
||||
user_token.access_token
|
||||
)
|
||||
if not token_metadata:
|
||||
logger.debug("Token metadata not found for authenticated user.")
|
||||
return {}
|
||||
|
||||
# Extract ID token information
|
||||
id_token = token_metadata.get_id_token()
|
||||
if not id_token:
|
||||
logger.debug("ID token not found in token metadata.")
|
||||
return {}
|
||||
|
||||
# Validate resource identifiers
|
||||
resource_id = id_token.resource_id
|
||||
resource_type = id_token.resource_type
|
||||
|
||||
if not resource_id or not resource_type:
|
||||
logger.debug("Resource ID or type missing from ID token.")
|
||||
return {}
|
||||
|
||||
logger.debug(f"Fetching FHIR resource: {resource_type}/{resource_id}")
|
||||
|
||||
# Fetch user's FHIR resource
|
||||
client: AsyncFHIRClient = await get_async_fhir_client()
|
||||
resource: Dict[str, Any] = await client.get(
|
||||
resource_type_or_resource_or_ref=resource_type, id_or_ref=resource_id
|
||||
)
|
||||
|
||||
# Build response with only available fields
|
||||
profile: Dict[str, Any] = build_user_profile(resource)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully retrieved profile for user: {resource_type}/{resource_id}"
|
||||
)
|
||||
return profile
|
||||
|
||||
except ValueError as ex:
|
||||
logger.exception(
|
||||
"Authorization error occurred while reading user resource. Caused by, ",
|
||||
exc_info=ex,
|
||||
)
|
||||
return await get_operation_outcome(
|
||||
code="forbidden",
|
||||
diagnostics="The user does not have the rights to perform read operations.",
|
||||
)
|
||||
|
||||
except OperationOutcome as ex:
|
||||
logger.exception(
|
||||
f"FHIR server error occurred while reading user resource. Caused by, ",
|
||||
exc_info=ex,
|
||||
)
|
||||
return ex.resource.get("issue") or await get_operation_outcome_exception()
|
||||
|
||||
except Exception as ex:
|
||||
logger.exception(
|
||||
"Unexpected error occurred while reading user resource. Caused by, ",
|
||||
exc_info=ex,
|
||||
)
|
||||
return await get_operation_outcome_exception()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
@@ -714,9 +804,7 @@ def register_mcp_tools(mcp: FastMCP) -> None:
|
||||
help="Disable authorization between MCP client and MCP server. [default: False]",
|
||||
)
|
||||
@click.pass_context
|
||||
def main(
|
||||
click_ctx: click.Context, transport, log_level, disable_auth
|
||||
) -> int:
|
||||
def main(click_ctx: click.Context, transport, log_level, disable_auth) -> int:
|
||||
"""
|
||||
FHIR MCP Server - helping you expose any FHIR Server or API as a MCP Server.
|
||||
"""
|
||||
@@ -731,9 +819,7 @@ def main(
|
||||
try:
|
||||
mcp: FastMCP = configure_mcp_server(disable_auth)
|
||||
register_mcp_tools(mcp=mcp)
|
||||
register_mcp_routes(
|
||||
mcp=mcp, server_provider=server_provider
|
||||
)
|
||||
register_mcp_routes(mcp=mcp, server_provider=server_provider)
|
||||
logger.info(f"Starting FHIR MCP server with {transport} transport")
|
||||
mcp.run(transport=transport)
|
||||
except Exception as ex:
|
||||
|
||||
@@ -111,6 +111,7 @@ async def get_capability_statement(metadata_url: str) -> Dict[str, Any]:
|
||||
Discover CapabilityStatement from server's metadata endpoint.
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Fetching CapabilityStatement from {metadata_url}")
|
||||
async with create_mcp_http_client() as client:
|
||||
response = await client.get(url=metadata_url, headers=get_default_headers())
|
||||
response.raise_for_status()
|
||||
@@ -126,3 +127,35 @@ async def get_capability_statement(metadata_url: str) -> Dict[str, Any]:
|
||||
|
||||
def get_default_headers() -> Dict[str, str]:
|
||||
return {"Accept": "application/fhir+json", "Content-Type": "application/fhir+json"}
|
||||
|
||||
|
||||
def build_user_profile(resource: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Build user profile dictionary from FHIR resource.
|
||||
|
||||
Args:
|
||||
resource: The FHIR resource dictionary of the user.
|
||||
|
||||
Returns:
|
||||
Dict containing only mandatory user fields
|
||||
"""
|
||||
|
||||
# Define fields to extract from the resource
|
||||
fields_to_extract = [
|
||||
"id",
|
||||
"resourceType",
|
||||
"name",
|
||||
"gender",
|
||||
"birthDate",
|
||||
"telecom",
|
||||
"address",
|
||||
]
|
||||
|
||||
profile: Dict[str, Any] = {}
|
||||
# Add fields only if they exist and have values
|
||||
for field in fields_to_extract:
|
||||
value = resource.get(field)
|
||||
if value is not None:
|
||||
profile[field] = value
|
||||
|
||||
return profile
|
||||
|
||||
Reference in New Issue
Block a user