df5542f7d6
Three-arm harness under scripts/native-bakeoff/: - arm A: /api/chat with JSON tools (current default) - arm B: /api/generate raw:true with canonical HF jinja template rendered directly - arm C: google-deepmind/gemma JAX ToolSampler (env-gated, JAX required) Interim finding from A+B sweep on matt-strix gemma4:26b Q4: Ollama's bidirectional JSON↔native tool-call translator is faithful. The "long" multi-tool task produces identical behavior (7 steps / 6 tools) on both arms. Earlier arm-B parser bug that looked like a divergence was a harness issue: preserving the model's <|channel>thought\n<channel|> prefix as assistant content tripped the jinja template's tool_response-following conditional, appending a spurious <turn|>\n that corrupted the next step's prompt. Fixed by dropping the channel prefix on the assistant message. Arm C left as scaffolded-but-not-run — the JAX/bf16 reference path would answer "does the GGUF runtime diverge from DeepMind's implementation" but requires a separate env with the `gemma` PyPI package. Parked pending SDXL eviction or vast-h100 session. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
"""Native-bakeoff entry point.
|
|
|
|
Three arms, one invocation. Compares inference paths for Gemma 4:
|
|
A. ollama-json — /api/chat with JSON tools (current default)
|
|
B. ollama-native — /api/generate raw:true with canonical jinja template
|
|
C. jax-native — google-deepmind/gemma reference ToolSampler
|
|
|
|
Research question: does the inference path materially change behavior,
|
|
or is Ollama's JSON tools path faithful to the reference? If arms A and
|
|
B diverge, Ollama's parser is the variable. If B and C diverge, that's
|
|
the llama.cpp runtime / GGUF quantization / Ollama's scheduler.
|
|
|
|
Arms A and B run against a local Ollama at http://127.0.0.1:11434
|
|
by default. Arm C needs its own Python env with JAX + the `gemma`
|
|
package (see `arms/jax_native.py` module docstring).
|
|
|
|
Usage:
|
|
python3 harness.py --arm ollama-json --task movies --out runs/A/movies.json
|
|
python3 harness.py --arm ollama-native --task movies --out runs/B/movies.json
|
|
python3 harness.py --arm jax-native --task movies --out runs/C/movies.json
|
|
|
|
# Default model targets: E4B. Override with --model:
|
|
python3 harness.py --arm ollama-json --task movies --model gemma4:26b --out ...
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import importlib
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Make `tasks` and sibling `arms/` importable regardless of where the
|
|
# harness is invoked from.
|
|
_HERE = Path(__file__).resolve().parent
|
|
sys.path.insert(0, str(_HERE))
|
|
|
|
from tasks import TASKS # noqa: E402
|
|
|
|
|
|
ARMS = {
|
|
"ollama-json": "arms.ollama_json",
|
|
"ollama-native": "arms.ollama_native",
|
|
"jax-native": "arms.jax_native",
|
|
}
|
|
|
|
|
|
DEFAULT_MODELS = {
|
|
"ollama-json": "gemma4:e4b-it-q8_0",
|
|
"ollama-native": "gemma4:e4b-it-q8_0",
|
|
"jax-native": "google-deepmind/gemma:GEMMA4_E4B_IT",
|
|
}
|
|
|
|
|
|
async def _main() -> int:
|
|
ap = argparse.ArgumentParser(description="Three-arm native Gemma 4 bakeoff harness.")
|
|
ap.add_argument("--arm", required=True, choices=list(ARMS))
|
|
ap.add_argument("--task", required=True, choices=list(TASKS))
|
|
ap.add_argument("--out", required=True, type=Path)
|
|
ap.add_argument("--model", default=None, help="override default model for this arm")
|
|
ap.add_argument("--ollama-url", default=os.environ.get("OLLAMA_URL", "http://127.0.0.1:11434"))
|
|
ap.add_argument("--num-ctx", type=int, default=8192)
|
|
ap.add_argument("--num-predict", type=int, default=2048)
|
|
ap.add_argument("--step-budget", type=int, default=20)
|
|
args = ap.parse_args()
|
|
|
|
arm_mod = importlib.import_module(ARMS[args.arm])
|
|
model = args.model or DEFAULT_MODELS[args.arm]
|
|
task_prompt = TASKS[args.task]
|
|
|
|
trace = await arm_mod.run(
|
|
ollama_url=args.ollama_url,
|
|
model=model,
|
|
task_prompt=task_prompt,
|
|
num_ctx=args.num_ctx,
|
|
num_predict=args.num_predict,
|
|
step_budget=args.step_budget,
|
|
)
|
|
trace.setdefault("task", args.task)
|
|
trace.setdefault("task_prompt", task_prompt)
|
|
|
|
args.out.parent.mkdir(parents=True, exist_ok=True)
|
|
args.out.write_text(json.dumps(trace, indent=2, default=str))
|
|
|
|
f = trace.get("final") or {}
|
|
print(
|
|
f"arm={args.arm:14s} task={args.task:8s} "
|
|
f"steps={f.get('steps_used', '?')} tools={f.get('tool_calls_total', '?')} "
|
|
f"halt={f.get('halt_reason', '?')} wall={f.get('wall_clock_s', '?')}s"
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(asyncio.run(_main()))
|