df5542f7d6
Three-arm harness under scripts/native-bakeoff/: - arm A: /api/chat with JSON tools (current default) - arm B: /api/generate raw:true with canonical HF jinja template rendered directly - arm C: google-deepmind/gemma JAX ToolSampler (env-gated, JAX required) Interim finding from A+B sweep on matt-strix gemma4:26b Q4: Ollama's bidirectional JSON↔native tool-call translator is faithful. The "long" multi-tool task produces identical behavior (7 steps / 6 tools) on both arms. Earlier arm-B parser bug that looked like a divergence was a harness issue: preserving the model's <|channel>thought\n<channel|> prefix as assistant content tripped the jinja template's tool_response-following conditional, appending a spurious <turn|>\n that corrupted the next step's prompt. Fixed by dropping the channel prefix on the assistant message. Arm C left as scaffolded-but-not-run — the JAX/bf16 reference path would answer "does the GGUF runtime diverge from DeepMind's implementation" but requires a separate env with the `gemma` PyPI package. Parked pending SDXL eviction or vast-h100 session. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
262 lines
11 KiB
Python
262 lines
11 KiB
Python
"""Arm C: google-deepmind/gemma JAX ToolSampler (reference path).
|
|
|
|
This arm runs against the *canonical* JAX reference implementation. No
|
|
Ollama, no llama.cpp, no GGUF quantization, no wire protocol — the
|
|
chat template, token-level sampling, and tool-call parsing all happen
|
|
inside the Python process using the code Google wrote for Gemma 4.
|
|
|
|
**Environment requirement** — this arm cannot run inside the Ollama-only
|
|
environment used by arms A/B. Setup:
|
|
|
|
pip install jax[cuda12] gemma # or jax[cpu] for CPU fallback
|
|
huggingface-cli login # weights download via HF
|
|
|
|
It will download `gm.ckpts.CheckpointPath.GEMMA4_E4B_IT` on first run
|
|
(~8GB). Run this arm on a host with ≥16GB RAM (CPU) or ≥10GB VRAM (GPU).
|
|
|
|
**Known caveat** — the `gm.text.ToolSampler` docstring notes that
|
|
"Gemma 1, 2 and 3 models were not specifically trained for tool use"
|
|
and flags the sampler as a proof-of-concept. Gemma 4 *is* tool-trained
|
|
so it should do better here, but if this arm underperforms A/B it may
|
|
be the sampler wrapper, not the model. The trace logs the raw sampler
|
|
turns so that can be diagnosed post-hoc.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
from typing import Any
|
|
|
|
# Local imports are guarded so the harness can at least import this
|
|
# module on a non-JAX host for syntax checking. The actual run() call
|
|
# will blow up with a clean ImportError if the env isn't set up.
|
|
try:
|
|
from gemma import gm # type: ignore
|
|
_GEMMA_AVAILABLE = True
|
|
except ImportError:
|
|
gm = None # type: ignore
|
|
_GEMMA_AVAILABLE = False
|
|
|
|
from tasks import SYSTEM_PROMPT, FAKE_HISTORY, TASKS, execute_tool_stub # noqa: F401 (TASKS for parity with A/B)
|
|
|
|
|
|
# -------- Tool wrappers: one gm.tools.Tool subclass per stub --------
|
|
#
|
|
# ToolSampler requires DESCRIPTION + EXAMPLE for each tool so the model
|
|
# sees an in-context example of the calling pattern. The EXAMPLE bodies
|
|
# are intentionally short — they're primers, not test cases.
|
|
|
|
def _build_tools():
|
|
"""Build the 8 ToolSampler-compatible wrappers. Deferred so that
|
|
`import gm` only happens when we actually intend to run the arm."""
|
|
assert gm is not None
|
|
|
|
class WebSearch(gm.tools.Tool):
|
|
DESCRIPTION = "Search the web for current information."
|
|
EXAMPLE = gm.tools.Example(
|
|
query="recent Home Assistant release notes",
|
|
thought="web_search is the right tool for current events / docs.",
|
|
tool_kwargs={"query": "home assistant latest release"},
|
|
tool_kwargs_doc={"query": "<search query string>"},
|
|
result="1. HA 2026.4 released...",
|
|
answer="Home Assistant 2026.4 is the most recent release.",
|
|
)
|
|
def call(self, query: str) -> str:
|
|
return execute_tool_stub("web_search", {"query": query})
|
|
|
|
class SethSearch(gm.tools.Tool):
|
|
DESCRIPTION = "Search Seth's homelab (repos, wiki, media). Use source='sethflix' for movies/TV."
|
|
EXAMPLE = gm.tools.Example(
|
|
query="any cyberpunk movies on sethflix?",
|
|
thought="Use source=sethflix to search the movie library.",
|
|
tool_kwargs={"query": "cyberpunk", "source": "sethflix"},
|
|
tool_kwargs_doc={
|
|
"query": "<search query>",
|
|
"source": "<'sethflix' | 'general'>",
|
|
"limit": "<int, default 10>",
|
|
},
|
|
result="Blade Runner 2049, Ghost in the Shell, ...",
|
|
answer="Yes — Blade Runner 2049, Ghost in the Shell, and a few others.",
|
|
)
|
|
def call(self, query: str, source: str = "general", limit: int = 10) -> str:
|
|
return execute_tool_stub("sethsearch", {"query": query, "source": source, "limit": limit})
|
|
|
|
class CheckSethflix(gm.tools.Tool):
|
|
DESCRIPTION = "Verify which comma-separated titles are in sethflix."
|
|
EXAMPLE = gm.tools.Example(
|
|
query="is The Matrix in the library?",
|
|
thought="check_sethflix verifies library membership.",
|
|
tool_kwargs={"titles": "The Matrix"},
|
|
tool_kwargs_doc={"titles": "<comma-separated title list>"},
|
|
result="- The Matrix: IN LIBRARY",
|
|
answer="Yes, The Matrix is in the library.",
|
|
)
|
|
def call(self, titles: str) -> str:
|
|
return execute_tool_stub("check_sethflix", {"titles": titles})
|
|
|
|
class MemoryRead(gm.tools.Tool):
|
|
DESCRIPTION = "Look up stored facts about a topic or user."
|
|
EXAMPLE = gm.tools.Example(
|
|
query="what do I have about home automation?",
|
|
thought="memory_read is the right tool.",
|
|
tool_kwargs={"query": "home automation"},
|
|
tool_kwargs_doc={"query": "<topic>", "user": "<optional user filter>"},
|
|
result="- home_automation: Seth uses HA on VM 706...",
|
|
answer="You have notes about HA on VM 706 with Zigbee2MQTT.",
|
|
)
|
|
def call(self, query: str, user: str = "") -> str:
|
|
return execute_tool_stub("memory_read", {"query": query, "user": user})
|
|
|
|
class MemoryWrite(gm.tools.Tool):
|
|
DESCRIPTION = "Store a durable fact."
|
|
EXAMPLE = gm.tools.Example(
|
|
query="remember that Seth prefers dark themes",
|
|
thought="memory_write stores a key/content pair.",
|
|
tool_kwargs={"key": "theme_preference", "content": "dark with orange accents"},
|
|
tool_kwargs_doc={"key": "<short id>", "content": "<fact body>", "user": "<optional>"},
|
|
result="stored: theme_preference = dark with orange accents",
|
|
answer="Saved.",
|
|
)
|
|
def call(self, key: str, content: str, user: str = "") -> str:
|
|
return execute_tool_stub("memory_write", {"key": key, "content": content, "user": user})
|
|
|
|
class WebFetch(gm.tools.Tool):
|
|
DESCRIPTION = "Fetch the text contents of a URL."
|
|
EXAMPLE = gm.tools.Example(
|
|
query="fetch https://example.com/docs",
|
|
thought="web_fetch pulls page text.",
|
|
tool_kwargs={"url": "https://example.com/docs"},
|
|
tool_kwargs_doc={"url": "<absolute URL>"},
|
|
result="fetched content: ...",
|
|
answer="The page discusses X, Y, Z.",
|
|
)
|
|
def call(self, url: str) -> str:
|
|
return execute_tool_stub("web_fetch", {"url": url})
|
|
|
|
class ChatSearch(gm.tools.Tool):
|
|
DESCRIPTION = "Search message history across Matrix rooms."
|
|
EXAMPLE = gm.tools.Example(
|
|
query="have we talked about grafana before?",
|
|
thought="chat_search looks through prior messages.",
|
|
tool_kwargs={"query": "grafana"},
|
|
tool_kwargs_doc={"query": "<search query>"},
|
|
result="[2026-03-14] @seth: grafana dashboard...",
|
|
answer="Yes — you discussed a grafana dashboard on March 14.",
|
|
)
|
|
def call(self, query: str) -> str:
|
|
return execute_tool_stub("chat_search", {"query": query})
|
|
|
|
class GenerateImage(gm.tools.Tool):
|
|
DESCRIPTION = "Generate an image via SDXL."
|
|
EXAMPLE = gm.tools.Example(
|
|
query="make me a sunset image",
|
|
thought="generate_image dispatches to SDXL.",
|
|
tool_kwargs={"prompt": "dramatic ocean sunset"},
|
|
tool_kwargs_doc={"prompt": "<image description>"},
|
|
result="image generated: /mxc/abc/sunset.png",
|
|
answer="Done — here's the sunset image.",
|
|
)
|
|
def call(self, prompt: str) -> str:
|
|
return execute_tool_stub("generate_image", {"prompt": prompt})
|
|
|
|
return [
|
|
WebSearch(), SethSearch(), CheckSethflix(),
|
|
MemoryRead(), MemoryWrite(), WebFetch(),
|
|
ChatSearch(), GenerateImage(),
|
|
]
|
|
|
|
|
|
async def run(
|
|
*,
|
|
ollama_url: str, # unused; kept for CLI parity with arms A/B
|
|
model: str, # unused; arm C loads its own checkpoint
|
|
task_prompt: str,
|
|
num_ctx: int, # unused; ToolSampler uses its own seq_len
|
|
num_predict: int,
|
|
step_budget: int,
|
|
) -> dict[str, Any]:
|
|
if not _GEMMA_AVAILABLE:
|
|
return {
|
|
"arm": "jax-native",
|
|
"error": "gemma package not importable — run in a JAX+gemma env. See module docstring.",
|
|
"final": {"halt_reason": "env_missing", "steps_used": 0, "tool_calls_total": 0, "wall_clock_s": 0},
|
|
}
|
|
|
|
# Let JAX use the whole GPU if present (per colab_tool_use.ipynb hint).
|
|
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.95")
|
|
|
|
t_load_start = time.time()
|
|
model_net = gm.nn.Gemma4_E4B()
|
|
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA4_E4B_IT)
|
|
tools = _build_tools()
|
|
sampler = gm.text.ToolSampler(
|
|
model=model_net,
|
|
params=params,
|
|
tools=tools,
|
|
print_stream=False,
|
|
)
|
|
load_elapsed_s = round(time.time() - t_load_start, 2)
|
|
|
|
# ToolSampler doesn't natively consume a system prompt + pre-populated
|
|
# history. We simulate the same mid-session context by prepending a
|
|
# compact version of FAKE_HISTORY into the user message itself. This is
|
|
# a fidelity compromise documented in the writeup — the A/B arms feed
|
|
# history through proper role-tagged turns. If a delta between arms is
|
|
# traced to this, rebuild the sampler's turn list directly from
|
|
# `sampler.turns` pre-population.
|
|
history_compact = "\n".join(
|
|
f"{m['role'].upper()}: {m['content']}" for m in FAKE_HISTORY[-6:]
|
|
)
|
|
user_msg = (
|
|
f"[prior chat context]\n{history_compact}\n\n"
|
|
f"[2026-04-18 14:20] @seth:sethpc.xyz: {task_prompt}"
|
|
)
|
|
|
|
trace: dict[str, Any] = {
|
|
"arm": "jax-native",
|
|
"checkpoint": "GEMMA4_E4B_IT",
|
|
"tools_registered": [t.__class__.__name__ for t in tools],
|
|
"load_elapsed_s": load_elapsed_s,
|
|
"step_budget_note": "ToolSampler manages its own step loop; step_budget ignored",
|
|
"started_at": time.time(),
|
|
"turns": [],
|
|
"final": None,
|
|
}
|
|
|
|
try:
|
|
t0 = time.time()
|
|
answer = sampler.chat(user_msg)
|
|
elapsed = round(time.time() - t0, 2)
|
|
except Exception as e:
|
|
trace["final"] = {"halt_reason": f"sampler_error: {e}", "steps_used": 0,
|
|
"tool_calls_total": 0, "wall_clock_s": round(time.time() - trace["started_at"], 2)}
|
|
return trace
|
|
|
|
# Extract per-turn info from sampler.turns — the library exposes the
|
|
# full trace (thoughts, tool calls, tool results, final answer).
|
|
sampler_turns = list(getattr(sampler, "turns", []) or [])
|
|
tool_call_total = 0
|
|
for i, t in enumerate(sampler_turns):
|
|
# Different releases of gemma have different turn schemas. We
|
|
# log defensively — whatever attributes the turn object has end
|
|
# up in the JSON so we can inspect post-hoc.
|
|
info: dict[str, Any] = {"step": i + 1, "turn_type": t.__class__.__name__}
|
|
for attr in ("query", "thought", "tool_name", "tool_kwargs", "tool_result", "answer"):
|
|
v = getattr(t, attr, None)
|
|
if v is not None:
|
|
info[attr] = v if isinstance(v, (str, int, float, bool, list, dict)) else str(v)
|
|
if info.get("tool_name"):
|
|
tool_call_total += 1
|
|
trace["turns"].append(info)
|
|
|
|
trace["final"] = {
|
|
"halt_reason": "answer_returned" if answer else "no_answer",
|
|
"steps_used": len(sampler_turns),
|
|
"tool_calls_total": tool_call_total,
|
|
"wall_clock_s": round(time.time() - trace["started_at"], 2),
|
|
"model_answer": answer,
|
|
"sampler_elapsed_s": elapsed,
|
|
}
|
|
return trace
|