feat: native-bakeoff scaffold — Ollama JSON vs native-token tool-calling
Three-arm harness under scripts/native-bakeoff/: - arm A: /api/chat with JSON tools (current default) - arm B: /api/generate raw:true with canonical HF jinja template rendered directly - arm C: google-deepmind/gemma JAX ToolSampler (env-gated, JAX required) Interim finding from A+B sweep on matt-strix gemma4:26b Q4: Ollama's bidirectional JSON↔native tool-call translator is faithful. The "long" multi-tool task produces identical behavior (7 steps / 6 tools) on both arms. Earlier arm-B parser bug that looked like a divergence was a harness issue: preserving the model's <|channel>thought\n<channel|> prefix as assistant content tripped the jinja template's tool_response-following conditional, appending a spurious <turn|>\n that corrupted the next step's prompt. Fixed by dropping the channel prefix on the assistant message. Arm C left as scaffolded-but-not-run — the JAX/bf16 reference path would answer "does the GGUF runtime diverge from DeepMind's implementation" but requires a separate env with the `gemma` PyPI package. Parked pending SDXL eviction or vast-h100 session. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,121 @@
|
||||
"""Arm A: Ollama /api/chat with JSON tools.
|
||||
|
||||
This is the baseline — what mort-bot, OpenWebUI, and every other Ollama
|
||||
client does. Ollama's server translates the OpenAI-style JSON tools
|
||||
array into Gemma's native <|tool>declaration:...<tool|> tokens and
|
||||
parses the model's <|tool_call>call:...<tool_call|> output back into
|
||||
structured tool_calls. This arm measures what we already live with.
|
||||
|
||||
Think setting: fixed to `false` per round-3 bakeoff finding (26B silently
|
||||
stops on think:true in multi-turn tool loops). For E4B the finding was
|
||||
less load-bearing but we hold think:false constant across arms so
|
||||
only the inference path varies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from tasks import SYSTEM_PROMPT, TOOLS, FAKE_HISTORY, execute_tool_stub
|
||||
|
||||
|
||||
async def run(
|
||||
*,
|
||||
ollama_url: str,
|
||||
model: str,
|
||||
task_prompt: str,
|
||||
num_ctx: int,
|
||||
num_predict: int,
|
||||
step_budget: int,
|
||||
) -> dict[str, Any]:
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + list(FAKE_HISTORY)
|
||||
messages.append({"role": "user", "content": f"[2026-04-18 14:20] @seth:sethpc.xyz: {task_prompt}"})
|
||||
|
||||
trace: dict[str, Any] = {
|
||||
"arm": "ollama-json",
|
||||
"model": model,
|
||||
"num_ctx": num_ctx,
|
||||
"num_predict": num_predict,
|
||||
"started_at": time.time(),
|
||||
"turns": [],
|
||||
"final": None,
|
||||
}
|
||||
|
||||
tool_call_total = 0
|
||||
halt: str | None = None
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for step in range(1, step_budget + 1):
|
||||
t0 = time.time()
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": TOOLS,
|
||||
"stream": False,
|
||||
"think": False,
|
||||
"options": {"num_ctx": num_ctx, "num_predict": num_predict,
|
||||
"temperature": 0.7, "top_p": 0.95, "top_k": 64},
|
||||
"keep_alive": "2h",
|
||||
}
|
||||
try:
|
||||
async with session.post(
|
||||
f"{ollama_url}/api/chat", json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=300),
|
||||
) as resp:
|
||||
r = await resp.json()
|
||||
except Exception as e:
|
||||
halt = f"error: {e}"
|
||||
trace["turns"].append({"step": step, "error": str(e)})
|
||||
break
|
||||
|
||||
msg = r.get("message", {}) or {}
|
||||
content = msg.get("content", "") or ""
|
||||
tool_calls = msg.get("tool_calls") or []
|
||||
history_chars = sum(len(m.get("content", "") or "") for m in messages)
|
||||
|
||||
trace["turns"].append({
|
||||
"step": step,
|
||||
"elapsed_s": round(time.time() - t0, 2),
|
||||
"prompt_eval_count": r.get("prompt_eval_count"),
|
||||
"eval_count": r.get("eval_count"),
|
||||
"content_len": len(content),
|
||||
"tool_call_count": len(tool_calls),
|
||||
"history_chars_before_append": history_chars,
|
||||
})
|
||||
messages.append(msg)
|
||||
|
||||
if not tool_calls:
|
||||
halt = "no_tool_calls"
|
||||
break
|
||||
|
||||
tool_call_total += len(tool_calls)
|
||||
for tc in tool_calls:
|
||||
fn = tc.get("function", {})
|
||||
name = fn.get("name")
|
||||
args = fn.get("arguments") or {}
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except Exception:
|
||||
args = {}
|
||||
result = execute_tool_stub(name, args)
|
||||
messages.append({"role": "tool", "content": result})
|
||||
|
||||
if step == step_budget:
|
||||
halt = "step_budget"
|
||||
break
|
||||
|
||||
trace["final"] = {
|
||||
"halt_reason": halt,
|
||||
"steps_used": len(trace["turns"]),
|
||||
"tool_calls_total": tool_call_total,
|
||||
"wall_clock_s": round(time.time() - trace["started_at"], 2),
|
||||
"final_message_count": len(messages),
|
||||
"final_history_chars": sum(len(m.get("content", "") or "") for m in messages),
|
||||
}
|
||||
return trace
|
||||
Reference in New Issue
Block a user