diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index 2f558dcb8fe..ac217190196 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -32,7 +32,7 @@ from .base_coder import Coder -from cecli.helpers.coroutines import interruptible # isort:skip +from cecli.helpers import coroutines # isort:skip logger = logging.getLogger(__name__) @@ -110,7 +110,7 @@ def __init__(self, *args, **kwargs): def post_init(self): super().post_init() - + self.coroutines = coroutines if not self._inherited_tools: # Populate per-instance tool and server filters from config self.registered_tools["included"] = set( @@ -325,7 +325,9 @@ async def _exec_async(): self.io.tool_warning(f"Executing {tool_name} on {server.name} failed:\nError: {e}") return f"Error executing tool call {tool_name}: {e}" - result, interrupted = await interruptible(_exec_async(), self.interrupt_event) + result, interrupted = await self.coroutines.interruptible( + _exec_async(), self.interrupt_event + ) if interrupted: return "Tool execution interrupted by user." @@ -625,7 +627,8 @@ def get_context_summary(self): if percentage > 80: result += "\n\n⚠ **Context is getting full!**\n" result += "- Remove non-essential files via the `ContextManager` tool.\n" - result += "- Keep only essential files in context for best performance" + result += "- Remove unused MCP servers via the `RemoveMcp` tool to free context space.\n" + result += "- Keep only essential files and MCP servers in context for best performance" result += "\n" if not hasattr(self, "context_blocks_cache"): self.context_blocks_cache = {} @@ -659,7 +662,9 @@ def get_environment_info(self): result += f"- Git repository: {rel_repo_dir} with {num_files:,} files\n" except Exception: result += "- Git repository: active but details unavailable\n" - else: + if self.mcp_manager and self.mcp_manager.connected_servers: + num_mcp_servers = len(self.mcp_manager.connected_servers) + result += f"- Connected MCP servers: {num_mcp_servers}\n" result += "- Git repository: none\n" result += "" return result @@ -766,7 +771,7 @@ async def _execute_local_tools(self, tool_calls_list): async def gather_and_await(): return await asyncio.gather(*tasks, return_exceptions=True) - task_results, interrupted = await interruptible( + task_results, interrupted = await self.coroutines.interruptible( gather_and_await(), self.interrupt_event ) diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 1738092fb6e..8694686082f 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -423,9 +423,11 @@ def __init__( registered_servers=None, uuid: str = "", parent_uuid: str = "", + **kwargs, ): from cecli.helpers.agents.service import AgentService + self.original_kwargs = kwargs # initialize from args.map_cache_dir self.coroutines = coroutines # Per-instance tool and server filtering dictionaries diff --git a/cecli/commands/__init__.py b/cecli/commands/__init__.py index ae4e83a84cd..8095f367e9e 100644 --- a/cecli/commands/__init__.py +++ b/cecli/commands/__init__.py @@ -35,6 +35,7 @@ from .hot_reload import HotReloadCommand from .include_skill import IncludeSkillCommand from .lint import LintCommand +from .list_mcp import ListMcpCommand from .list_sessions import ListSessionsCommand from .list_skills import ListSkillsCommand from .load import LoadCommand @@ -123,6 +124,7 @@ CommandRegistry.register(SwitchAgentCommand) CommandRegistry.register(IncludeSkillCommand) CommandRegistry.register(LintCommand) +CommandRegistry.register(ListMcpCommand) CommandRegistry.register(ListSessionsCommand) CommandRegistry.register(ListSkillsCommand) CommandRegistry.register(LoadCommand) @@ -210,6 +212,7 @@ "LoadCommand", "LoadHookCommand", "LoadMcpCommand", + "ListMcpCommand", "LoadSessionCommand", "LoadSkillCommand", "LsCommand", diff --git a/cecli/commands/list_mcp.py b/cecli/commands/list_mcp.py new file mode 100644 index 00000000000..20457590e90 --- /dev/null +++ b/cecli/commands/list_mcp.py @@ -0,0 +1,48 @@ +from cecli.commands.utils.base_command import BaseCommand +from cecli.commands.utils.helpers import format_command_result + + +class ListMcpCommand(BaseCommand): + NORM_NAME = "list-mcp" + DESCRIPTION = "List all loaded and configured MCP servers." + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the list-mcp command.""" + if not coder.mcp_manager: + return format_command_result(io, cls.NORM_NAME, "MCP manager is not configured.") + + all_servers = coder.mcp_manager.servers + connected_servers = coder.mcp_manager.connected_servers + + loaded_server_names = {server.name for server in connected_servers} + configured_servers = [ + server for server in all_servers if server.name not in loaded_server_names + ] + + result = [] + if loaded_server_names: + result.append("Loaded MCP Servers:") + for name in sorted(list(loaded_server_names)): + result.append(f"- {name}") + else: + result.append("No MCP servers are currently loaded.") + + result.append("") + + if configured_servers: + result.append("Configured MCP Servers:") + for server in sorted(configured_servers, key=lambda s: s.name): + result.append(f"- {server.name}") + else: + result.append("No other MCP servers are configured.") + + return format_command_result(io, cls.NORM_NAME, "\n".join(result)) + + @classmethod + def get_help(cls) -> str: + """Get help text for the list-mcp command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /list-mcp # Lists all loaded and configured MCP servers\n" + return help_text diff --git a/cecli/commands/remove_mcp.py b/cecli/commands/remove_mcp.py index 6a08ef33a5f..2b32307ecd1 100644 --- a/cecli/commands/remove_mcp.py +++ b/cecli/commands/remove_mcp.py @@ -10,7 +10,7 @@ class RemoveMcpCommand(BaseCommand): NORM_NAME = "remove-mcp" - DESCRIPTION = "Remove a MCP server by name, or use '*' to remove all" + DESCRIPTION = "Remove (unload) a MCP server by name, or use '*' to remove all" @classmethod async def execute(cls, io, coder, args, **kwargs): diff --git a/cecli/tools/__init__.py b/cecli/tools/__init__.py index 44e527cff37..54dd6622227 100644 --- a/cecli/tools/__init__.py +++ b/cecli/tools/__init__.py @@ -17,9 +17,12 @@ git_show, git_status, grep, + list_mcp, + load_mcp, load_skill, ls, read_range, + remove_mcp, remove_skill, thinking, undo_change, @@ -42,9 +45,12 @@ git_show, git_status, grep, + list_mcp, + load_mcp, load_skill, ls, read_range, + remove_mcp, remove_skill, thinking, undo_change, diff --git a/cecli/tools/list_mcp.py b/cecli/tools/list_mcp.py new file mode 100644 index 00000000000..6df6daca038 --- /dev/null +++ b/cecli/tools/list_mcp.py @@ -0,0 +1,50 @@ +from cecli.tools.utils.base_tool import BaseTool + + +class Tool(BaseTool): + NORM_NAME = "list-mcp" + SCHEMA = { + "type": "function", + "function": { + "name": "ListMcp", + "description": "List all loaded and configured MCP servers.", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + + @classmethod + def execute(cls, coder, **kwargs): + """List all loaded and configured MCP servers.""" + if not coder.mcp_manager: + return "MCP manager is not configured." + + all_servers = coder.mcp_manager.servers + connected_servers = coder.mcp_manager.connected_servers + + loaded_server_names = {server.name for server in connected_servers} + configured_servers = [ + server for server in all_servers if server.name not in loaded_server_names + ] + + result = [] + if loaded_server_names: + result.append("Loaded MCP Servers:") + for name in sorted(list(loaded_server_names)): + result.append(f"- {name}") + else: + result.append("No MCP servers are currently loaded.") + + result.append("") + + if configured_servers: + result.append("Configured MCP Servers:") + for server in sorted(configured_servers, key=lambda s: s.name): + result.append(f"- {server.name}") + else: + result.append("No other MCP servers are configured.") + + return "\n".join(result) diff --git a/cecli/tools/load_mcp.py b/cecli/tools/load_mcp.py new file mode 100644 index 00000000000..fb98f55811d --- /dev/null +++ b/cecli/tools/load_mcp.py @@ -0,0 +1,78 @@ +from typing import List + +from cecli.tools.utils.base_tool import BaseTool + + +class Tool(BaseTool): + NORM_NAME = "load-mcp" + SCHEMA = { + "type": "function", + "function": { + "name": "LoadMCP", + "description": "Load MCP server(s) by name, or use '*' to load all enabled servers.", + "parameters": { + "type": "object", + "properties": { + "servers": { + "type": "array", + "items": {"type": "string"}, + "description": "A list of MCP server names to load. Use '*' to load all.", + } + }, + "required": ["servers"], + }, + }, + } + + @classmethod + async def execute(cls, coder, servers: List[str]): + """Execute the load-mcp tool with given parameters.""" + if not coder.mcp_manager or not coder.mcp_manager.servers: + return "No MCP servers found, nothing to load." + + results = [] + servers_to_load = [] + + if servers == ["*"]: + for server in coder.mcp_manager.servers: + if server.name in coder.mcp_manager.connected_servers: + results.append(f"Server already loaded: {server.name}") + continue + auto_connect = server.config.get("enabled", True) + if not auto_connect: + results.append(f"Skipping server (not enabled by default): {server.name}") + continue + servers_to_load.append(server) + else: + for server_name in servers: + server = coder.mcp_manager.get_server(server_name) + if server is None: + results.append(f"MCP server {server_name} does not exist.") + else: + servers_to_load.append(server) + + if not servers_to_load and results: + return "\n".join(results) + + # Process the loading + for server in servers_to_load: + server_name = server.name + if server_name in coder.mcp_manager.connected_servers: + results.append(f"Server already loaded: {server_name}") + continue + + coder.interrupt_event.clear() + did_connect, interrupted = await coder.coroutines.interruptible( + coder.mcp_manager.connect_server(server_name), + coder.interrupt_event, + ) + + if interrupted: + results.append(f"Interrupted: {server_name}") + continue + if did_connect: + results.append(f"Loaded server: {server_name}") + else: + results.append(f"Unable to load server: {server_name}") + + return "\n".join(results) diff --git a/cecli/tools/remove_mcp.py b/cecli/tools/remove_mcp.py new file mode 100644 index 00000000000..788647de8c6 --- /dev/null +++ b/cecli/tools/remove_mcp.py @@ -0,0 +1,77 @@ +from typing import List + +from cecli.tools.utils.base_tool import BaseTool + + +class Tool(BaseTool): + NORM_NAME = "remove-mcp" + SCHEMA = { + "type": "function", + "function": { + "name": "RemoveMCP", + "description": ( + "Remove (unload) MCP server(s) by name, or use '*' to remove all connected servers." + ), + "parameters": { + "type": "object", + "properties": { + "servers": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "A list of MCP server names to remove. Use '*' to remove all." + ), + } + }, + "required": ["servers"], + }, + }, + } + + @classmethod + async def execute(cls, coder, servers: List[str]): + """Execute the remove-mcp tool with given parameters.""" + if not coder.mcp_manager or not coder.mcp_manager.servers: + return "No MCP servers are configured." + + results = [] + servers_to_action = [] + + # Determine which servers to act on + if servers == ["*"]: + servers_to_action.extend(coder.mcp_manager.connected_servers.keys()) + else: + for server_name in servers: + server = coder.mcp_manager.get_server(server_name) + if not server: + results.append(f"MCP server {server_name} does not exist.") + elif server.name not in coder.mcp_manager.connected_servers: + results.append(f"Server {server_name} is not currently connected.") + else: + servers_to_action.append(server.name) + + # If there are no servers to act on but we have preliminary results (like errors), return them + if not servers_to_action and results: + return "\n".join(results) + + # If there are no servers to remove at all + if not servers_to_action: + return "No servers to remove." + + # Process the removal + for server_name in servers_to_action: + coder.interrupt_event.clear() + did_disconnect, interrupted = await coder.coroutines.interruptible( + coder.mcp_manager.disconnect_server(server_name), + coder.interrupt_event, + ) + + if interrupted: + results.append(f"Interrupted: {server_name}") + continue + if did_disconnect: + results.append(f"Removed server: {server_name}") + else: + results.append(f"Unable to remove server: {server_name}") + + return "\n".join(results) diff --git a/cecli/tui/app.py b/cecli/tui/app.py index 1a2d3822792..eaa3ab8f40a 100644 --- a/cecli/tui/app.py +++ b/cecli/tui/app.py @@ -750,7 +750,6 @@ def update_spinner(self, msg, agent_name: str | None = None): def show_error(self, message, agent_name: str | None = None): """Show an error message in the status bar.""" status_bar = self.query_one("#status-bar", StatusBar) - status_bar.show_notification(message, severity="error", timeout=5, agent_name=agent_name) def on_resize(self) -> None: diff --git a/cecli/website/docs/config/agent-mode.md b/cecli/website/docs/config/agent-mode.md index d66ac7c14e7..0b50f7c80ca 100644 --- a/cecli/website/docs/config/agent-mode.md +++ b/cecli/website/docs/config/agent-mode.md @@ -309,8 +309,26 @@ agent-config: For complete documentation on creating and using skills, including skill directory structure, SKILL.md format, and best practices, see the [Skills documentation](https://github.com/dwash96/cecli/blob/main/cecli/website/docs/config/skills.md). +### MCP Server Management + +MCP (Model Context Protocol) servers provide external tools to the agent, but each connected server and its tools consume context tokens. To maintain optimal performance: + +- **Remove unused servers**: If an MCP server is no longer needed for the current task, remove it using the `RemoveMcp` tool to free up context space. +- **Load servers on demand**: Only load MCP servers when their tools are actually required. Use the `LoadMcp` tool to add servers as needed. +- **Monitor context usage**: The context summary block shows total token usage. Removing unnecessary MCP servers can significantly reduce context overhead. +- **List active servers**: Use the `ListMcp` tool to see which servers are currently connected and consuming context. + ### Benefits +### MCP Server Management +MCP (Model Context Protocol) servers provide external tools to the agent, but each connected server and its tools consume context tokens. To maintain optimal performance: + +- **Remove unused servers**: If an MCP server is no longer needed for the current task, remove it using the `RemoveMcp` tool to free up context space. +- **Load servers on demand**: Only load MCP servers when their tools are actually required. Use the `LoadMcp` tool to add servers as needed. +- **Monitor context usage**: The context summary block shows total token usage. Removing unnecessary MCP servers can significantly reduce context overhead. +- **List active servers**: Use the `ListMcp` tool to see which servers are currently connected and consuming context. + +### Benefits - **Autonomous operation**: Reduces need for manual file management - **Context awareness**: Real-time project information improves decision making - **Precision editing**: Granular tools reduce errors compared to SEARCH/REPLACE diff --git a/tests/integration/test_agent_mcp_management.py b/tests/integration/test_agent_mcp_management.py new file mode 100644 index 00000000000..39d20e90101 --- /dev/null +++ b/tests/integration/test_agent_mcp_management.py @@ -0,0 +1,137 @@ +"""Integration tests for agent and subagent MCP management.""" + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from cecli.coders.agent_coder import AgentCoder +from cecli.coders.sub_agent_coder import SubAgentCoder +from cecli.tools.load_mcp_tool import LoadMcpTool +from cecli.tools.remove_mcp_tool import RemoveMcpTool + + +@pytest.fixture +def mock_mcp_manager(): + """Fixture for a mocked McpServerManager.""" + manager = MagicMock() + manager.connected_servers = {} + + # Mock servers + server1 = MagicMock() + server1.name = "test_server" + server1.config = {"enabled": True} + + server2 = MagicMock() + server2.name = "sub_test_server" + server2.config = {"enabled": True} + + manager.servers = [server1, server2] + + def get_server_side_effect(name): + if name == "test_server": + return server1 + if name == "sub_test_server": + return server2 + return None + + manager.get_server.side_effect = get_server_side_effect + + async def connect(server_name): + manager.connected_servers[server_name] = "connected" + return True, False # (did_connect, interrupted) + + async def disconnect(server_name): + if server_name in manager.connected_servers: + del manager.connected_servers[server_name] + return True, False + return False, False + + manager.connect_server = AsyncMock(side_effect=connect) + manager.disconnect_server = AsyncMock(side_effect=disconnect) + manager.add_server = AsyncMock() + manager.add_server = AsyncMock() + manager.add_server = AsyncMock() + manager.add_server = AsyncMock() + + return manager + + +@pytest.fixture +def agent_coder(mock_mcp_manager): + """Fixture for an AgentCoder with a mocked MCP manager.""" + with patch("cecli.coders.agent_coder.McpServerManager", return_value=mock_mcp_manager): + coder = AgentCoder( + main_model=MagicMock(), + io=MagicMock(), + ) + coder.mcp_manager = mock_mcp_manager + coder.original_kwargs = {} + coder.coroutines = Mock() + + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible.side_effect = mock_interruptible + coder.interrupt_event = Mock() + return coder + + +@pytest.fixture +async def sub_agent_coder(agent_coder): + """Fixture for a SubAgentCoder.""" + # Fix: Use create() class method instead of direct instantiation + sub_agent = await SubAgentCoder.create(from_coder=agent_coder) + # Ensure sub_agent has the required mocks for tools + sub_agent.coroutines = agent_coder.coroutines + sub_agent.interrupt_event = agent_coder.interrupt_event + return sub_agent + + +@pytest.mark.asyncio +async def test_agent_can_load_mcp_server(agent_coder, mock_mcp_manager): + """Verify an agent can load an MCP server.""" + tool = LoadMcpTool() + server_name = "test_server" + + await tool.execute(agent_coder, servers=[server_name]) + + mock_mcp_manager.connect_server.assert_called_once_with(server_name) + assert server_name in mock_mcp_manager.connected_servers + + +@pytest.mark.asyncio +async def test_agent_can_remove_mcp_server(agent_coder, mock_mcp_manager): + """Verify an agent can remove an MCP server.""" + tool = RemoveMcpTool() + server_name = "test_server" + mock_mcp_manager.connected_servers[server_name] = "connected" + + await tool.execute(agent_coder, servers=[server_name]) + + mock_mcp_manager.disconnect_server.assert_called_once_with(server_name) + assert server_name not in mock_mcp_manager.connected_servers + + +@pytest.mark.asyncio +async def test_sub_agent_can_load_mcp_server(sub_agent_coder, mock_mcp_manager): + """Verify a subagent can load an MCP server.""" + tool = LoadMcpTool() + server_name = "sub_test_server" + + await tool.execute(sub_agent_coder, servers=[server_name]) + + mock_mcp_manager.connect_server.assert_called_once_with(server_name) + assert server_name in mock_mcp_manager.connected_servers + + +@pytest.mark.asyncio +async def test_sub_agent_can_remove_mcp_server(sub_agent_coder, mock_mcp_manager): + """Verify a subagent can remove an MCP server.""" + tool = RemoveMcpTool() + server_name = "sub_test_server" + mock_mcp_manager.connected_servers[server_name] = "connected" + + await tool.execute(sub_agent_coder, servers=[server_name]) + + mock_mcp_manager.disconnect_server.assert_called_once_with(server_name) + assert server_name not in mock_mcp_manager.connected_servers diff --git a/tests/integration/test_mcp_management.py b/tests/integration/test_mcp_management.py new file mode 100644 index 00000000000..0f7f9f38bf1 --- /dev/null +++ b/tests/integration/test_mcp_management.py @@ -0,0 +1,93 @@ +"""Integration tests for MCP management tools.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from cecli.tools.load_mcp_tool import LoadMcpTool +from cecli.tools.remove_mcp_tool import RemoveMcpTool + + +class CoderMock: + """Mock Coder object for integration testing.""" + + def __init__(self): + self.mcp_manager = MagicMock() + self.mcp_manager.servers = [] + self.mcp_manager.connected_servers = {} + self.mcp_manager.get_server.return_value = None + self.mcp_manager.connect_server = AsyncMock(return_value=(True, False)) + self.mcp_manager.disconnect_server = AsyncMock(return_value=(True, False)) + self.coroutines = MagicMock() + self.interrupt_event = MagicMock() + + async def mock_interruptible(self, coro, event): + """Mock interruptible that just executes the coroutine.""" + return await coro, False + + def add_server(self, name, enabled=True): + """Add a mock server to the manager.""" + server = MagicMock() + server.name = name + server.config = {"enabled": enabled} + self.mcp_manager.servers.append(server) + original_get_server = self.mcp_manager.get_server.side_effect + + def get_server_side_effect(server_name): + if server_name == name: + return server + if original_get_server: + return original_get_server(server_name) + return None + + self.mcp_manager.get_server.side_effect = get_server_side_effect + + +@pytest.fixture +def coder(): + """Provide a mock coder for integration testing.""" + return CoderMock() + + +@pytest.mark.asyncio +async def test_integration_load_and_remove_server(coder): + """Test loading and then removing a server.""" + coder.add_server("integration-test-server") + coder.coroutines.interruptible = coder.mock_interruptible + + # Load the server + load_result = await LoadMcpTool.execute(coder, ["integration-test-server"]) + assert "Loaded server: integration-test-server" in load_result + + # Mock the connected server for the remove tool + coder.mcp_manager.connected_servers = {"integration-test-server": coder.mcp_manager.servers[0]} + + # Remove the server + remove_result = await RemoveMcpTool.execute(coder, ["integration-test-server"]) + assert "Removed server: integration-test-server" in remove_result + + +@pytest.mark.asyncio +async def test_integration_wildcard_load_and_remove(coder): + """Test loading and removing all servers with a wildcard.""" + coder.add_server("server1") + coder.add_server("server2") + coder.add_server("server3", enabled=False) + coder.coroutines.interruptible = coder.mock_interruptible + + # Load all enabled servers + load_result = await LoadMcpTool.execute(coder, ["*"]) + assert "Loaded server: server1" in load_result + assert "Loaded server: server2" in load_result + assert "Skipping server (not enabled by default): server3" in load_result + + # Mock the connected servers for the remove tool + coder.mcp_manager.connected_servers = { + "server1": coder.mcp_manager.servers[0], + "server2": coder.mcp_manager.servers[1], + } + + # Remove all connected servers + remove_result = await RemoveMcpTool.execute(coder, ["*"]) + assert "Removed server: server1" in remove_result + assert "Removed server: server2" in remove_result diff --git a/tests/tools/test_remove_mcp_tool.py b/tests/tools/test_remove_mcp_tool.py new file mode 100644 index 00000000000..3d800de5218 --- /dev/null +++ b/tests/tools/test_remove_mcp_tool.py @@ -0,0 +1,146 @@ +"""Unit tests for RemoveMcpTool.execute.""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from cecli.tools.remove_mcp_tool import RemoveMcpTool + + +class DummyIO: + """Mock IO object for testing.""" + + def __init__(self): + self.tool_error = Mock() + self.tool_warning = Mock() + self.tool_output = Mock() + self.interrupt_event = Mock() + + +class DummyCoder: + """Mock Coder object for testing.""" + + def __init__(self): + self.io = DummyIO() + self.mcp_manager = Mock() + self.mcp_manager.servers = [] + self.mcp_manager.connected_servers = {} + self.coroutines = Mock() + self.coroutines.interruptible = AsyncMock() + + self.interrupt_event = Mock() + + +@pytest.fixture +def coder(): + """Provide a dummy coder for testing.""" + return DummyCoder() + + +@pytest.fixture +def mock_server(): + """Provide a mock MCP server.""" + server = Mock() + server.name = "test-server" + return server + + +class TestRemoveMcpTool: + """Test cases for RemoveMcpTool.""" + + @pytest.mark.asyncio + async def test_no_configured_servers(self, coder): + """Test when no MCP servers are configured at all.""" + coder.mcp_manager.servers = [] + result = await RemoveMcpTool.execute(coder, servers=["test"]) + assert result == "No MCP servers are configured." + + @pytest.mark.asyncio + async def test_server_not_found(self, coder, mock_server): + """Test when requested server doesn't exist.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {"existing": "server"} + coder.mcp_manager.get_server.return_value = None + result = await RemoveMcpTool.execute(coder, servers=["nonexistent"]) + assert "MCP server nonexistent does not exist." in result + + @pytest.mark.asyncio + async def test_all_servers_not_loaded(self, coder, mock_server): + """Test when multiple servers exist but are not loaded.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.return_value = mock_server + result = await RemoveMcpTool.execute(coder, servers=["test-server"]) + assert "Server test-server is not currently connected." in result + + @pytest.mark.asyncio + async def test_successful_removal(self, coder, mock_server): + """Test successful server removal.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {"test-server": mock_server} + coder.mcp_manager.get_server.return_value = mock_server + coder.coroutines.interruptible.return_value = (True, False) + result = await RemoveMcpTool.execute(coder, servers=["test-server"]) + assert "Removed server: test-server" in result + + @pytest.mark.asyncio + async def test_removal_interrupted(self, coder, mock_server): + """Test when removal is interrupted.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {"test-server": mock_server} + coder.mcp_manager.get_server.return_value = mock_server + coder.coroutines.interruptible.return_value = (False, True) + result = await RemoveMcpTool.execute(coder, servers=["test-server"]) + assert "Interrupted: test-server" in result + + @pytest.mark.asyncio + async def test_removal_failed(self, coder, mock_server): + """Test when removal fails.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {"test-server": mock_server} + coder.mcp_manager.get_server.return_value = mock_server + coder.coroutines.interruptible.return_value = (False, False) + result = await RemoveMcpTool.execute(coder, servers=["test-server"]) + assert "Unable to remove server: test-server" in result + + @pytest.mark.asyncio + async def test_remove_all_servers(self, coder): + """Test removing all servers with '*' wildcard.""" + server1 = Mock() + server1.name = "server1" + server2 = Mock() + server2.name = "server2" + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {"server1": server1, "server2": server2} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + coder.coroutines.interruptible.return_value = (True, False) + result = await RemoveMcpTool.execute(coder, servers=["*"]) + assert "Removed server: server1" in result + assert "Removed server: server2" in result + + @pytest.mark.asyncio + async def test_mixed_results(self, coder): + """Test mixed success/failure results.""" + server1 = Mock() + server1.name = "server1" + server2 = Mock() + server2.name = "server2" + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {"server1": server1, "server2": server2} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + call_count = 0 + + async def mock_interruptible_func(*args, **kwargs): + nonlocal call_count + result = (True, False) if call_count == 0 else (False, False) + call_count += 1 + return result + + coder.coroutines.interruptible.side_effect = mock_interruptible_func + result = await RemoveMcpTool.execute(coder, servers=["server1", "server2"]) + assert "Removed server: server1" in result + assert "Unable to remove server: server2" in result diff --git a/tests/tools/test_tools_load_mcp_tool.py b/tests/tools/test_tools_load_mcp_tool.py new file mode 100644 index 00000000000..6c0f5e45cbd --- /dev/null +++ b/tests/tools/test_tools_load_mcp_tool.py @@ -0,0 +1,212 @@ +"""Unit tests for LoadMcpTool.execute.""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from cecli.tools.load_mcp_tool import LoadMcpTool + + +class DummyIO: + """Mock IO object for testing.""" + + def __init__(self): + self.tool_error = Mock() + self.tool_warning = Mock() + self.tool_output = Mock() + self.interrupt_event = Mock() + + +class DummyCoder: + """Mock Coder object for testing.""" + + def __init__(self): + self.io = DummyIO() + self.mcp_manager = Mock() + self.mcp_manager.servers = [] + self.mcp_manager.connected_servers = {} + self.coroutines = Mock() + self.coroutines.interruptible = AsyncMock() + self.interrupt_event = Mock() + + +@pytest.fixture +def coder(): + """Provide a dummy coder for testing.""" + return DummyCoder() + + +@pytest.fixture +def mock_server(): + """Provide a mock MCP server.""" + server = Mock() + server.name = "test-server" + server.config = {"enabled": True} + return server + + +class TestLoadMcpTool: + """Test cases for LoadMcpTool.""" + + @pytest.mark.asyncio + async def test_no_mcp_servers_found(self, coder): + """Test when no MCP servers are configured.""" + coder.mcp_manager.servers = [] + result = await LoadMcpTool.execute(coder, servers=["test"]) + assert result == "No MCP servers found, nothing to load." + + @pytest.mark.asyncio + async def test_server_not_found(self, coder, mock_server): + """Test when requested server doesn't exist.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.get_server.return_value = None + result = await LoadMcpTool.execute(coder, servers=["nonexistent"]) + assert "MCP server nonexistent does not exist." in result + + @pytest.mark.asyncio + async def test_server_already_loaded(self, coder, mock_server): + """Test when server is already loaded.""" + mock_server.name = "test-server" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {"test-server": mock_server} + coder.mcp_manager.get_server.return_value = mock_server + # Must return tuple (did_connect, interrupted) + coder.coroutines.interruptible.return_value = (True, False) + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + assert "Server already loaded: test-server" in result + + @pytest.mark.asyncio + async def test_server_not_enabled_by_default(self, coder, mock_server): + """Test when server is not enabled by default.""" + mock_server.config = {"enabled": False} + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.get_server.return_value = mock_server + result = await LoadMcpTool.execute(coder, servers=["*"]) + assert "Skipping server (not enabled by default): test-server" in result + + @pytest.mark.asyncio + async def test_successful_load(self, coder, mock_server): + """Test successful server loading.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.return_value = mock_server + coder.coroutines.interruptible.return_value = (True, False) + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + assert "Loaded server: test-server" in result + + @pytest.mark.asyncio + async def test_load_interrupted(self, coder, mock_server): + """Test when loading is interrupted.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.return_value = mock_server + coder.coroutines.interruptible.return_value = (False, True) + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + assert "Interrupted: test-server" in result + + @pytest.mark.asyncio + async def test_load_failed(self, coder, mock_server): + """Test when loading fails.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.return_value = mock_server + coder.coroutines.interruptible.return_value = (False, False) + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + assert "Unable to load server: test-server" in result + + @pytest.mark.asyncio + async def test_load_all_servers(self, coder): + """Test loading all servers with '*' wildcard.""" + server1 = Mock() + server1.name = "server1" + server1.config = {"enabled": True} + server2 = Mock() + server2.name = "server2" + server2.config = {"enabled": True} + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + coder.coroutines.interruptible.return_value = (True, False) + result = await LoadMcpTool.execute(coder, servers=["*"]) + assert "Loaded server: server1" in result + assert "Loaded server: server2" in result + + @pytest.mark.asyncio + async def test_mixed_results(self, coder): + """Test mixed success/failure results.""" + server1 = Mock() + server1.name = "server1" + server1.config = {"enabled": True} + server2 = Mock() + server2.name = "server2" + server2.config = {"enabled": True} + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + + async def mock_interruptible_func(*args, **kwargs): + # First call succeeds, second fails + if not hasattr(mock_interruptible_func, "call_count"): + mock_interruptible_func.call_count = 0 + mock_interruptible_func.call_count += 1 + if mock_interruptible_func.call_count == 1: + return (True, False) + else: + return (False, False) + + coder.coroutines.interruptible.side_effect = mock_interruptible_func + result = await LoadMcpTool.execute(coder, servers=["server1", "server2"]) + assert "Loaded server: server1" in result + assert "Unable to load server: server2" in result + + @pytest.mark.asyncio + async def test_duplicate_iteration_bug_fix(self, coder, mock_server): + """Test that duplicate iteration bug is fixed - server already loaded only processed once.""" + mock_server.name = "test-server" + coder.mcp_manager.servers = [mock_server] + # Server already connected + coder.mcp_manager.connected_servers = {"test-server": mock_server} + coder.mcp_manager.get_server.return_value = mock_server + + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + + # Should only report server already loaded once + assert result.count("Server already loaded: test-server") == 1 + # connect_server should not have been called since it was already loaded + coder.mcp_manager.connect_server.assert_not_called() + + @pytest.mark.asyncio + async def test_wildcard_with_duplicate_iteration_fix(self, coder): + """Test wildcard loading with duplicate iteration fix.""" + server1 = Mock() + server1.name = "server1" + server1.config = {"enabled": True} + server2 = Mock() + server2.name = "server2" + server2.config = {"enabled": True} + coder.mcp_manager.servers = [server1, server2] + # server1 already loaded, server2 not loaded + coder.mcp_manager.connected_servers = {"server1": server1} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + connect_calls = [] + + async def mock_connect_server(server_name): + connect_calls.append(server_name) + if server_name == "server2": + return True, False + return False, False + + coder.mcp_manager.connect_server.side_effect = mock_connect_server + coder.coroutines.interruptible.side_effect = mock_connect_server + result = await LoadMcpTool.execute(coder, servers=["*"]) + + # Should only attempt to load server2 (server1 should be skipped) + assert "Server already loaded: server1" in result + assert "Loaded server: server2" in result + assert connect_calls == ["server2"] # Only server2 should have been connected diff --git a/tests/unit/test_load_mcp.py b/tests/unit/test_load_mcp.py new file mode 100644 index 00000000000..9a0048007f8 --- /dev/null +++ b/tests/unit/test_load_mcp.py @@ -0,0 +1,278 @@ +"""Unit tests for LoadMcpTool.execute.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from cecli.tools.load_mcp_tool import LoadMcpTool + + +class DummyIO: + """Mock IO object for testing.""" + + def __init__(self): + self.tool_error = MagicMock() + self.tool_warning = MagicMock() + self.tool_output = MagicMock() + self.interrupt_event = MagicMock() + + +class DummyCoder: + """Mock Coder object for testing.""" + + def __init__(self): + self.io = DummyIO() + self.mcp_manager = MagicMock() + self.mcp_manager.servers = [] + self.mcp_manager.connected_servers = {} + self.coroutines = MagicMock() + self.interrupt_event = MagicMock() + + +@pytest.fixture +def coder(): + """Provide a dummy coder for testing.""" + return DummyCoder() + + +@pytest.fixture +def mock_server(): + """Provide a mock MCP server.""" + server = MagicMock() + server.name = "test-server" + server.config = {"enabled": True} + return server + + +@pytest.mark.asyncio +async def test_no_mcp_servers_found(coder): + """Test when no MCP servers are configured.""" + coder.mcp_manager.servers = [] + result = await LoadMcpTool.execute(coder, servers=["test"]) + assert result == "No MCP servers found, nothing to load." + + +@pytest.mark.asyncio +async def test_server_not_found(coder, mock_server): + """Test when requested server doesn't exist.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.get_server.return_value = None + result = await LoadMcpTool.execute(coder, servers=["nonexistent"]) + assert "MCP server nonexistent does not exist." in result + + +@pytest.mark.asyncio +async def test_server_already_loaded(coder, mock_server): + """Test when server is already loaded.""" + mock_server.name = "test-server" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {"test-server": mock_server} + coder.mcp_manager.get_server.return_value = mock_server + # Set up connect_server as AsyncMock so assert_not_called works + coder.mcp_manager.connect_server = AsyncMock() + + # Mock interruptible to just execute the coroutine + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + assert "Server already loaded: test-server" in result + # connect_server should not have been called since it was already loaded + coder.mcp_manager.connect_server.assert_not_called() + + +@pytest.mark.asyncio +async def test_server_not_enabled_by_default(coder, mock_server): + """Test when server is not enabled by default.""" + mock_server.config = {"enabled": False} + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.get_server.return_value = mock_server + result = await LoadMcpTool.execute(coder, servers=["*"]) + assert "Skipping server (not enabled by default): test-server" in result + + +@pytest.mark.asyncio +async def test_successful_load(coder, mock_server): + """Test successful server loading.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.return_value = mock_server + + # Set up connect_server as AsyncMock that returns (True, False) + async def mock_connect_server(server_name): + return True, False + + coder.mcp_manager.connect_server = mock_connect_server + + # Mock interruptible to just execute the coroutine + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + assert "Loaded server: test-server" in result + + +@pytest.mark.asyncio +async def test_load_interrupted(coder, mock_server): + """Test when loading is interrupted.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.return_value = mock_server + + # Set up connect_server as AsyncMock + async def mock_connect_server(server_name): + return True, False + + coder.mcp_manager.connect_server = mock_connect_server + + # Mock interruptible to return interruption + async def mock_interruptible(coro, event): + return False, True + + coder.coroutines.interruptible = mock_interruptible + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + assert "Interrupted: test-server" in result + + +@pytest.mark.asyncio +async def test_load_failed(coder, mock_server): + """Test when loading fails.""" + coder.mcp_manager.servers = [mock_server] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.return_value = mock_server + + # Set up connect_server as AsyncMock that returns failure + async def mock_connect_server(server_name): + return False, False + + coder.mcp_manager.connect_server = mock_connect_server + + # Mock interruptible to just execute the coroutine + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + assert "Unable to load server: test-server" in result + + +@pytest.mark.asyncio +async def test_load_all_servers(coder): + """Test loading all servers with '*' wildcard.""" + server1 = MagicMock() + server1.name = "server1" + server1.config = {"enabled": True} + server2 = MagicMock() + server2.name = "server2" + server2.config = {"enabled": True} + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + + # Set up connect_server as AsyncMock + async def mock_connect_server(server_name): + return True, False + + coder.mcp_manager.connect_server = mock_connect_server + + # Mock interruptible to just execute the coroutine + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + result = await LoadMcpTool.execute(coder, servers=["*"]) + assert "Loaded server: server1" in result + assert "Loaded server: server2" in result + + +@pytest.mark.asyncio +async def test_mixed_results(coder): + """Test mixed success/failure results.""" + server1 = MagicMock() + server1.name = "server1" + server1.config = {"enabled": True} + server2 = MagicMock() + server2.name = "server2" + server2.config = {"enabled": True} + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + # First call succeeds, second fails + call_count = 0 + + async def mock_connect_server(server_name): + nonlocal call_count + result = True if call_count == 0 else False + call_count += 1 + return result + + coder.mcp_manager.connect_server = mock_connect_server + + # Mock interruptible to just execute the coroutine + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + result = await LoadMcpTool.execute(coder, servers=["server1", "server2"]) + assert "Loaded server: server1" in result + assert "Unable to load server: server2" in result + + +@pytest.mark.asyncio +async def test_duplicate_iteration_bug_fix(coder, mock_server): + """Test that duplicate iteration bug is fixed - server already loaded only processed once.""" + mock_server.name = "test-server" + coder.mcp_manager.servers = [mock_server] + # Server already connected + coder.mcp_manager.connected_servers = {"test-server": mock_server} + coder.mcp_manager.get_server.return_value = mock_server + # Set up connect_server as AsyncMock + coder.mcp_manager.connect_server = AsyncMock() + result = await LoadMcpTool.execute(coder, servers=["test-server"]) + # Should only report server already loaded once + assert result.count("Server already loaded: test-server") == 1 + # connect_server should not have been called since it was already loaded + coder.mcp_manager.connect_server.assert_not_called() + + +@pytest.mark.asyncio +async def test_wildcard_with_duplicate_iteration_fix(coder): + """Test wildcard loading with duplicate iteration fix.""" + server1 = MagicMock() + server1.name = "server1" + server1.config = {"enabled": True} + server2 = MagicMock() + server2.name = "server2" + server2.config = {"enabled": True} + coder.mcp_manager.servers = [server1, server2] + # server1 already loaded, server2 not loaded + coder.mcp_manager.connected_servers = {"server1": server1} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + connect_calls = [] + + async def mock_connect_server(server_name): + connect_calls.append(server_name) + if server_name == "server2": + return True, False + return False, False + + coder.mcp_manager.connect_server = mock_connect_server + + # Mock interruptible to just execute the coroutine + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + result = await LoadMcpTool.execute(coder, servers=["*"]) + # Should only attempt to load server2 (server1 should be skipped) + assert "Server already loaded: server1" in result + assert "Loaded server: server2" in result + assert connect_calls == ["server2"] # Only server2 should have been connected diff --git a/tests/unit/test_remove_mcp.py b/tests/unit/test_remove_mcp.py new file mode 100644 index 00000000000..27d7e09a7ec --- /dev/null +++ b/tests/unit/test_remove_mcp.py @@ -0,0 +1,261 @@ +"""Unit tests for RemoveMcpTool.execute.""" + +from unittest.mock import MagicMock + +import pytest + +from cecli.tools.remove_mcp_tool import RemoveMcpTool + + +class DummyIO: + """Mock IO object for testing.""" + + def __init__(self): + self.tool_error = MagicMock() + self.tool_warning = MagicMock() + self.tool_output = MagicMock() + self.interrupt_event = MagicMock() + + +class DummyCoder: + """Mock Coder object for testing.""" + + def __init__(self): + self.io = DummyIO() + self.mcp_manager = MagicMock() + self.mcp_manager.servers = [] + self.mcp_manager.connected_servers = {} + self.coroutines = MagicMock() + self.interrupt_event = MagicMock() + + +@pytest.fixture +def coder(): + """Provide a dummy coder for testing.""" + return DummyCoder() + + +@pytest.fixture +def mock_server(): + """Provide a mock MCP server.""" + server = MagicMock() + server.name = "test-server" + return server + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_success(): + """Test successful removal of an MCP server.""" + # Setup + coder = MagicMock() + coder.mcp_manager = MagicMock() + server = MagicMock() + server.name = "test-server" + coder.mcp_manager.get_server.return_value = server + coder.mcp_manager.connected_servers = {"test-server": server} + + # Mock disconnect_server as an async function that returns (True, False) + async def mock_disconnect(server_name): + return True, False + + coder.mcp_manager.disconnect_server = mock_disconnect + + # Mock the interruptible method to execute the coroutine directly without interruption + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines = MagicMock() + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + # Execute + result = await RemoveMcpTool.execute(coder, ["test-server"]) + # Assertions + assert "Removed server: test-server" in result + coder.mcp_manager.disconnect_server.assert_awaited_once_with("test-server") + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_non_existent(): + """Test removing a non-existent MCP server.""" + # Setup + coder = MagicMock() + coder.mcp_manager = MagicMock() + # Create a mock server that exists (to bypass the 'no servers' check) + existing_server = MagicMock() + existing_server.name = "existing-server" + coder.mcp_manager.servers = [existing_server] + # But the one we're looking for doesn't exist + coder.mcp_manager.get_server.return_value = None + # Execute + result = await RemoveMcpTool.execute(coder, ["non-existent-server"]) + # Assertions + assert "MCP server non-existent-server does not exist." in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_not_connected(): + """Test removing a server that is not connected.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server = MagicMock() + server.name = "test-server" + coder.mcp_manager.servers = [server] + coder.mcp_manager.get_server.return_value = server + coder.mcp_manager.connected_servers = {} + result = await RemoveMcpTool.execute(coder, ["test-server"]) + assert "Server test-server is not currently connected." in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_wildcard(): + """Test removing all servers with wildcard '*'.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server1 = MagicMock() + server1.name = "server1" + server2 = MagicMock() + server2.name = "server2" + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {"server1": server1, "server2": server2} + + # Mock disconnect_server as an async function that returns (True, False) + async def mock_disconnect(server_name): + return True, False + + coder.mcp_manager.disconnect_server = mock_disconnect + + # Mock interruptible to execute the coroutine without interruption + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines = MagicMock() + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + result = await RemoveMcpTool.execute(coder, ["*"]) + assert "Removed server: server1" in result + assert "Removed server: server2" in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_interrupted(): + """Test when removal is interrupted.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server = MagicMock() + server.name = "test-server" + coder.mcp_manager.servers = [server] + coder.mcp_manager.get_server.return_value = server + coder.mcp_manager.connected_servers = {"test-server": server} + + async def mock_disconnect(server_name): + return False, True + + coder.mcp_manager.disconnect_server = mock_disconnect + + async def mock_interruptible(coro, event): + return False, True + + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + result = await RemoveMcpTool.execute(coder, ["test-server"]) + assert "Interrupted: test-server" in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_failed(): + """Test when removal fails.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server = MagicMock() + server.name = "test-server" + coder.mcp_manager.servers = [server] + coder.mcp_manager.get_server.return_value = server + coder.mcp_manager.connected_servers = {"test-server": server} + + async def mock_disconnect(server_name): + return False, False + + coder.mcp_manager.disconnect_server = mock_disconnect + + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + result = await RemoveMcpTool.execute(coder, ["test-server"]) + assert "Unable to remove server: test-server" in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_no_servers_configured(): + """Test when no MCP servers are configured at all.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + coder.mcp_manager.servers = [] + result = await RemoveMcpTool.execute(coder, servers=["test"]) + assert result == "No MCP servers are configured." + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_mixed_results(): + """Test mixed success/failure results.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server1 = MagicMock() + server1.name = "server1" + server2 = MagicMock() + server2.name = "server2" + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {"server1": server1, "server2": server2} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + call_count = 0 + + async def mock_disconnect(server_name): + nonlocal call_count + result = (True, False) if call_count == 0 else (False, False) + call_count += 1 + return result + + coder.mcp_manager.disconnect_server = mock_disconnect + + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + result = await RemoveMcpTool.execute(coder, servers=["server1", "server2"]) + assert "Removed server: server1" in result + assert "Unable to remove server: server2" in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_dictionary_iteration_fix(): + """Test that dictionary iteration bug is fixed - iterates over keys correctly.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server1 = MagicMock() + server1.name = "server1" + server2 = MagicMock() + server2.name = "server2" + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {"server1": server1, "server2": server2} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + + async def mock_disconnect(server_name): + return True, False + + coder.mcp_manager.disconnect_server = mock_disconnect + + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + result = await RemoveMcpTool.execute(coder, servers=["*"]) + # Should successfully remove both servers using dictionary keys + assert "Removed server: server1" in result + assert "Removed server: server2" in result diff --git a/tests/unit/test_unit_load_mcp_tool.py b/tests/unit/test_unit_load_mcp_tool.py new file mode 100644 index 00000000000..34dfaae1f09 --- /dev/null +++ b/tests/unit/test_unit_load_mcp_tool.py @@ -0,0 +1,139 @@ +"""Unit tests for load-mcp tool.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from cecli.tools.load_mcp_tool import LoadMcpTool + + +@pytest.fixture +def mock_mcp_manager(): + """Fixture for a mocked McpServerManager.""" + manager = MagicMock() + manager.connected_servers = {} + + # Mock servers + server1 = MagicMock() + server1.name = "test-server" + server1.config = {"enabled": True} + + server2 = MagicMock() + server2.name = "server2" + server2.config = {"enabled": True} + + server3 = MagicMock() + server3.name = "server3" + server3.config = {"enabled": False} + + manager.servers = [server1, server2, server3] + + def get_server_side_effect(name): + if name == "test-server": + return server1 + if name == "server2": + return server2 + if name == "server3": + return server3 + return None + + manager.get_server.side_effect = get_server_side_effect + + async def connect(server_name): + manager.connected_servers[server_name] = "connected" + return True, False # (did_connect, interrupted) + + async def disconnect(server_name): + if server_name in manager.connected_servers: + del manager.connected_servers[server_name] + return True, False + return False, False + + manager.connect_server = AsyncMock(side_effect=connect) + manager.disconnect_server = AsyncMock(side_effect=disconnect) + manager.add_server = AsyncMock() + return manager + + +@pytest.mark.asyncio +async def test_load_mcp_tool_success(mock_mcp_manager): + """Test loading a single MCP server successfully.""" + tool = LoadMcpTool() + + # Mock the coder + coder = MagicMock() + coder.mcp_manager = mock_mcp_manager + + # Mock interruptible to return (await coro, False) + async def mock_interruptible(coro, event): + return await coro + + coder.coroutines = MagicMock() + coder.coroutines.interruptible.side_effect = mock_interruptible + coder.interrupt_event = MagicMock() + + result = await tool.execute(coder, servers=["test-server"]) + + assert "Loaded server: test-server" in result + mock_mcp_manager.connect_server.assert_awaited_once_with("test-server") + + +@pytest.mark.asyncio +async def test_load_mcp_tool_non_existent(mock_mcp_manager): + """Test loading a non-existent MCP server.""" + + tool = LoadMcpTool() + + coder = MagicMock() + coder.mcp_manager = mock_mcp_manager + + result = await tool.execute(coder, servers=["non-existent-server"]) + + assert "MCP server non-existent-server does not exist." in result + mock_mcp_manager.connect_server.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_load_mcp_tool_already_loaded(mock_mcp_manager): + """Test loading an already loaded MCP server.""" + tool = LoadMcpTool() + coder = MagicMock() + coder.mcp_manager = mock_mcp_manager + # Pre-populate connected_servers + server = mock_mcp_manager.get_server("test-server") + coder.mcp_manager.connected_servers = {"test-server": server} + + result = await tool.execute(coder, servers=["test-server"]) + + assert "Server already loaded: test-server" in result + mock_mcp_manager.connect_server.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_load_mcp_tool_wildcard_and_duplicate_fix(mock_mcp_manager): + """Test loading with wildcard and duplicate fix.""" + tool = LoadMcpTool() + coder = MagicMock() + coder.mcp_manager = mock_mcp_manager + + # Mock interruptible to return (await coro, False) + async def mock_interruptible(coro, event): + return await coro + + coder.coroutines = MagicMock() + coder.coroutines.interruptible.side_effect = mock_interruptible + coder.interrupt_event = MagicMock() + + # Set up connected_servers: server1 is already connected + server1 = mock_mcp_manager.get_server("test-server") + coder.mcp_manager.connected_servers = {"test-server": server1} + + result = await tool.execute(coder, servers=["*"]) + + # Check results + assert "Server already loaded: test-server" in result + assert "Loaded server: server2" in result + assert "Skipping server (not enabled by default): server3" in result + + # Verify connect_server was called only once for server2 + mock_mcp_manager.connect_server.assert_awaited_once_with("server2") diff --git a/tests/unit/test_unit_remove_mcp_tool.py b/tests/unit/test_unit_remove_mcp_tool.py new file mode 100644 index 00000000000..0126f1b9a4d --- /dev/null +++ b/tests/unit/test_unit_remove_mcp_tool.py @@ -0,0 +1,280 @@ +"""Unit tests for RemoveMcpTool.execute.""" + +from unittest.mock import MagicMock + +import pytest + +from cecli.tools.remove_mcp_tool import RemoveMcpTool + + +class DummyIO: + """Mock IO object for testing.""" + + def __init__(self): + self.tool_error = MagicMock() + self.tool_warning = MagicMock() + self.tool_output = MagicMock() + self.interrupt_event = MagicMock() + + +class DummyCoder: + """Mock Coder object for testing.""" + + def __init__(self): + self.io = DummyIO() + self.mcp_manager = MagicMock() + self.mcp_manager.servers = [] + self.mcp_manager.connected_servers = {} + self.coroutines = MagicMock() + self.interrupt_event = MagicMock() + + +@pytest.fixture +def coder(): + """Provide a dummy coder for testing.""" + return DummyCoder() + + +@pytest.fixture +def mock_server(): + """Provide a mock MCP server.""" + server = MagicMock() + server.name = "test-server" + return server + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_success(): + """Test successful removal of an MCP server.""" + # Setup + coder = MagicMock() + coder.mcp_manager = MagicMock() + server = MagicMock() + server.name = "test-server" + coder.mcp_manager.get_server.return_value = server + coder.mcp_manager.connected_servers = {"test-server": server} + + # Mock disconnect_server as an AsyncMock that returns (True, False) + async def mock_disconnect(server_name): + return True, False + + coder.mcp_manager.disconnect_server = mock_disconnect + + # Mock the interruptible method to execute the coroutine directly without interruption + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines = MagicMock() + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + + # Execute + result = await RemoveMcpTool.execute(coder, ["test-server"]) + + # Assertions + assert "Removed server: test-server" in result + coder.mcp_manager.disconnect_server.assert_awaited_once_with("test-server") + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_non_existent(): + """Test removing a non-existent MCP server.""" + # Setup + coder = MagicMock() + coder.mcp_manager = MagicMock() + # Create a mock server that exists (to bypass the 'no servers' check) + existing_server = MagicMock() + existing_server.name = "existing-server" + coder.mcp_manager.servers = [existing_server] + # But the one we're looking for doesn't exist + coder.mcp_manager.get_server.return_value = None + + # Execute + result = await RemoveMcpTool.execute(coder, ["non-existent-server"]) + + # Assertions + assert "MCP server non-existent-server does not exist." in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_not_connected(): + """Test removing a server that is not connected.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server = MagicMock() + server.name = "test-server" + coder.mcp_manager.servers = [server] + coder.mcp_manager.get_server.return_value = server + coder.mcp_manager.connected_servers = {} + + result = await RemoveMcpTool.execute(coder, ["test-server"]) + + assert "Server test-server is not currently connected." in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_wildcard(): + """Test removing all servers with wildcard '*'.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + + server1 = MagicMock() + server1.name = "server1" + server2 = MagicMock() + server2.name = "server2" + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {"server1": server1, "server2": server2} + + # Mock disconnect_server as an AsyncMock that returns (True, False) + async def mock_disconnect(server_name): + return True, False + + coder.mcp_manager.disconnect_server = mock_disconnect + + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines = MagicMock() + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + + result = await RemoveMcpTool.execute(coder, ["*"]) + + assert "Removed server: server1" in result + assert "Removed server: server2" in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_interrupted(): + """Test when removal is interrupted.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server = MagicMock() + server.name = "test-server" + coder.mcp_manager.servers = [server] + coder.mcp_manager.get_server.return_value = server + coder.mcp_manager.connected_servers = {"test-server": server} + + async def mock_disconnect(server_name): + return False, True + + coder.mcp_manager.disconnect_server = mock_disconnect + + async def mock_interruptible(coro, event): + return False, True + + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + + result = await RemoveMcpTool.execute(coder, ["test-server"]) + + assert "Interrupted: test-server" in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_failed(): + """Test when removal fails.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server = MagicMock() + server.name = "test-server" + coder.mcp_manager.servers = [server] + coder.mcp_manager.get_server.return_value = server + coder.mcp_manager.connected_servers = {"test-server": server} + + async def mock_disconnect(server_name): + return False, False + + coder.mcp_manager.disconnect_server = mock_disconnect + + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + + result = await RemoveMcpTool.execute(coder, ["test-server"]) + + assert "Unable to remove server: test-server" in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_no_servers_configured(): + """Test when no MCP servers are configured at all.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + coder.mcp_manager.servers = [] + + result = await RemoveMcpTool.execute(coder, servers=["test"]) + + assert result == "No MCP servers are configured." + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_mixed_results(): + """Test mixed success/failure results.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server1 = MagicMock() + server1.name = "server1" + server2 = MagicMock() + server2.name = "server2" + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {"server1": server1, "server2": server2} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + + call_count = 0 + + async def mock_disconnect(server_name): + nonlocal call_count + result = (True, False) if call_count == 0 else (False, False) + call_count += 1 + return result + + coder.mcp_manager.disconnect_server = mock_disconnect + + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + + result = await RemoveMcpTool.execute(coder, servers=["server1", "server2"]) + + assert "Removed server: server1" in result + assert "Unable to remove server: server2" in result + + +@pytest.mark.asyncio +async def test_remove_mcp_tool_dictionary_iteration_fix(): + """Test that dictionary iteration bug is fixed - iterates over keys correctly.""" + coder = MagicMock() + coder.mcp_manager = MagicMock() + server1 = MagicMock() + server1.name = "server1" + server2 = MagicMock() + server2.name = "server2" + coder.mcp_manager.servers = [server1, server2] + coder.mcp_manager.connected_servers = {"server1": server1, "server2": server2} + coder.mcp_manager.get_server.side_effect = lambda name: next( + (s for s in [server1, server2] if s.name == name), None + ) + + async def mock_disconnect(server_name): + return True, False + + coder.mcp_manager.disconnect_server = mock_disconnect + + async def mock_interruptible(coro, event): + return await coro, False + + coder.coroutines.interruptible = mock_interruptible + coder.interrupt_event = MagicMock() + + result = await RemoveMcpTool.execute(coder, servers=["*"]) + + # Should successfully remove both servers using dictionary keys + assert "Removed server: server1" in result + assert "Removed server: server2" in result