diff --git a/src/workbench/server.py b/src/workbench/server.py new file mode 100644 index 0000000..22eda9f --- /dev/null +++ b/src/workbench/server.py @@ -0,0 +1,237 @@ +"""MCP server — 6 workbench tools with HTTP/WS server management and persistence.""" + +from __future__ import annotations + +import asyncio +import json +import os +import socket +from pathlib import Path +from typing import Optional + +from aiohttp import web +from mcp.server.fastmcp import FastMCP + +from workbench.project import ( + create_project, project_exists, project_path, + append_log, read_log, list_projects, log_session_event, + write_server_info, read_server_info, clear_server_info, + WORKBENCH_DIR, +) + + +def _get_lan_ip() -> str: + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("192.168.0.1", 80)) + ip = s.getsockname()[0] + s.close() + return ip + except Exception: + return "127.0.0.1" + + +def _find_free_port(start: int = 8070) -> int: + for port in range(start, start + 100): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("0.0.0.0", port)) + s.close() + return port + except OSError: + continue + raise RuntimeError(f"No free port found in range {start}-{start + 99}") + + +class WorkbenchServer: + """Core server logic — testable without MCP transport.""" + + def __init__(self, workbench_dir: Path = WORKBENCH_DIR): + self.workbench_dir = Path(workbench_dir) + self._active: dict[str, dict] = {} + self._runners: dict[str, web.AppRunner] = {} + + def _is_server_alive(self, port: int) -> bool: + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(1) + s.connect(("127.0.0.1", port)) + s.close() + return True + except (OSError, ConnectionRefusedError): + return False + + async def reconnect_existing_servers(self) -> None: + if not self.workbench_dir.exists(): + return + for d in self.workbench_dir.iterdir(): + if not d.is_dir(): + continue + name = d.name + info = read_server_info(name, workbench_dir=self.workbench_dir) + if info is None: + continue + port = info["port"] + if self._is_server_alive(port): + self._active[name] = {"port": port, "ws_clients": set()} + else: + clear_server_info(name, workbench_dir=self.workbench_dir) + + async def _start_http_server(self, name: str) -> int: + port = _find_free_port() + pdir = project_path(name, self.workbench_dir) + + app = web.Application() + app["project_name"] = name + app["workbench_server"] = self + + async def ws_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + proj = request.app["project_name"] + if proj in self._active: + self._active[proj]["ws_clients"].add(ws) + try: + async for msg in ws: + pass + finally: + if proj in self._active: + self._active[proj]["ws_clients"].discard(ws) + return ws + + async def static_handler(request): + proj = request.app["project_name"] + path = request.match_info.get("path", "index.html") or "index.html" + file_path = project_path(proj, self.workbench_dir) / path + if not file_path.exists(): + return web.Response(status=404, text="Not found") + return web.FileResponse(file_path) + + app.router.add_get("/ws", ws_handler) + app.router.add_get("/{path:.*}", static_handler) + app.router.add_get("/", static_handler) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "0.0.0.0", port) + await site.start() + + self._runners[name] = runner + self._active[name] = {"port": port, "ws_clients": set()} + write_server_info(name, pid=os.getpid(), port=port, workbench_dir=self.workbench_dir) + return port + + async def _broadcast_ws(self, name: str, message: dict) -> None: + if name not in self._active: + return + clients = self._active[name].get("ws_clients", set()) + dead = set() + for ws in clients: + try: + await ws.send_json(message) + except Exception: + dead.add(ws) + clients -= dead + + async def workbench_scaffold(self, name: str, title: str, description: str = "") -> str: + pdir = project_path(name, self.workbench_dir) + create_project(name, title, description, workbench_dir=self.workbench_dir) + + info = read_server_info(name, workbench_dir=self.workbench_dir) + if info and self._is_server_alive(info["port"]): + port = info["port"] + if name not in self._active: + self._active[name] = {"port": port, "ws_clients": set()} + else: + if info: + clear_server_info(name, workbench_dir=self.workbench_dir) + port = await self._start_http_server(name) + if name not in self._active: + self._active[name] = {"port": port, "ws_clients": set()} + + ip = _get_lan_ip() + log_session_event(name, "session_start", workbench_dir=self.workbench_dir) + return json.dumps({"path": str(pdir), "url": f"http://{ip}:{port}"}) + + async def workbench_state(self, project: str, state: str) -> str: + if not project_exists(project, workbench_dir=self.workbench_dir): + return json.dumps({"error": f"Project '{project}' not found. Run workbench_scaffold first."}) + state_obj = json.loads(state) + pdir = project_path(project, self.workbench_dir) + (pdir / "state.json").write_text(json.dumps(state_obj, indent=2)) + await self._broadcast_ws(project, {"type": "state", "state": state_obj}) + return json.dumps({"ok": True}) + + async def workbench_log(self, project: str, entry: str, data: str = "{}") -> str: + if not project_exists(project, workbench_dir=self.workbench_dir): + return json.dumps({"error": f"Project '{project}' not found."}) + data_obj = json.loads(data) if data and data != "{}" else None + append_log(project, entry, data=data_obj, workbench_dir=self.workbench_dir) + await self._broadcast_ws(project, {"type": "log", "entry": entry}) + return json.dumps({"ok": True}) + + async def workbench_read_log(self, project: str, tail: int = 20) -> str: + if not project_exists(project, workbench_dir=self.workbench_dir): + return json.dumps({"error": f"Project '{project}' not found."}) + entries = read_log(project, tail=tail, workbench_dir=self.workbench_dir) + return json.dumps({"entries": entries}) + + async def workbench_list(self) -> str: + projects = list_projects(workbench_dir=self.workbench_dir) + for p in projects: + p["active"] = p["name"] in self._active + if p["active"]: + ip = _get_lan_ip() + port = self._active[p["name"]]["port"] + p["url"] = f"http://{ip}:{port}" + return json.dumps({"projects": projects}) + + async def workbench_stop(self, project: str) -> str: + if project not in self._active: + return json.dumps({"error": f"Project '{project}' is not running."}) + if project_exists(project, workbench_dir=self.workbench_dir): + entries = read_log(project, tail=999999, workbench_dir=self.workbench_dir) + log_session_event(project, "session_end", workbench_dir=self.workbench_dir, log_entries=len(entries)) + if project in self._runners: + await self._runners[project].cleanup() + del self._runners[project] + clear_server_info(project, workbench_dir=self.workbench_dir) + del self._active[project] + return json.dumps({"ok": True}) + + +def create_mcp_server(workbench_dir: Path = WORKBENCH_DIR) -> FastMCP: + srv = WorkbenchServer(workbench_dir=workbench_dir) + mcp = FastMCP("workbench", instructions="Workbench — build interactive web pages served over LAN. Call workbench_scaffold first.") + + @mcp.tool() + async def workbench_scaffold(name: str, title: str, description: str = "") -> str: + """Create a workbench project and start the HTTP server. Returns the LAN URL to open in a browser. If the project already exists and its server is running, reattaches without starting a duplicate. Always safe to call — never creates duplicates.""" + return await srv.workbench_scaffold(name, title, description) + + @mcp.tool() + async def workbench_state(project: str, state: str) -> str: + """Push a state update to the browser via WebSocket. The state is a JSON string — include 'template' (HTML string) to replace the page content, 'styles' (CSS string) to inject styles, and 'script' (JS string) to execute code. The AI has full control over the page.""" + return await srv.workbench_state(project, state) + + @mcp.tool() + async def workbench_log(project: str, entry: str, data: str = "{}") -> str: + """Append a log entry to the session log. Shows in the browser log feed. entry: human-readable markdown string. data: optional JSON for the machine-readable log.""" + return await srv.workbench_log(project, entry, data) + + @mcp.tool() + async def workbench_read_log(project: str, tail: int = 20) -> str: + """Read recent session log entries so the AI can resume a previous session.""" + return await srv.workbench_read_log(project, tail) + + @mcp.tool() + async def workbench_list() -> str: + """List all workbench projects and whether their HTTP server is running.""" + return await srv.workbench_list() + + @mcp.tool() + async def workbench_stop(project: str) -> str: + """Stop the HTTP server for a project and end the session.""" + return await srv.workbench_stop(project) + + return mcp diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..aaf0f5a --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,105 @@ +import asyncio +import json +import os +from pathlib import Path +from unittest.mock import patch, AsyncMock, MagicMock + +import pytest + +from workbench.server import WorkbenchServer + + +@pytest.fixture +def server(tmp_workbench): + return WorkbenchServer(workbench_dir=tmp_workbench) + + +@pytest.mark.asyncio +async def test_scaffold_creates_project(server, tmp_workbench): + with patch.object(server, "_start_http_server", new_callable=AsyncMock, return_value=8070): + result = await server.workbench_scaffold("test-proj", "Test Project") + data = json.loads(result) + assert "url" in data + assert (tmp_workbench / "test-proj" / "index.html").exists() + + +@pytest.mark.asyncio +async def test_scaffold_reattaches_to_running_server(server, tmp_workbench): + from workbench.project import create_project, write_server_info + create_project("test-proj", "Test", workbench_dir=tmp_workbench) + write_server_info("test-proj", pid=os.getpid(), port=9999, workbench_dir=tmp_workbench) + with patch.object(server, "_is_server_alive", return_value=True): + with patch.object(server, "_start_http_server", new_callable=AsyncMock) as mock_start: + result = await server.workbench_scaffold("test-proj", "Test") + mock_start.assert_not_called() + data = json.loads(result) + assert "9999" in data["url"] + + +@pytest.mark.asyncio +async def test_scaffold_replaces_dead_server(server, tmp_workbench): + from workbench.project import create_project, write_server_info + create_project("test-proj", "Test", workbench_dir=tmp_workbench) + write_server_info("test-proj", pid=99999, port=9999, workbench_dir=tmp_workbench) + with patch.object(server, "_is_server_alive", return_value=False): + with patch.object(server, "_start_http_server", new_callable=AsyncMock, return_value=8070): + result = await server.workbench_scaffold("test-proj", "Test") + data = json.loads(result) + assert "8070" in data["url"] + + +@pytest.mark.asyncio +async def test_list_empty(server): + result = await server.workbench_list() + data = json.loads(result) + assert data["projects"] == [] + + +@pytest.mark.asyncio +async def test_log_writes_to_disk(server, tmp_workbench): + with patch.object(server, "_start_http_server", new_callable=AsyncMock, return_value=8070): + await server.workbench_scaffold("test-proj", "Test") + result = await server.workbench_log("test-proj", "R412 measured 1.05M") + data = json.loads(result) + assert data["ok"] is True + jsonl = (tmp_workbench / "test-proj" / "session.jsonl").read_text().strip() + assert "R412 measured 1.05M" in jsonl + + +@pytest.mark.asyncio +async def test_state_saves_to_disk(server, tmp_workbench): + with patch.object(server, "_start_http_server", new_callable=AsyncMock, return_value=8070): + await server.workbench_scaffold("test-proj", "Test") + state = json.dumps({"template": "

Hello

"}) + result = await server.workbench_state("test-proj", state) + data = json.loads(result) + assert data["ok"] is True + saved = json.loads((tmp_workbench / "test-proj" / "state.json").read_text()) + assert saved["template"] == "

Hello

" + + +@pytest.mark.asyncio +async def test_stop_cleans_up(server, tmp_workbench): + with patch.object(server, "_start_http_server", new_callable=AsyncMock, return_value=8070): + await server.workbench_scaffold("test-proj", "Test") + server._runners["test-proj"] = AsyncMock() + result = await server.workbench_stop("test-proj") + data = json.loads(result) + assert data["ok"] is True + assert not (tmp_workbench / "test-proj" / ".server.json").exists() + + +@pytest.mark.asyncio +async def test_reconnect_on_startup(tmp_workbench): + from workbench.project import create_project, write_server_info + create_project("live-proj", "Live", workbench_dir=tmp_workbench) + write_server_info("live-proj", pid=os.getpid(), port=8070, workbench_dir=tmp_workbench) + create_project("dead-proj", "Dead", workbench_dir=tmp_workbench) + write_server_info("dead-proj", pid=99999, port=8071, workbench_dir=tmp_workbench) + srv = WorkbenchServer(workbench_dir=tmp_workbench) + with patch.object(srv, "_is_server_alive", side_effect=lambda port: port == 8070): + await srv.reconnect_existing_servers() + assert "live-proj" in srv._active + assert srv._active["live-proj"]["port"] == 8070 + assert "dead-proj" not in srv._active + assert not (tmp_workbench / "dead-proj" / ".server.json").exists()