"""Arm C: google-deepmind/gemma JAX ToolSampler (reference path). This arm runs against the *canonical* JAX reference implementation. No Ollama, no llama.cpp, no GGUF quantization, no wire protocol — the chat template, token-level sampling, and tool-call parsing all happen inside the Python process using the code Google wrote for Gemma 4. **Environment requirement** — this arm cannot run inside the Ollama-only environment used by arms A/B. Setup: pip install jax[cuda12] gemma # or jax[cpu] for CPU fallback huggingface-cli login # weights download via HF It will download `gm.ckpts.CheckpointPath.GEMMA4_E4B_IT` on first run (~8GB). Run this arm on a host with ≥16GB RAM (CPU) or ≥10GB VRAM (GPU). **Known caveat** — the `gm.text.ToolSampler` docstring notes that "Gemma 1, 2 and 3 models were not specifically trained for tool use" and flags the sampler as a proof-of-concept. Gemma 4 *is* tool-trained so it should do better here, but if this arm underperforms A/B it may be the sampler wrapper, not the model. The trace logs the raw sampler turns so that can be diagnosed post-hoc. """ from __future__ import annotations import os import time from typing import Any # Local imports are guarded so the harness can at least import this # module on a non-JAX host for syntax checking. The actual run() call # will blow up with a clean ImportError if the env isn't set up. try: from gemma import gm # type: ignore _GEMMA_AVAILABLE = True except ImportError: gm = None # type: ignore _GEMMA_AVAILABLE = False from tasks import SYSTEM_PROMPT, FAKE_HISTORY, TASKS, execute_tool_stub # noqa: F401 (TASKS for parity with A/B) # -------- Tool wrappers: one gm.tools.Tool subclass per stub -------- # # ToolSampler requires DESCRIPTION + EXAMPLE for each tool so the model # sees an in-context example of the calling pattern. The EXAMPLE bodies # are intentionally short — they're primers, not test cases. def _build_tools(): """Build the 8 ToolSampler-compatible wrappers. Deferred so that `import gm` only happens when we actually intend to run the arm.""" assert gm is not None class WebSearch(gm.tools.Tool): DESCRIPTION = "Search the web for current information." EXAMPLE = gm.tools.Example( query="recent Home Assistant release notes", thought="web_search is the right tool for current events / docs.", tool_kwargs={"query": "home assistant latest release"}, tool_kwargs_doc={"query": ""}, result="1. HA 2026.4 released...", answer="Home Assistant 2026.4 is the most recent release.", ) def call(self, query: str) -> str: return execute_tool_stub("web_search", {"query": query}) class SethSearch(gm.tools.Tool): DESCRIPTION = "Search Seth's homelab (repos, wiki, media). Use source='sethflix' for movies/TV." EXAMPLE = gm.tools.Example( query="any cyberpunk movies on sethflix?", thought="Use source=sethflix to search the movie library.", tool_kwargs={"query": "cyberpunk", "source": "sethflix"}, tool_kwargs_doc={ "query": "", "source": "<'sethflix' | 'general'>", "limit": "", }, result="Blade Runner 2049, Ghost in the Shell, ...", answer="Yes — Blade Runner 2049, Ghost in the Shell, and a few others.", ) def call(self, query: str, source: str = "general", limit: int = 10) -> str: return execute_tool_stub("sethsearch", {"query": query, "source": source, "limit": limit}) class CheckSethflix(gm.tools.Tool): DESCRIPTION = "Verify which comma-separated titles are in sethflix." EXAMPLE = gm.tools.Example( query="is The Matrix in the library?", thought="check_sethflix verifies library membership.", tool_kwargs={"titles": "The Matrix"}, tool_kwargs_doc={"titles": ""}, result="- The Matrix: IN LIBRARY", answer="Yes, The Matrix is in the library.", ) def call(self, titles: str) -> str: return execute_tool_stub("check_sethflix", {"titles": titles}) class MemoryRead(gm.tools.Tool): DESCRIPTION = "Look up stored facts about a topic or user." EXAMPLE = gm.tools.Example( query="what do I have about home automation?", thought="memory_read is the right tool.", tool_kwargs={"query": "home automation"}, tool_kwargs_doc={"query": "", "user": ""}, result="- home_automation: Seth uses HA on VM 706...", answer="You have notes about HA on VM 706 with Zigbee2MQTT.", ) def call(self, query: str, user: str = "") -> str: return execute_tool_stub("memory_read", {"query": query, "user": user}) class MemoryWrite(gm.tools.Tool): DESCRIPTION = "Store a durable fact." EXAMPLE = gm.tools.Example( query="remember that Seth prefers dark themes", thought="memory_write stores a key/content pair.", tool_kwargs={"key": "theme_preference", "content": "dark with orange accents"}, tool_kwargs_doc={"key": "", "content": "", "user": ""}, result="stored: theme_preference = dark with orange accents", answer="Saved.", ) def call(self, key: str, content: str, user: str = "") -> str: return execute_tool_stub("memory_write", {"key": key, "content": content, "user": user}) class WebFetch(gm.tools.Tool): DESCRIPTION = "Fetch the text contents of a URL." EXAMPLE = gm.tools.Example( query="fetch https://example.com/docs", thought="web_fetch pulls page text.", tool_kwargs={"url": "https://example.com/docs"}, tool_kwargs_doc={"url": ""}, result="fetched content: ...", answer="The page discusses X, Y, Z.", ) def call(self, url: str) -> str: return execute_tool_stub("web_fetch", {"url": url}) class ChatSearch(gm.tools.Tool): DESCRIPTION = "Search message history across Matrix rooms." EXAMPLE = gm.tools.Example( query="have we talked about grafana before?", thought="chat_search looks through prior messages.", tool_kwargs={"query": "grafana"}, tool_kwargs_doc={"query": ""}, result="[2026-03-14] @seth: grafana dashboard...", answer="Yes — you discussed a grafana dashboard on March 14.", ) def call(self, query: str) -> str: return execute_tool_stub("chat_search", {"query": query}) class GenerateImage(gm.tools.Tool): DESCRIPTION = "Generate an image via SDXL." EXAMPLE = gm.tools.Example( query="make me a sunset image", thought="generate_image dispatches to SDXL.", tool_kwargs={"prompt": "dramatic ocean sunset"}, tool_kwargs_doc={"prompt": ""}, result="image generated: /mxc/abc/sunset.png", answer="Done — here's the sunset image.", ) def call(self, prompt: str) -> str: return execute_tool_stub("generate_image", {"prompt": prompt}) return [ WebSearch(), SethSearch(), CheckSethflix(), MemoryRead(), MemoryWrite(), WebFetch(), ChatSearch(), GenerateImage(), ] async def run( *, ollama_url: str, # unused; kept for CLI parity with arms A/B model: str, # unused; arm C loads its own checkpoint task_prompt: str, num_ctx: int, # unused; ToolSampler uses its own seq_len num_predict: int, step_budget: int, ) -> dict[str, Any]: if not _GEMMA_AVAILABLE: return { "arm": "jax-native", "error": "gemma package not importable — run in a JAX+gemma env. See module docstring.", "final": {"halt_reason": "env_missing", "steps_used": 0, "tool_calls_total": 0, "wall_clock_s": 0}, } # Let JAX use the whole GPU if present (per colab_tool_use.ipynb hint). os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.95") t_load_start = time.time() model_net = gm.nn.Gemma4_E4B() params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA4_E4B_IT) tools = _build_tools() sampler = gm.text.ToolSampler( model=model_net, params=params, tools=tools, print_stream=False, ) load_elapsed_s = round(time.time() - t_load_start, 2) # ToolSampler doesn't natively consume a system prompt + pre-populated # history. We simulate the same mid-session context by prepending a # compact version of FAKE_HISTORY into the user message itself. This is # a fidelity compromise documented in the writeup — the A/B arms feed # history through proper role-tagged turns. If a delta between arms is # traced to this, rebuild the sampler's turn list directly from # `sampler.turns` pre-population. history_compact = "\n".join( f"{m['role'].upper()}: {m['content']}" for m in FAKE_HISTORY[-6:] ) user_msg = ( f"[prior chat context]\n{history_compact}\n\n" f"[2026-04-18 14:20] @seth:sethpc.xyz: {task_prompt}" ) trace: dict[str, Any] = { "arm": "jax-native", "checkpoint": "GEMMA4_E4B_IT", "tools_registered": [t.__class__.__name__ for t in tools], "load_elapsed_s": load_elapsed_s, "step_budget_note": "ToolSampler manages its own step loop; step_budget ignored", "started_at": time.time(), "turns": [], "final": None, } try: t0 = time.time() answer = sampler.chat(user_msg) elapsed = round(time.time() - t0, 2) except Exception as e: trace["final"] = {"halt_reason": f"sampler_error: {e}", "steps_used": 0, "tool_calls_total": 0, "wall_clock_s": round(time.time() - trace["started_at"], 2)} return trace # Extract per-turn info from sampler.turns — the library exposes the # full trace (thoughts, tool calls, tool results, final answer). sampler_turns = list(getattr(sampler, "turns", []) or []) tool_call_total = 0 for i, t in enumerate(sampler_turns): # Different releases of gemma have different turn schemas. We # log defensively — whatever attributes the turn object has end # up in the JSON so we can inspect post-hoc. info: dict[str, Any] = {"step": i + 1, "turn_type": t.__class__.__name__} for attr in ("query", "thought", "tool_name", "tool_kwargs", "tool_result", "answer"): v = getattr(t, attr, None) if v is not None: info[attr] = v if isinstance(v, (str, int, float, bool, list, dict)) else str(v) if info.get("tool_name"): tool_call_total += 1 trace["turns"].append(info) trace["final"] = { "halt_reason": "answer_returned" if answer else "no_answer", "steps_used": len(sampler_turns), "tool_calls_total": tool_call_total, "wall_clock_s": round(time.time() - trace["started_at"], 2), "model_answer": answer, "sampler_elapsed_s": elapsed, } return trace