refactor: Improve input schema of the tool parameters

This commit is contained in:
vimukthiRajapaksha
2025-07-02 13:27:50 +05:30
parent 6e8295fbb7
commit 3692a7448a
6 changed files with 59 additions and 48 deletions

View File

@@ -1,3 +1,14 @@
# ----------------------------------------------------------------------------------------
#
# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). All Rights Reserved.
#
# This software is the property of WSO2 LLC. and its suppliers, if any.
# Dissemination of any information or reproduction of any material contained
# herein in any form is strictly forbidden, unless permitted by WSO2 expressly.
# You may not alter or remove any copyright or other notice from copies of this content.
#
# ----------------------------------------------------------------------------------------
FROM python:3.11-slim
# Install uv (for fast dependency management)
@@ -25,4 +36,4 @@ ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app/src
# Default command to run the server (can be overridden)
CMD ["python", "-m", "fhir_mcp_server", "--disable-mcp-auth"]
CMD ["python", "-m", "fhir_mcp_server"]

View File

@@ -140,11 +140,11 @@ class FHIRClientProvider(httpx.Auth):
return
# Try refreshing existing token
# if await self._refresh_access_token(token_id):
# return
if await self._refresh_access_token(token_id):
return
# Fall back to full OAuth token flow
await self._perform_token_flow(token_id)
# Fall back to full OAuth flow
await self._perform_oauth_flow(token_id)
async def _perform_oauth_flow(self, token_id: str) -> None:
"""Execute OAuth2 authorization code flow with PKCE."""
@@ -188,30 +188,6 @@ class FHIRClientProvider(httpx.Auth):
# Redirect user for authorization
await self.redirect_handler(auth_url)
async def _perform_token_flow(self, token_id: str) -> None:
"""Execute OAuth2 client credentials flow."""
logger.debug("Starting client credentials flow.")
access_token_payload: dict = {
"grant_type": "client_credentials",
"scope": self.configs.scope,
"client_id": self.configs.client_id,
"client_secret": self.configs.client_secret,
"resource": "https://ohfhirrepositorypoc-ohfhirrepositorypoc.fhir.azurehealthcareapis.com",
}
try:
token: OAuthToken = await perform_token_flow(
url=self._get_token_endpoint(),
data=access_token_payload,
timeout=self.configs.timeout,
)
self.token_mapping[token_id] = token
except Exception as ex:
logger.exception("Access token request failed. Caused by, ", exc_info=ex)
raise ValueError("Access token request failed")
async def handle_fhir_oauth_callback(self, code: str, state: str) -> None:
state_mapping: Dict[str, str] | None = self.state_mapping.get(state)
@@ -259,8 +235,7 @@ class FHIRClientProvider(httpx.Auth):
def _get_token_endpoint(self) -> str:
"""Get token endpoint."""
return "https://login.microsoftonline.com/da76d684-740f-4d94-8717-9d5fb21dd1f9/oauth2/token"
# return get_endpoint(self._metadata, "token_endpoint")
return get_endpoint(self._metadata, "token_endpoint")
async def _refresh_access_token(self, token_id: str) -> None:
"""Refresh access token using refresh token."""

View File

@@ -128,7 +128,10 @@ def generate_code_challenge(code_verifier: str) -> str:
async def perform_token_flow(
url: str,
data: Dict[str, str],
headers: Dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"},
headers: Dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
timeout: float = 30.0,
) -> OAuthToken:
try:

View File

@@ -41,7 +41,7 @@ class MCPOAuthConfigs(BaseOAuthConfigs):
class FHIROAuthConfigs(BaseOAuthConfigs):
base_url: str = "https://hapi.fhir.org/baseR5"
base_url: str = ""
timeout: int = 30 # in secs
access_token: str | None = None

View File

@@ -25,7 +25,7 @@ from fhir_mcp_server.utils import (
get_operation_outcome_exception,
get_operation_outcome_required_error,
get_capability_statement,
trim_resource,
trim_resource_capabilities,
)
from fhir_mcp_server.oauth import (
handle_failed_authentication,
@@ -223,11 +223,14 @@ def register_mcp_tools(mcp: FastMCP) -> None:
Field(
description=(
"A dictionary containing: "
"- 'type': The requested resource type (if supported by the system) or empty. "
"- 'searchParam' : A map of FHIR search-parameter names. Each key is the parameter name (e.g., `family`, `_id`, `_lastUpdated`),"
"and each value is the FHIR-provided description of that parameter's meaning and usage constraints. "
"- 'operation' : A map of custom FHIR operation names to their descriptions, "
"each key is the operation name (e.g., `$validate`), and each value explains the operation's purpose."
"'type': The requested resource type (if supported by the system) or empty. "
"'searchParam': A mapping of FHIR search parameter names to their descriptions. Each key is a parameter name "
"(e.g., family, _id, _lastUpdated), and each value is a string describing the parameter's meaning and usage constraints. "
"'operation': A mapping of custom FHIR operation names to their descriptions. Each key is an operation name "
"(e.g., $validate), and each value is a string explaining the operation's purpose and usage. "
"'interaction': A list of supported interactions for the resource type (e.g., read, search-type, create). "
"'searchInclude': A list of supported _include parameters for the resource type, indicating which related resources can be included. "
"'searchRevInclude': A list of supported _revinclude parameters for the resource type, indicating which reverse-included resources can be included."
)
),
]:
@@ -243,8 +246,15 @@ def register_mcp_tools(mcp: FastMCP) -> None:
)
return {
"type": resource.get("type"),
"searchParam": trim_resource(resource.get("searchParam", [])),
"operation": trim_resource(resource.get("operation", [])),
"searchParam": trim_resource_capabilities(
resource.get("searchParam", [])
),
"operation": trim_resource_capabilities(
resource.get("operation", [])
),
"interaction": resource.get("interaction", []),
"searchInclude": resource.get("searchInclude", []),
"searchRevInclude": resource.get("searchRevInclude", []),
}
logger.info(f"Resource type '{type}' not found in the CapabilityStatement.")
return await get_operation_outcome(

View File

@@ -59,14 +59,23 @@ async def get_bundle_entries(bundle: Dict[str, Any]) -> Dict[str, Any]:
return bundle
def trim_resource(operations: List[Dict[str, Any]]) -> List[Dict[str, Optional[str]]]:
logger.debug(f"trim_resource called with {len(operations)} operations.")
def trim_resource_capabilities(
capabilities: List[Dict[str, Any]],
) -> List[Dict[str, Optional[str]]]:
logger.debug(
f"trim_resource_capabilities called with {len(capabilities)} capabilities."
)
trimmed = [
{"name": operation.get("name"), "documentation": operation.get("documentation")}
for operation in operations
if "name" in operation or "documentation" in operation
{
"name": capability.get("name"),
"documentation": capability.get("documentation"),
}
for capability in capabilities
if "name" in capability or "documentation" in capability
]
logger.debug(f"trim_resource returning {len(trimmed)} trimmed operations.")
logger.debug(
f"trim_resource_capabilities returning {len(trimmed)} trimmed capabilities."
)
return trimmed
@@ -82,7 +91,9 @@ async def get_operation_outcome_required_error(element: str = "") -> dict:
)
async def get_operation_outcome(code: str, diagnostics: str, severity: str = "error") -> dict:
async def get_operation_outcome(
code: str, diagnostics: str, severity: str = "error"
) -> dict:
return {
"resourceType": "OperationOutcome",
"issue": [
@@ -94,6 +105,7 @@ async def get_operation_outcome(code: str, diagnostics: str, severity: str = "er
],
}
async def get_capability_statement(metadata_url: str) -> Dict[str, Any]:
"""
Discover CapabilityStatement from server's metadata endpoint.