Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ async def main():
elicitation_callback: ElicitationFnT | None = None
"""Callback for handling elicitation requests."""

protocol_version: str | None = None
"""The protocol version to request during initialization. Defaults to the latest version."""

_session: ClientSession | None = field(init=False, default=None)
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
_transport: Transport = field(init=False)
Expand Down Expand Up @@ -129,7 +132,10 @@ async def __aenter__(self) -> Client:
)
)

await self._session.initialize()
if self.protocol_version is not None:
await self._session.initialize(protocol_version=self.protocol_version)
else:
await self._session.initialize()

# Transfer ownership to self for __aexit__ to handle
self._exit_stack = exit_stack.pop_all()
Expand Down
4 changes: 2 additions & 2 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]:
def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]:
return types.server_notification_adapter

async def initialize(self) -> types.InitializeResult:
async def initialize(self, protocol_version: str = types.LATEST_PROTOCOL_VERSION) -> types.InitializeResult:
sampling = (
(self._sampling_capabilities or types.SamplingCapability())
if self._sampling_callback is not _default_sampling_callback
Expand All @@ -168,7 +168,7 @@ async def initialize(self) -> types.InitializeResult:
result = await self.send_request(
types.InitializeRequest(
params=types.InitializeRequestParams(
protocol_version=types.LATEST_PROTOCOL_VERSION,
protocol_version=protocol_version,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
Expand Down
6 changes: 5 additions & 1 deletion src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ClientSessionParameters:
logging_callback: LoggingFnT | None = None
message_handler: MessageHandlerFnT | None = None
client_info: types.Implementation | None = None
protocol_version: str | None = None


class ClientSessionGroup:
Expand Down Expand Up @@ -313,7 +314,10 @@ async def _establish_session(
)
)

result = await session.initialize()
if session_params.protocol_version is not None:
result = await session.initialize(protocol_version=session_params.protocol_version)
else:
result = await session.initialize()

# Session successfully initialized.
# Store its stack and register the stack with the main group stack.
Expand Down
7 changes: 7 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ async def test_client_is_initialized(app: MCPServer):
assert client.initialize_result.server_info.name == "test"


async def test_client_custom_protocol_version(app: MCPServer):
"""Test that the client negotiates a custom protocol version when configured."""
async with Client(app, protocol_version="2024-11-05") as client:
assert client.initialize_result.protocol_version == "2024-11-05"
assert client.initialize_result.server_info.name == "test"


async def test_client_with_simple_server(simple_server: Server):
"""Test that from_server works with a basic Server instance."""
async with Client(simple_server) as client:
Expand Down
84 changes: 84 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,90 @@ async def message_handler( # pragma: no cover
assert isinstance(initialized_notification, InitializedNotification)


@pytest.mark.anyio
async def test_client_session_initialize_custom_protocol_version():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

initialized_notification = None
result = None

async def mock_server():
nonlocal initialized_notification

session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
assert request.params.protocol_version == "2024-11-05"

result = InitializeResult(
protocol_version="2024-11-05",
capabilities=ServerCapabilities(
logging=None,
resources=None,
tools=None,
experimental=None,
prompts=None,
),
server_info=Implementation(name="mock-server", version="0.1.0"),
instructions="The server instructions.",
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
session_notification = await client_to_server_receive.receive()
jsonrpc_notification = session_notification.message
assert isinstance(jsonrpc_notification, JSONRPCNotification)
initialized_notification = client_notification_adapter.validate_python(
jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
)

# Create a message handler to catch exceptions
async def message_handler( # pragma: no cover
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
result = await session.initialize(protocol_version="2024-11-05")

# Assert the result
assert isinstance(result, InitializeResult)
assert result.protocol_version == "2024-11-05"
assert isinstance(result.capabilities, ServerCapabilities)
assert result.server_info == Implementation(name="mock-server", version="0.1.0")
assert result.instructions == "The server instructions."

# Check that the client sent the initialized notification
assert initialized_notification
assert isinstance(initialized_notification, InitializedNotification)


@pytest.mark.anyio
async def test_client_session_custom_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
Expand Down
34 changes: 34 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,37 @@ async def test_client_session_group_establish_session_parameterized(
# 3. Assert returned values
assert returned_server_info is mock_initialize_result.server_info
assert returned_session is mock_entered_session


@pytest.mark.anyio
async def test_client_session_group_establish_session_custom_protocol_version():
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
with mock.patch("mcp.client.session_group.mcp.stdio_client") as mock_stdio_client:
mock_client_cm_instance = mock.AsyncMock(name="stdioClientCM")
mock_read_stream = mock.AsyncMock(name="stdioRead")
mock_write_stream = mock.AsyncMock(name="stdioWrite")

mock_client_cm_instance.__aenter__.return_value = (mock_read_stream, mock_write_stream)
mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None)
mock_stdio_client.return_value = mock_client_cm_instance

mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM")
mock_ClientSession_class.return_value = mock_raw_session_cm

mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance")
mock_raw_session_cm.__aenter__.return_value = mock_entered_session
mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None)

mock_initialize_result = mock.AsyncMock(name="InitializeResult")
mock_initialize_result.server_info = types.Implementation(name="foo", version="1")
mock_entered_session.initialize.return_value = mock_initialize_result

group = ClientSessionGroup()
server_params = StdioServerParameters(command="test_stdio_cmd")
session_params = ClientSessionParameters(protocol_version="2024-11-05")

async with contextlib.AsyncExitStack() as stack:
group._exit_stack = stack
await group._establish_session(server_params, session_params)

mock_entered_session.initialize.assert_awaited_once_with(protocol_version="2024-11-05")
Loading