feat: native-bakeoff scaffold — Ollama JSON vs native-token tool-calling
Three-arm harness under scripts/native-bakeoff/: - arm A: /api/chat with JSON tools (current default) - arm B: /api/generate raw:true with canonical HF jinja template rendered directly - arm C: google-deepmind/gemma JAX ToolSampler (env-gated, JAX required) Interim finding from A+B sweep on matt-strix gemma4:26b Q4: Ollama's bidirectional JSON↔native tool-call translator is faithful. The "long" multi-tool task produces identical behavior (7 steps / 6 tools) on both arms. Earlier arm-B parser bug that looked like a divergence was a harness issue: preserving the model's <|channel>thought\n<channel|> prefix as assistant content tripped the jinja template's tool_response-following conditional, appending a spurious <turn|>\n that corrupted the next step's prompt. Fixed by dropping the channel prefix on the assistant message. Arm C left as scaffolded-but-not-run — the JAX/bf16 reference path would answer "does the GGUF runtime diverge from DeepMind's implementation" but requires a separate env with the `gemma` PyPI package. Parked pending SDXL eviction or vast-h100 session. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,99 @@
|
||||
"""Native-bakeoff entry point.
|
||||
|
||||
Three arms, one invocation. Compares inference paths for Gemma 4:
|
||||
A. ollama-json — /api/chat with JSON tools (current default)
|
||||
B. ollama-native — /api/generate raw:true with canonical jinja template
|
||||
C. jax-native — google-deepmind/gemma reference ToolSampler
|
||||
|
||||
Research question: does the inference path materially change behavior,
|
||||
or is Ollama's JSON tools path faithful to the reference? If arms A and
|
||||
B diverge, Ollama's parser is the variable. If B and C diverge, that's
|
||||
the llama.cpp runtime / GGUF quantization / Ollama's scheduler.
|
||||
|
||||
Arms A and B run against a local Ollama at http://127.0.0.1:11434
|
||||
by default. Arm C needs its own Python env with JAX + the `gemma`
|
||||
package (see `arms/jax_native.py` module docstring).
|
||||
|
||||
Usage:
|
||||
python3 harness.py --arm ollama-json --task movies --out runs/A/movies.json
|
||||
python3 harness.py --arm ollama-native --task movies --out runs/B/movies.json
|
||||
python3 harness.py --arm jax-native --task movies --out runs/C/movies.json
|
||||
|
||||
# Default model targets: E4B. Override with --model:
|
||||
python3 harness.py --arm ollama-json --task movies --model gemma4:26b --out ...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Make `tasks` and sibling `arms/` importable regardless of where the
|
||||
# harness is invoked from.
|
||||
_HERE = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(_HERE))
|
||||
|
||||
from tasks import TASKS # noqa: E402
|
||||
|
||||
|
||||
ARMS = {
|
||||
"ollama-json": "arms.ollama_json",
|
||||
"ollama-native": "arms.ollama_native",
|
||||
"jax-native": "arms.jax_native",
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_MODELS = {
|
||||
"ollama-json": "gemma4:e4b-it-q8_0",
|
||||
"ollama-native": "gemma4:e4b-it-q8_0",
|
||||
"jax-native": "google-deepmind/gemma:GEMMA4_E4B_IT",
|
||||
}
|
||||
|
||||
|
||||
async def _main() -> int:
|
||||
ap = argparse.ArgumentParser(description="Three-arm native Gemma 4 bakeoff harness.")
|
||||
ap.add_argument("--arm", required=True, choices=list(ARMS))
|
||||
ap.add_argument("--task", required=True, choices=list(TASKS))
|
||||
ap.add_argument("--out", required=True, type=Path)
|
||||
ap.add_argument("--model", default=None, help="override default model for this arm")
|
||||
ap.add_argument("--ollama-url", default=os.environ.get("OLLAMA_URL", "http://127.0.0.1:11434"))
|
||||
ap.add_argument("--num-ctx", type=int, default=8192)
|
||||
ap.add_argument("--num-predict", type=int, default=2048)
|
||||
ap.add_argument("--step-budget", type=int, default=20)
|
||||
args = ap.parse_args()
|
||||
|
||||
arm_mod = importlib.import_module(ARMS[args.arm])
|
||||
model = args.model or DEFAULT_MODELS[args.arm]
|
||||
task_prompt = TASKS[args.task]
|
||||
|
||||
trace = await arm_mod.run(
|
||||
ollama_url=args.ollama_url,
|
||||
model=model,
|
||||
task_prompt=task_prompt,
|
||||
num_ctx=args.num_ctx,
|
||||
num_predict=args.num_predict,
|
||||
step_budget=args.step_budget,
|
||||
)
|
||||
trace.setdefault("task", args.task)
|
||||
trace.setdefault("task_prompt", task_prompt)
|
||||
|
||||
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||||
args.out.write_text(json.dumps(trace, indent=2, default=str))
|
||||
|
||||
f = trace.get("final") or {}
|
||||
print(
|
||||
f"arm={args.arm:14s} task={args.task:8s} "
|
||||
f"steps={f.get('steps_used', '?')} tools={f.get('tool_calls_total', '?')} "
|
||||
f"halt={f.get('halt_reason', '?')} wall={f.get('wall_clock_s', '?')}s"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(asyncio.run(_main()))
|
||||
Reference in New Issue
Block a user