Files
gemma4-research/scripts/native-bakeoff/arms/jax_native.py
T
Mortdecai df5542f7d6 feat: native-bakeoff scaffold — Ollama JSON vs native-token tool-calling
Three-arm harness under scripts/native-bakeoff/:
- arm A: /api/chat with JSON tools (current default)
- arm B: /api/generate raw:true with canonical HF jinja template rendered directly
- arm C: google-deepmind/gemma JAX ToolSampler (env-gated, JAX required)

Interim finding from A+B sweep on matt-strix gemma4:26b Q4: Ollama's
bidirectional JSON↔native tool-call translator is faithful. The "long"
multi-tool task produces identical behavior (7 steps / 6 tools) on both
arms. Earlier arm-B parser bug that looked like a divergence was a
harness issue: preserving the model's <|channel>thought\n<channel|>
prefix as assistant content tripped the jinja template's
tool_response-following conditional, appending a spurious <turn|>\n
that corrupted the next step's prompt. Fixed by dropping the channel
prefix on the assistant message.

Arm C left as scaffolded-but-not-run — the JAX/bf16 reference path
would answer "does the GGUF runtime diverge from DeepMind's
implementation" but requires a separate env with the `gemma` PyPI
package. Parked pending SDXL eviction or vast-h100 session.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-20 05:45:12 -04:00

262 lines
11 KiB
Python

"""Arm C: google-deepmind/gemma JAX ToolSampler (reference path).
This arm runs against the *canonical* JAX reference implementation. No
Ollama, no llama.cpp, no GGUF quantization, no wire protocol — the
chat template, token-level sampling, and tool-call parsing all happen
inside the Python process using the code Google wrote for Gemma 4.
**Environment requirement** — this arm cannot run inside the Ollama-only
environment used by arms A/B. Setup:
pip install jax[cuda12] gemma # or jax[cpu] for CPU fallback
huggingface-cli login # weights download via HF
It will download `gm.ckpts.CheckpointPath.GEMMA4_E4B_IT` on first run
(~8GB). Run this arm on a host with ≥16GB RAM (CPU) or ≥10GB VRAM (GPU).
**Known caveat** — the `gm.text.ToolSampler` docstring notes that
"Gemma 1, 2 and 3 models were not specifically trained for tool use"
and flags the sampler as a proof-of-concept. Gemma 4 *is* tool-trained
so it should do better here, but if this arm underperforms A/B it may
be the sampler wrapper, not the model. The trace logs the raw sampler
turns so that can be diagnosed post-hoc.
"""
from __future__ import annotations
import os
import time
from typing import Any
# Local imports are guarded so the harness can at least import this
# module on a non-JAX host for syntax checking. The actual run() call
# will blow up with a clean ImportError if the env isn't set up.
try:
from gemma import gm # type: ignore
_GEMMA_AVAILABLE = True
except ImportError:
gm = None # type: ignore
_GEMMA_AVAILABLE = False
from tasks import SYSTEM_PROMPT, FAKE_HISTORY, TASKS, execute_tool_stub # noqa: F401 (TASKS for parity with A/B)
# -------- Tool wrappers: one gm.tools.Tool subclass per stub --------
#
# ToolSampler requires DESCRIPTION + EXAMPLE for each tool so the model
# sees an in-context example of the calling pattern. The EXAMPLE bodies
# are intentionally short — they're primers, not test cases.
def _build_tools():
"""Build the 8 ToolSampler-compatible wrappers. Deferred so that
`import gm` only happens when we actually intend to run the arm."""
assert gm is not None
class WebSearch(gm.tools.Tool):
DESCRIPTION = "Search the web for current information."
EXAMPLE = gm.tools.Example(
query="recent Home Assistant release notes",
thought="web_search is the right tool for current events / docs.",
tool_kwargs={"query": "home assistant latest release"},
tool_kwargs_doc={"query": "<search query string>"},
result="1. HA 2026.4 released...",
answer="Home Assistant 2026.4 is the most recent release.",
)
def call(self, query: str) -> str:
return execute_tool_stub("web_search", {"query": query})
class SethSearch(gm.tools.Tool):
DESCRIPTION = "Search Seth's homelab (repos, wiki, media). Use source='sethflix' for movies/TV."
EXAMPLE = gm.tools.Example(
query="any cyberpunk movies on sethflix?",
thought="Use source=sethflix to search the movie library.",
tool_kwargs={"query": "cyberpunk", "source": "sethflix"},
tool_kwargs_doc={
"query": "<search query>",
"source": "<'sethflix' | 'general'>",
"limit": "<int, default 10>",
},
result="Blade Runner 2049, Ghost in the Shell, ...",
answer="Yes — Blade Runner 2049, Ghost in the Shell, and a few others.",
)
def call(self, query: str, source: str = "general", limit: int = 10) -> str:
return execute_tool_stub("sethsearch", {"query": query, "source": source, "limit": limit})
class CheckSethflix(gm.tools.Tool):
DESCRIPTION = "Verify which comma-separated titles are in sethflix."
EXAMPLE = gm.tools.Example(
query="is The Matrix in the library?",
thought="check_sethflix verifies library membership.",
tool_kwargs={"titles": "The Matrix"},
tool_kwargs_doc={"titles": "<comma-separated title list>"},
result="- The Matrix: IN LIBRARY",
answer="Yes, The Matrix is in the library.",
)
def call(self, titles: str) -> str:
return execute_tool_stub("check_sethflix", {"titles": titles})
class MemoryRead(gm.tools.Tool):
DESCRIPTION = "Look up stored facts about a topic or user."
EXAMPLE = gm.tools.Example(
query="what do I have about home automation?",
thought="memory_read is the right tool.",
tool_kwargs={"query": "home automation"},
tool_kwargs_doc={"query": "<topic>", "user": "<optional user filter>"},
result="- home_automation: Seth uses HA on VM 706...",
answer="You have notes about HA on VM 706 with Zigbee2MQTT.",
)
def call(self, query: str, user: str = "") -> str:
return execute_tool_stub("memory_read", {"query": query, "user": user})
class MemoryWrite(gm.tools.Tool):
DESCRIPTION = "Store a durable fact."
EXAMPLE = gm.tools.Example(
query="remember that Seth prefers dark themes",
thought="memory_write stores a key/content pair.",
tool_kwargs={"key": "theme_preference", "content": "dark with orange accents"},
tool_kwargs_doc={"key": "<short id>", "content": "<fact body>", "user": "<optional>"},
result="stored: theme_preference = dark with orange accents",
answer="Saved.",
)
def call(self, key: str, content: str, user: str = "") -> str:
return execute_tool_stub("memory_write", {"key": key, "content": content, "user": user})
class WebFetch(gm.tools.Tool):
DESCRIPTION = "Fetch the text contents of a URL."
EXAMPLE = gm.tools.Example(
query="fetch https://example.com/docs",
thought="web_fetch pulls page text.",
tool_kwargs={"url": "https://example.com/docs"},
tool_kwargs_doc={"url": "<absolute URL>"},
result="fetched content: ...",
answer="The page discusses X, Y, Z.",
)
def call(self, url: str) -> str:
return execute_tool_stub("web_fetch", {"url": url})
class ChatSearch(gm.tools.Tool):
DESCRIPTION = "Search message history across Matrix rooms."
EXAMPLE = gm.tools.Example(
query="have we talked about grafana before?",
thought="chat_search looks through prior messages.",
tool_kwargs={"query": "grafana"},
tool_kwargs_doc={"query": "<search query>"},
result="[2026-03-14] @seth: grafana dashboard...",
answer="Yes — you discussed a grafana dashboard on March 14.",
)
def call(self, query: str) -> str:
return execute_tool_stub("chat_search", {"query": query})
class GenerateImage(gm.tools.Tool):
DESCRIPTION = "Generate an image via SDXL."
EXAMPLE = gm.tools.Example(
query="make me a sunset image",
thought="generate_image dispatches to SDXL.",
tool_kwargs={"prompt": "dramatic ocean sunset"},
tool_kwargs_doc={"prompt": "<image description>"},
result="image generated: /mxc/abc/sunset.png",
answer="Done — here's the sunset image.",
)
def call(self, prompt: str) -> str:
return execute_tool_stub("generate_image", {"prompt": prompt})
return [
WebSearch(), SethSearch(), CheckSethflix(),
MemoryRead(), MemoryWrite(), WebFetch(),
ChatSearch(), GenerateImage(),
]
async def run(
*,
ollama_url: str, # unused; kept for CLI parity with arms A/B
model: str, # unused; arm C loads its own checkpoint
task_prompt: str,
num_ctx: int, # unused; ToolSampler uses its own seq_len
num_predict: int,
step_budget: int,
) -> dict[str, Any]:
if not _GEMMA_AVAILABLE:
return {
"arm": "jax-native",
"error": "gemma package not importable — run in a JAX+gemma env. See module docstring.",
"final": {"halt_reason": "env_missing", "steps_used": 0, "tool_calls_total": 0, "wall_clock_s": 0},
}
# Let JAX use the whole GPU if present (per colab_tool_use.ipynb hint).
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.95")
t_load_start = time.time()
model_net = gm.nn.Gemma4_E4B()
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA4_E4B_IT)
tools = _build_tools()
sampler = gm.text.ToolSampler(
model=model_net,
params=params,
tools=tools,
print_stream=False,
)
load_elapsed_s = round(time.time() - t_load_start, 2)
# ToolSampler doesn't natively consume a system prompt + pre-populated
# history. We simulate the same mid-session context by prepending a
# compact version of FAKE_HISTORY into the user message itself. This is
# a fidelity compromise documented in the writeup — the A/B arms feed
# history through proper role-tagged turns. If a delta between arms is
# traced to this, rebuild the sampler's turn list directly from
# `sampler.turns` pre-population.
history_compact = "\n".join(
f"{m['role'].upper()}: {m['content']}" for m in FAKE_HISTORY[-6:]
)
user_msg = (
f"[prior chat context]\n{history_compact}\n\n"
f"[2026-04-18 14:20] @seth:sethpc.xyz: {task_prompt}"
)
trace: dict[str, Any] = {
"arm": "jax-native",
"checkpoint": "GEMMA4_E4B_IT",
"tools_registered": [t.__class__.__name__ for t in tools],
"load_elapsed_s": load_elapsed_s,
"step_budget_note": "ToolSampler manages its own step loop; step_budget ignored",
"started_at": time.time(),
"turns": [],
"final": None,
}
try:
t0 = time.time()
answer = sampler.chat(user_msg)
elapsed = round(time.time() - t0, 2)
except Exception as e:
trace["final"] = {"halt_reason": f"sampler_error: {e}", "steps_used": 0,
"tool_calls_total": 0, "wall_clock_s": round(time.time() - trace["started_at"], 2)}
return trace
# Extract per-turn info from sampler.turns — the library exposes the
# full trace (thoughts, tool calls, tool results, final answer).
sampler_turns = list(getattr(sampler, "turns", []) or [])
tool_call_total = 0
for i, t in enumerate(sampler_turns):
# Different releases of gemma have different turn schemas. We
# log defensively — whatever attributes the turn object has end
# up in the JSON so we can inspect post-hoc.
info: dict[str, Any] = {"step": i + 1, "turn_type": t.__class__.__name__}
for attr in ("query", "thought", "tool_name", "tool_kwargs", "tool_result", "answer"):
v = getattr(t, attr, None)
if v is not None:
info[attr] = v if isinstance(v, (str, int, float, bool, list, dict)) else str(v)
if info.get("tool_name"):
tool_call_total += 1
trace["turns"].append(info)
trace["final"] = {
"halt_reason": "answer_returned" if answer else "no_answer",
"steps_used": len(sampler_turns),
"tool_calls_total": tool_call_total,
"wall_clock_s": round(time.time() - trace["started_at"], 2),
"model_answer": answer,
"sampler_elapsed_s": elapsed,
}
return trace