Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 132 additions & 1 deletion cecli/mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import logging
import os
import random
import webbrowser
from contextlib import AsyncExitStack
from enum import Enum, auto
from urllib.parse import urlparse

import httpx
Expand All @@ -20,6 +22,16 @@
save_mcp_oauth_token,
)

MIN_KEEPALIVE_INTERVAL = 5
MAX_KEEPALIVE_INTERVAL = 300
FAILED_PING_THRESHOLD = 3


class ConnectionState(Enum):
CONNECTED = auto()
UNHEALTHY = auto()
DISCONNECTED = auto()


class McpServer:
"""
Expand Down Expand Up @@ -120,6 +132,13 @@ async def disconnect(self):
class HttpBasedMcpServer(McpServer):
"""Base class for HTTP-based MCP servers (HTTP streaming and SSE)."""

def __init__(self, server_config, io=None, verbose=False):
super().__init__(server_config, io, verbose)
self._state: ConnectionState = ConnectionState.CONNECTED
self._failed_pings: int = 0
self._keepalive_task: asyncio.Task | None = None
self._http_client: httpx.AsyncClient | None = None

async def _create_oauth_provider(self):
"""Create an OAuthClientProvider using the MCP SDK."""
parsed = urlparse(self.config.get("url"))
Expand Down Expand Up @@ -228,6 +247,7 @@ async def connect(self):
timeout=30,
)
)
self._http_client = http_client

transport = await self.exit_stack.enter_async_context(
self._create_transport(url, http_client=http_client)
Expand All @@ -238,6 +258,7 @@ async def connect(self):
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
self.session = session
await self.start_keepalive()
self._connection_loop = current_loop

if oauth_provider.context.oauth_metadata:
Expand All @@ -256,10 +277,119 @@ async def connect(self):
await self.disconnect()
raise

async def disconnect(self):
async def start_keepalive(self):
"""Start the background keepalive loop if configured."""
interval = self.config.get("keepalive_interval", 300)

try:
interval = int(interval)
if not (MIN_KEEPALIVE_INTERVAL <= interval <= MAX_KEEPALIVE_INTERVAL):
if self.verbose and self.io:
self.io.tool_warning(
f"Keepalive interval {interval} out of range ({MIN_KEEPALIVE_INTERVAL}-"
f"{MAX_KEEPALIVE_INTERVAL}). Ignoring."
)
return
except (ValueError, TypeError):
if self.verbose and self.io:
self.io.tool_warning(f"Invalid keepalive interval {interval}. Must be an integer.")
return

if self._keepalive_task and not self._keepalive_task.done():
self._keepalive_task.cancel()

self._keepalive_task = asyncio.create_task(self._keepalive_loop(interval))
if self.verbose and self.io:
self.io.tool_output(f"Started keepalive loop for {self.name} (interval: {interval}s)")

async def _keepalive_loop(self, interval: int):
"""Background loop that sends periodic heartbeats to the MCP server."""
try:
while True:
# Jitter: ±10% to prevent timing analysis
jitter = interval * 0.1 * (2 * random.random() - 1)
await asyncio.sleep(interval + jitter)

if not self._http_client:
continue

try:
url = self.config.get("url")
headers = self.config.get("headers", {})

# Use OPTIONS request as a lightweight heartbeat
response = await self._http_client.options(url, headers=headers)
if response.status_code == 200:
self._state = ConnectionState.CONNECTED
self._failed_pings = 0
else:
raise httpx.HTTPStatusError(
f"Unexpected status {response.status_code}",
request=response.request,
response=response,
)
except Exception:
self._failed_pings += 1
if self._failed_pings >= FAILED_PING_THRESHOLD:
self._state = ConnectionState.DISCONNECTED
if self.verbose and self.io:
self.io.tool_warning(
f"MCP server {self.name} disconnected after {self._failed_pings} failed"
" pings. Attempting reconnect..."
)
await self.reconnect()
else:
self._state = ConnectionState.UNHEALTHY
if self.verbose and self.io:
self.io.tool_output(
f"MCP server {self.name} unhealthy (ping {self._failed_pings}/{FAILED_PING_THRESHOLD})"
)
except asyncio.CancelledError:
pass
except Exception as e:
logging.error(f"Keepalive loop for {self.name} crashed: {e}")

async def reconnect(self):
"""Attempt to reconnect to the server using exponential backoff."""
initial_delay = 1
multiplier = 2
max_delay = 300

attempt = 0
while self._state == ConnectionState.DISCONNECTED:
delay = min(initial_delay * (multiplier**attempt), max_delay)
# Jitter: ±20%
jitter = delay * 0.2 * (2 * random.random() - 1)
await asyncio.sleep(delay + jitter)

try:
if self.verbose and self.io:
self.io.tool_output(
f"Attempting to reconnect to {self.name} (attempt {attempt + 1})..."
)

# Clean up old session/client without cancelling the keepalive task
await self.disconnect(cancel_keepalive=False)
await self.connect()

self._state = ConnectionState.CONNECTED
self._failed_pings = 0
if self.verbose and self.io:
self.io.tool_output(f"Successfully reconnected to {self.name}")
break
except Exception as e:
attempt += 1
if self.verbose and self.io:
self.io.tool_warning(
f"Reconnection attempt {attempt} failed for {self.name}: {e}"
)

async def disconnect(self, cancel_keepalive: bool = True):
"""Disconnect from the MCP server and clean up resources."""
async with self._cleanup_lock:
try:
if cancel_keepalive and self._keepalive_task:
self._keepalive_task.cancel()
if hasattr(self, "_oauth_shutdown"):
self._oauth_shutdown()
await self.exit_stack.aclose()
Expand All @@ -271,6 +401,7 @@ async def disconnect(self):
logging.error(f"Error during cleanup of server {self.name}: {e}")
finally:
self.session = None
self._http_client = None


class HttpStreamingServer(HttpBasedMcpServer):
Expand Down
20 changes: 19 additions & 1 deletion cecli/website/docs/config/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,32 @@ description: Configure Model Control Protocol (MCP) servers for enhanced AI capa

# Model Control Protocol (MCP)

Model Control Protocol (MCP) servers extend cecli's capabilities by providing additional tools and functionality to the AI models. MCP servers can add features like git operations, context retrieval, and other specialized tools.
Model Control Protocol (MCP) servers extend the capabilities of cecli by providing additional tools and functionality to the AI models. MCP servers can add features like git operations, context retrieval, and other specialized tools.

## Configuring MCP Servers

cecli supports configuring MCP servers using the MCP Server Configuration schema. Please
see the [Model Context Protocol documentation](https://modelcontextprotocol.io/introduction)
for more information.

### Keepalive Mechanism

For HTTP-based servers, you can enable a keepalive mechanism to prevent connections from dropping during long idle periods. This is done by adding the `keepalive_interval` property to your server configuration.

- `keepalive_interval`: (Optional) An integer specifying the interval in seconds for sending a heartbeat (an `OPTIONS` request) to the server.
- If not provided, it defaults to **300** seconds.
- The value must be between **5** and **300** seconds.

Example with keepalive enabled:
```yaml
mcp-servers:
mcpServers:
context7:
transport: http
url: https://mcp.context7.com/mcp
keepalive_interval: 60 # Send a heartbeat every 60 seconds
```

You have two ways of sharing your MCP server configuration with cecli.

{: .note }
Expand Down
102 changes: 102 additions & 0 deletions tests/mcp/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Any, AsyncGenerator, Dict
from unittest.mock import MagicMock

import pytest

from cecli.mcp.server import HttpBasedMcpServer, HttpStreamingServer
from tests.mcp.mock_server import MockMcpServer


@pytest.fixture
def mock_mcp_server() -> MockMcpServer:
"""Fixture providing a mock MCP server instance."""
server = MockMcpServer()
return server


@pytest.fixture
async def running_mock_server(mock_mcp_server) -> AsyncGenerator[MockMcpServer, None]:
"""Fixture providing a running mock MCP server."""
await mock_mcp_server.start()
yield mock_mcp_server
await mock_mcp_server.stop()


@pytest.fixture
def http_server_config(running_mock_server) -> Dict[str, Any]:
"""Fixture providing a basic HTTP server configuration."""
return {
"name": "test-server",
"url": running_mock_server,
"type": "http",
"keepalive_interval": 1, # 1 second for fast tests
"headers": {},
"enabled": True,
}


@pytest.fixture
def http_streaming_server_config(running_mock_server) -> Dict[str, Any]:
"""Fixture providing an HTTP streaming server configuration."""
return {
"name": "test-streaming-server",
"url": running_mock_server,
"type": "streamable_http",
"keepalive_interval": 1,
"headers": {},
"enabled": True,
}


@pytest.fixture
def mock_io():
"""Fixture providing a mock IO object."""
io = MagicMock()
io.tool_output = MagicMock()
io.tool_error = MagicMock()
io.tool_warning = MagicMock()
return io


@pytest.fixture
def http_based_server(http_server_config, mock_io) -> HttpBasedMcpServer:
"""Fixture providing an HttpBasedMcpServer instance."""
return HttpBasedMcpServer(http_server_config, io=mock_io)


@pytest.fixture
def http_streaming_server(http_streaming_server_config, mock_io) -> HttpStreamingServer:
"""Fixture providing an HttpStreamingServer instance."""
return HttpStreamingServer(http_streaming_server_config, io=mock_io)


# Test utilities for inspecting internal state
class ServerStateInspector:
"""Utility class to inspect internal state of HttpBasedMcpServer for testing."""

@staticmethod
def get_state(server: HttpBasedMcpServer):
"""Get the connection state of the server."""
return server._state

@staticmethod
def get_failed_pings(server: HttpBasedMcpServer):
"""Get the number of failed pings."""
return server._failed_pings

@staticmethod
def get_keepalive_task(server: HttpBasedMcpServer):
"""Get the keepalive task."""
return server._keepalive_task

@staticmethod
def is_keepalive_running(server: HttpBasedMcpServer):
"""Check if the keepalive task is running."""
task = server._keepalive_task
return task is not None and not task.done()


@pytest.fixture
def server_inspector():
"""Fixture providing a server state inspector."""
return ServerStateInspector()
Loading
Loading