{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n", "
\n", "\n", "\n", " Join Discord if you need help + ⭐ Star us on Github ⭐\n", "
\n", "\n", "To install Unsloth on your local device, follow [our guide](https://unsloth.ai/docs/get-started/install). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).\n", "\n", "You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & how to save it" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Installation" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "%%capture\nimport os, re\nif \"COLAB_\" not in \"\".join(os.environ.keys()):\n !pip install unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r'[\\d]{1,}\\.[\\d]{1,}', str(torch.__version__)).group(0)\n xformers = 'xformers==' + {'2.10':'0.0.34','2.9':'0.0.33.post1','2.8':'0.0.32.post2'}.get(v, \"0.0.34\")\n !pip install sentencepiece protobuf \"datasets==4.3.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install --no-deps transformers==5.5.0\n!pip install torchcodec\nimport torch; torch._dynamo.config.recompile_limit = 64;" }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": "#@title Colab Extra Install { display-mode: \"form\" }\n%%capture\nimport os\n!pip install --upgrade -qqq uv\nif \"COLAB_\" not in \"\".join(os.environ.keys()):\n # If you're not in Colab, just use pip install!\n !pip install unsloth vllm\nelse:\n try: import numpy, PIL; _numpy = f'numpy=={numpy.__version__}'; _pil = f'pillow=={PIL.__version__}'\n except: _numpy = \"numpy\"; _pil = \"pillow\"\n try: import subprocess; is_t4 = \"Tesla T4\" in str(subprocess.check_output([\"nvidia-smi\"]))\n except: is_t4 = False\n _vllm, _triton = ('vllm==0.9.2', 'triton==3.2.0') if is_t4 else ('vllm==0.15.1', 'triton')\n !uv pip install -qqq --upgrade {_vllm} {_numpy} {_pil} torchvision bitsandbytes xformers unsloth\n !uv pip install -qqq {_triton}\n!uv pip install transformers==4.56.2\n!uv pip install --no-deps trl==0.22.2" }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Unsloth" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Goal: Make faster kernels with Reinforcement Learning\n", "\n", "Our goal is to make a faster matrix multiplication kernel by doing RL on Gemma 4 with Unsloth.\n", "\n", "\n", "\n", "You will learn how to:\n", "1. Counteract **reward hacking** like cheating, caching, laziness.\n", "2. Timing and correctness of kernels and time limits.\n", "3. Making good **reward functions**\n", "4. How to seriously do RL to make optimized kernels" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "from unsloth import FastVisionModel\n", "import torch\n", "max_seq_length = 4096 # Can increase for longer reasoning traces\n", "lora_rank = 32 # Larger rank = smarter, but slower\n", "\n", "gemma4_models = [\n", " # Gemma-4 instruct models:\n", " \"unsloth/gemma-4-E2B-it\",\n", " \"unsloth/gemma-4-E4B-it\",\n", " \"unsloth/gemma-4-31B-it\",\n", " \"unsloth/gemma-4-26B-A4B-it\",\n", " # Gemma-4 base models:\n", " \"unsloth/gemma-4-E2B\",\n", " \"unsloth/gemma-4-E4B\",\n", " \"unsloth/gemma-4-31B\",\n", " \"unsloth/gemma-4-26B-A4B\",\n", "] # More models at https://huggingface.co/unsloth\n", "\n", "model, tokenizer = FastVisionModel.from_pretrained(\n", " model_name = \"unsloth/gemma-4-E2B-it\",\n", " max_seq_length = max_seq_length,\n", " load_in_4bit = False, # False for LoRA 16bit\n", " fast_inference = False, # Enable vllm fast inference\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now add some small amount of LoRA weights to Gemma 4 so we only need to train those, instead of training on the full model." ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "model = FastVisionModel.get_peft_model(\n", " model,\n", " r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", " target_modules = [\n", " \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", " \"gate_proj\", \"up_proj\", \"down_proj\",\n", " ],\n", " lora_alpha = lora_rank*2, # *2 speeds up training\n", " use_gradient_checkpointing = \"unsloth\", # Reduces memory usage\n", " random_state = 3407,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Optimized matrix multiplication\n", "\n", "Numpy has optimized matrix multiplication kernels for CPUs via BLAS optimized operations. For GPUs, one can use CUDA accelerated cuBLAS kernels which PyTorch calls under the hood.\n", "\n", "To generate some random matrices to do matrix multiplication, we can do the below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "def generate_random_matrices(seed = 3407, n = 256):\n", " random_state = np.random.RandomState(seed)\n", " n, k, m = random_state.randint(1, n+1, size = 3)\n", " A = np.random.uniform(-10, 10, size = (n, k))\n", " B = np.random.uniform(-10, 10, size = (k, m))\n", " return A, A.tolist(), B, B.tolist()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We shall generate a small matrix, and see the matrix multiplied output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-2.8313286 4.54613909 -7.95265309 6.53459836 2.87235103]\n", " [ 7.0739631 3.76278879 9.31565599 -8.52884711 9.96832952]\n", " [ 8.41214082 6.51136046 -3.79347975 -2.46773693 -2.32292989]\n", " [ 3.91302932 4.98335304 -5.33855089 5.71057634 -2.79871647]]\n", "[[ 0.39218774 -9.6181377 -3.49736707]\n", " [-0.33354865 -1.05626139 3.87231208]\n", " [ 0.49494174 5.91863954 -6.83183693]\n", " [ 5.1465162 -7.51648113 1.00445384]\n", " [ 9.63213377 -4.92327556 3.323014 ]]\n", "[[ 54.73441488 -87.89725072 97.94605887]\n", " [ 58.25238906 -1.8467447 -49.25453031]\n", " [ -35.82528794 -80.25394462 11.51225408]\n", " [ -0.33785799 -103.64132345 38.51974367]]\n" ] } ], "source": [ "A, A_list, B, B_list = generate_random_matrices(seed = 42, n = 5)\n", "print(A)\n", "print(B)\n", "print(np.matmul(A, B))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can call a LLM to generate a simple matrix multiply kernel in Python only, and we can calculate the differences between the actual result and the kernel's result" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def calculate_difference(pred, real):\n", " if pred is None: return 5, 5\n", " assert real is not None\n", " import numpy as np\n", " try:\n", " difference = pred - real\n", " except:\n", " return 5, 5\n", " amax_error = float(np.amax(difference))\n", " mse_error = float(np.mean(np.square(difference)))\n", " return amax_error, mse_error" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Kernel generated by GPT-5\n", "def matmul(A, B):\n", " z, s = zip, sum\n", " Bt = list(z(*B))\n", " return [[s(a*b for a, b in z(row, col)) for col in Bt] for row in A]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see the error below is very small, so that's good!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(7.105427357601002e-15, 4.6783406255758477e-29)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prediction = matmul(A_list, B_list)\n", "calculate_difference(prediction, np.matmul(A, B))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Countering Reward Hacking\n", "\n", "The ultimate goal of RL is to maximize some reward (say speed, revenue, some metric).\n", "\n", "But RL can **cheat** When the RL algorithm learns a trick or exploits something to increase the reward, without actually doing the task at end, this is called \"Reward Hacking\".\n", "\n", "Some good examples are in https://en.wikipedia.org/wiki/Reward_hacking\n", "\n", "For matrix multiplication kernels, we might see the following issues:\n", "\n", "* Laziness: RL learns to use Numpy, Torch, other libraries, which calls optimized kernels.\n", "* Caching: RL learns to cache the result of the output\n", "* Cheating: RL learns to find the actual output by inspecting Python global variables\n", "* RL learns to edit the timing function to make it output 0 time as passed.\n", "\n", "And possibly more. We shall try to address each!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Countering Reward Hacking 1: Stop laziness\n", "We can stop the RL algorithm from calling optimized code by inspecting if the generated code imports other non standard Python libraries. We used GPT-5 to help generate this check `check_only_stdlib_imports`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title (Collapsible code)\n", "import ast\n", "import sys\n", "import sysconfig\n", "from pathlib import Path\n", "\n", "def _stdlib_names():\n", " \"\"\"\n", " Build a set of canonical stdlib top-level module/package names.\n", " Uses sys.stdlib_module_names when available (3.10+), with a\n", " filesystem fallback for older versions/edge cases.\n", " \"\"\"\n", " names = {m.lower() for m in getattr(sys, \"stdlib_module_names\", set())}\n", " names |= {m.lower() for m in sys.builtin_module_names}\n", " names.add(\"__future__\") # special-case\n", "\n", " # Fallback/augmentation: scan the stdlib directory\n", " try:\n", " stdlib_dir = Path(sysconfig.get_path(\"stdlib\"))\n", " if stdlib_dir.exists():\n", " for p in stdlib_dir.iterdir():\n", " if p.name == \"site-packages\":\n", " continue\n", " if p.suffix == \".py\":\n", " names.add(p.stem.lower())\n", " elif p.is_dir() and (p / \"__init__.py\").exists():\n", " names.add(p.name.lower())\n", " except Exception:\n", " # conservative fallback; the names set above will still work well\n", " pass\n", "\n", " return names\n", "\n", "_STDLIB_SET = _stdlib_names()\n", "\n", "def check_only_stdlib_imports(code: str):\n", " \"\"\"\n", " Return (ok: bool, details: dict)\n", "\n", " ok == True -> all absolute imports are from the stdlib.\n", " ok == False -> details['non_stdlib'] lists offending top-level modules.\n", "\n", " details includes:\n", " - stdlib: sorted list of stdlib imports found\n", " - non_stdlib: sorted list of non-stdlib imports found\n", " - relative_imports: count of relative imports (always allowed here)\n", " \"\"\"\n", " try:\n", " tree = ast.parse(code)\n", " except SyntaxError as e:\n", " return False, {\n", " \"error\": f\"SyntaxError: {e}\",\n", " \"stdlib\": [],\n", " \"non_stdlib\": [],\n", " \"relative_imports\": 0,\n", " }\n", "\n", " abs_imports = set()\n", " relative_count = 0\n", "\n", " class Visitor(ast.NodeVisitor):\n", " def visit_Import(self, node: ast.Import):\n", " for alias in node.names:\n", " abs_imports.add(alias.name.split(\".\")[0])\n", " def visit_ImportFrom(self, node: ast.ImportFrom):\n", " nonlocal relative_count\n", " if (node.level or 0) > 0:\n", " # relative import\n", " relative_count += 1\n", " else:\n", " if node.module:\n", " abs_imports.add(node.module.split(\".\")[0])\n", "\n", " Visitor().visit(tree)\n", "\n", " stdlib_found = sorted(m for m in abs_imports if m.lower() in _STDLIB_SET)\n", " non_stdlib = sorted(m for m in abs_imports if m.lower() not in _STDLIB_SET)\n", "\n", " return len(non_stdlib) == 0, {\n", " \"stdlib\": stdlib_found,\n", " \"non_stdlib\": non_stdlib,\n", " \"relative_imports\": relative_count,\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example, let's call `check_only_stdlib_imports` on a random piece of matrix multiplication code generated by GPT-5:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Only stdlib imports? False\n", "{'stdlib': [], 'non_stdlib': ['numpy', 'torch'], 'relative_imports': 0}\n" ] } ], "source": [ "sample = \"\"\"\n", "def matmul(A, B):\n", " import numpy as np\n", " from torch import matmul\n", " z, s = zip, sum\n", " Bt = list(z(*B))\n", " return [[s(a*b for a, b in z(row, col)) for col in Bt] for row in A]\n", "\"\"\"\n", "ok, info = check_only_stdlib_imports(sample)\n", "print(\"Only stdlib imports?\", ok)\n", "print(info)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Countering Reward Hacking 2: Stop cheating\n", "We can stop the RL algorithm from using global or cached variables by restricting it's `locals` and `globals`.\n", "\n", "We are also going to use `exec` to create the function, so we have to save the output to an empty dict.\n", "\n", "We also disallow global variable access." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output_function = {}\n", "exec(sample, {}, output_function)\n", "output_function[\"matmul\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also disallow global variable access via `types.FunctionType(f.__code__, {})`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Success\n", "name 'np' is not defined\n" ] } ], "source": [ "import types\n", "output_function[\"matmul\"] = types.FunctionType(output_function[\"matmul\"].__code__, {})\n", "\n", "def import_numpy():\n", " np.matmul\n", " print(\"Success\")\n", "\n", "import_numpy()\n", "import_numpy = types.FunctionType(import_numpy.__code__, {})\n", "try:\n", " import_numpy()\n", "except Exception as e:\n", " print(str(e))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_locked_down_function(function):\n", " output_function = {}\n", " exec(function, {}, output_function)\n", " new_matmul = output_function[\"matmul\"]\n", " new_matmul = types.FunctionType(new_matmul.__code__, {})\n", " return new_matmul" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Countering Reward Hacking 3: Stop caching\n", "We can stop the RL algorithm from using cached data by wiping the cache with a large fake matrix. We also have to benchmark carefully with multiple loops and turns.\n", "\n", "We also add a **timer** to not make the algorithm go in an endless loop." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, gc, time, statistics\n", "import signal\n", "from contextlib import contextmanager\n", "class TimeoutError(Exception): pass\n", "\n", "@contextmanager\n", "def time_limit(seconds):\n", " def _handler(signum, frame):\n", " raise TimeoutError(f\"Timed out after {seconds}s\")\n", " old = signal.signal(signal.SIGALRM, _handler)\n", " signal.setitimer(signal.ITIMER_REAL, seconds)\n", " try:\n", " yield\n", " finally:\n", " signal.setitimer(signal.ITIMER_REAL, 0.0)\n", " signal.signal(signal.SIGALRM, old)\n", "\n", "class Benchmarker:\n", " def __init__(self, trials = 3, loops = 1, timeout = 30):\n", " self.buffer = np.zeros(2 * 1024 * 1024 * 1024, dtype = np.uint8)\n", " self.trials = trials\n", " self.loops = loops\n", " assert timeout > 0 # Cannot be 0 since it won't work!\n", " self.timeout = timeout\n", " def thrash(self):\n", " # Edit the buffer to wipe cache lines\n", " self.buffer ^= 1\n", " return int(self.buffer[::4096].sum())\n", "\n", " def benchmark(self, function, arguments):\n", " assert len(arguments) == self.loops\n", " samples = []\n", " exceptions = []\n", " timed_out = 0\n", " for _ in range(self.trials):\n", " gc.collect(); gc.disable(); self.thrash()\n", " t_start = time.perf_counter_ns()\n", " for i in range(self.loops):\n", " try:\n", " with time_limit(self.timeout):\n", " function(*arguments[i])\n", " except TimeoutError as e:\n", " timed_out += 1\n", " except Exception as e:\n", " exceptions.append(str(e))\n", " t_end = time.perf_counter_ns()\n", " gc.enable()\n", " samples.append((t_end - t_start) // max(1, self.loops))\n", " return {\n", " \"median_ns\": int(statistics.median(samples)),\n", " \"mean_ns\": int(statistics.fmean(samples)),\n", " \"stdev_ns\": int(statistics.pstdev(samples) if len(samples) > 1 else 0),\n", " \"exceptions\" : exceptions,\n", " \"timeouts\" : timed_out,\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example we use our matmul kernel we had, and benchmark it with a 10 second delay:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'median_ns': 70895404,\n", " 'mean_ns': 70895404,\n", " 'stdev_ns': 0,\n", " 'exceptions': [],\n", " 'timeouts': 0}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "A, A_list, B, B_list = generate_random_matrices(seed = 0, n = 256)\n", "Benchmarker(trials = 1, timeout = 10).benchmark(output_function[\"matmul\"], [(A_list, B_list)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data & RL task setup\n", "\n", "We now have to create a prompt to the model for which it will do some task. For our matrix multiply example, we use the below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Create a new fast matrix multiplication function using only native Python code.\n", "You are given a list of list of numbers.\n", "Output your new function in backticks using the format below:\n", "```python\n", "def matmul(A, B):\n", " return ...\n", "```\n" ] } ], "source": [ "prompt = \"\"\"\n", "Create a new fast matrix multiplication function using only native Python code.\n", "You are given a list of list of numbers.\n", "Output your new function in backticks using the format below:\n", "```python\n", "def matmul(A, B):\n", " return ...\n", "```\n", "\"\"\".strip()\n", "print(prompt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's prompt Gemma 4 without RL and see how it goes:" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "text = tokenizer.apply_chat_template(\n", " [{\"role\": \"user\", \"content\": prompt.strip()}],\n", " tokenize = False,\n", " add_generation_prompt = True,\n", ")\n", "\n", "from transformers import TextStreamer\n", "print(\"=\" * 50)\n", "print(\"BASE MODEL OUTPUT (before RL training):\")\n", "print(\"=\" * 50)\n", "\n", "inputs = tokenizer(\n", " text = text,\n", " add_special_tokens = False,\n", " return_tensors = \"pt\",\n", ").to(\"cuda\")\n", "\n", "text_streamer = TextStreamer(tokenizer, skip_prompt = True)\n", "result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 512,\n", " use_cache = True, temperature = 1.0, top_p = 0.95, top_k = 64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Reward functions\n", "\n", "We now design the `extract_function` function which simply extracts the function wrapped in 3 backticks.\n", "\n", "And 4 reward functions:\n", "\n", "1. `function_works` which rewards the model if the strategy is a valid Python function.\n", "2. `no_cheating` which checks if the function imported other modules, and if it did, we penalize it.\n", "3. `correctness_check` which checks if the kernel was correct or wrong - it shouldn't generate gibberish!\n", "4. `speed_check` checks the performance relative to Numpy matmul directly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def matmul(A, B):\n", " return ...\n" ] } ], "source": [ "def extract_function(text):\n", " if text.count(\"```\") >= 2:\n", " first = text.find(\"```\") + 3\n", " second = text.find(\"```\", first)\n", " fx = text[first : second].strip()\n", " fx = fx.removeprefix(\"python\\n\")\n", " fx = fx[fx.find(\"def\"):]\n", " if fx.startswith(\"def matmul(A, B):\"): return fx\n", " return None\n", "print(extract_function(prompt))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is our `function_works` reward function which uses Python's `exec` but guarded by not allowing leakage of local and global variables. We can also use `check_only_stdlib_imports` first to check if there are errors before even executing the function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(False,\n", " {'error': \"SyntaxError: expected '(' (, line 1)\",\n", " 'stdlib': [],\n", " 'non_stdlib': [],\n", " 'relative_imports': 0})" ], "text/html": [ "
(False,\n",
       " {'error': "SyntaxError: expected '(' (<unknown>, line 1)",\n",
       "  'stdlib': [],\n",
       "  'non_stdlib': [],\n",
       "  'relative_imports': 0})
" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ok, info = check_only_stdlib_imports(\"def a\")\n", "ok, info" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def function_works(completions, **kwargs):\n", " scores = []\n", " for completion in completions:\n", " score = 0\n", " response = completion[0][\"content\"]\n", " function = extract_function(response)\n", " print(function)\n", " if function is not None:\n", " ok, info = check_only_stdlib_imports(function)\n", " if function is None or \"error\" in info:\n", " score = -2.0\n", " else:\n", " try:\n", " new_matmul = create_locked_down_function(function)\n", " score = 1.0\n", " except:\n", " score = -0.5\n", " scores.append(score)\n", " return scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`no_cheating` checks if the function cheated since it might have imported Numpy or Torch optimized code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def no_cheating(completions, **kwargs):\n", " scores = []\n", " for completion in completions:\n", " score = 0\n", " response = completion[0][\"content\"]\n", " function = extract_function(response)\n", " if function is not None:\n", " ok, info = check_only_stdlib_imports(function)\n", " else:\n", " ok = False\n", " scores.append(1.0 if ok else -20.0) # Penalize heavily!\n", " return scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next `correctness_check` checks if the kernel was correct. We want to penalize if the absolute error is larger than 1, and if the mean squared error is somewhat bigger then machine epsilon.\n", "\n", "We have to execute the code now!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "np.float64(2.220446049250313e-16)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.finfo(np.float64).eps" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def correctness_check(completions, **kwargs):\n", " scores = []\n", " # Generate some random matrices of size less than 128\n", " A, A_list, B, B_list = generate_random_matrices(seed = np.random.randint(10000), n = 128)\n", " for completion in completions:\n", " score = 0\n", " response = completion[0][\"content\"]\n", " function = extract_function(response)\n", " if function is not None:\n", " ok, info = check_only_stdlib_imports(function)\n", " if function is None or \"error\" in info:\n", " scores.append(0)\n", " continue\n", " try:\n", " new_matmul = create_locked_down_function(function)\n", " except:\n", " scores.append(0)\n", " continue\n", " try:\n", " pred = new_matmul(A_list.copy(), B_list.copy())\n", " except:\n", " # Failed!\n", " scores.append(-2.0)\n", " continue\n", " true = np.matmul(A, B)\n", " amax_error, mse_error = calculate_difference(pred, true)\n", "\n", " # Check correctness and score!\n", " machine_epsilon = 100*np.finfo(np.float64).eps\n", " if amax_error >= 3: score = -3.0\n", " elif amax_error >= 2: score = -2.5\n", " elif amax_error >= 1: score = -2.0\n", " elif amax_error >= 0.5: score = -1.0\n", " elif amax_error >= 100*machine_epsilon: score = 0.0\n", " elif amax_error >= machine_epsilon: score = 1.0\n", " else: score = 3.0\n", "\n", " if mse_error >= 3: score += -3.0\n", " elif mse_error >= 2: score += -2.5\n", " elif mse_error >= 1: score += -2.0\n", " elif mse_error >= 0.5: score += -1.0\n", " elif mse_error >= 100*machine_epsilon: score += 0.0\n", " elif mse_error >= machine_epsilon: score += 1.0\n", " else: score += 3.0\n", " scores.append(score)\n", " return scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally our benchmarking function for `speed_check`! We shall limit the timer to 10 seconds and do 3 trials." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'median_ns': 205566,\n", " 'mean_ns': 231173,\n", " 'stdev_ns': 39247,\n", " 'exceptions': [],\n", " 'timeouts': 0}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "A, A_list, B, B_list = generate_random_matrices(seed = 0, n = 256)\n", "benchmarker = Benchmarker(trials = 3, timeout = 10)\n", "numpy_results = benchmarker.benchmark(np.matmul, [(A, B)])\n", "numpy_results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'median_ns': 84237,\n", " 'mean_ns': 87442,\n", " 'stdev_ns': 4538,\n", " 'exceptions': [],\n", " 'timeouts': 0}" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_matmul = create_locked_down_function(extract_function(prompt))\n", "new_results = benchmarker.benchmark(new_matmul, [(A_list, B_list)])\n", "new_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can take the difference and do a negative sign for slower ones. If the ratio is less than 1 (ie faster, we shall invert it!)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.02440329071548132" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "negative = -(new_results[\"median_ns\"] / numpy_results[\"median_ns\"]) / 100\n", "positive = +(numpy_results[\"median_ns\"] / new_results[\"median_ns\"]) / 100\n", "reward = negative if new_results[\"median_ns\"] >= numpy_results[\"median_ns\"] else positive\n", "reward" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.333333333333333" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_results[\"median_ns\"] = 3\n", "numpy_results[\"median_ns\"] = 1000\n", "negative = -(new_results[\"median_ns\"] / numpy_results[\"median_ns\"]) / 100\n", "positive = +(numpy_results[\"median_ns\"] / new_results[\"median_ns\"]) / 100\n", "reward = negative if new_results[\"median_ns\"] >= numpy_results[\"median_ns\"] else positive\n", "reward" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gc\n", "def speed_check(completions, **kwargs):\n", " scores = []\n", " # Generate some random matrices of size less than 256\n", " A, A_list, B, B_list = generate_random_matrices(seed = np.random.randint(10000), n = 256)\n", " numpy_results = benchmarker.benchmark(np.matmul, [(A, B)])\n", " for completion in completions:\n", " score = 0\n", " response = completion[0][\"content\"]\n", " function = extract_function(response)\n", " if function is not None:\n", " ok, info = check_only_stdlib_imports(function)\n", " if function is None or \"error\" in info:\n", " scores.append(0)\n", " continue\n", " try:\n", " new_matmul = create_locked_down_function(function)\n", " except:\n", " scores.append(0)\n", " continue\n", " new_results = benchmarker.benchmark(new_matmul, [(A_list.copy(), B_list.copy())])\n", "\n", " # Get score and clip to -10, 10\n", " negative = -(new_results[\"median_ns\"] / numpy_results[\"median_ns\"]) / 100\n", " positive = +(numpy_results[\"median_ns\"] / new_results[\"median_ns\"]) / 100\n", " score = negative if new_results[\"median_ns\"] >= numpy_results[\"median_ns\"] else positive\n", " if score >= 10: score = 10\n", " if score <= -10: score = -10\n", " scores.append(score)\n", " # Free memory to counteract OOMs\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", " return scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create the dataset which includes a replica of our prompt." ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "from datasets import Dataset\n", "dataset = Dataset.from_list([{\"prompt\" : [{\"role\": \"user\", \"content\": prompt.strip()}], \"answer\" : 0}]*1000)\n", "maximum_length = len(tokenizer.apply_chat_template([{\"role\":\"user\", \"content\":prompt.strip()}], add_generation_prompt = True, tokenize = True))\n", "print(maximum_length)\n", "dataset[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### Train the model\n", "\n", "Now set up GRPO Trainer and all configurations! We also support GSDP, GAPO, Dr GRPO and more! Go to our docs https://unsloth.ai/docs/ for more info!" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Leave room for the prompt (plus 1 token safety margin)\n", "max_completion_length = max_seq_length - (maximum_length + 1)\n", "\n", "from trl import GRPOConfig, GRPOTrainer\n", "training_args = GRPOConfig(\n", " temperature = 1.0,\n", " top_p = 0.95,\n", " top_k = 64,\n", " learning_rate = 5e-5,\n", " weight_decay = 0.001,\n", " warmup_ratio = 0.1,\n", " lr_scheduler_type = \"linear\",\n", " optim = \"adamw_8bit\",\n", " logging_steps = 1,\n", " per_device_train_batch_size = 1,\n", " gradient_accumulation_steps = 2, # Increase to 4 for smoother training\n", " num_generations = 2, # Decrease if out of memory\n", " max_completion_length = max_completion_length,\n", " # num_train_epochs = 1, # Set to 1 for a full training run\n", " max_steps = 100,\n", " save_steps = 100,\n", " report_to = \"none\", # Can use Weights & Biases, TrackIO\n", " output_dir = \"outputs\",\n", " epsilon = 0.2,\n", " epsilon_high = 0.28, # one sided\n", " delta = 1.5, # two sided\n", " loss_type = 'bnpo',\n", " mask_truncated_completions = True\n", " # For optional training + evaluation\n", " # fp16_full_eval = True,\n", " # per_device_eval_batch_size = 4,\n", " # eval_accumulation_steps = 1,\n", " # eval_strategy = \"steps\",\n", " # eval_steps = 1,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!\n", "\n", "You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!\n", "\n", "| Step | Training Loss | reward | reward_std | completion_length | kl |\n", "|------|---------------|-----------|------------|-------------------|----------|\n", "| 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |\n", "| 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |\n", "| 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# For optional training + evaluation\n", "# new_dataset = dataset.train_test_split(test_size = 0.01)\n", "\n", "trainer = GRPOTrainer(\n", " model = model,\n", " processing_class = tokenizer,\n", " reward_funcs = [\n", " function_works,\n", " no_cheating,\n", " correctness_check,\n", " speed_check,\n", " ],\n", " args = training_args,\n", " train_dataset = dataset,\n", "\n", " # For optional training + evaluation\n", " # train_dataset = new_dataset[\"train\"],\n", " # eval_dataset = new_dataset[\"test\"],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And let's train the model!\n", "\n", "**NOTE** A T4 free GPU might take 5 minutes for one generation sadly since it's an old GPU - A100 or H100 will be much faster!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 199998}.\n", "==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n", " \\\\ /| Num examples = 1,000 | Num Epochs = 1 | Total steps = 100\n", "O^O/ \\_/ \\ Batch size per device = 2 | Gradient accumulation steps = 1\n", "\\ / Data Parallel GPUs = 1 | Total batch size (2 x 1 x 1) = 2\n", " \"-____-\" Trainable parameters = 1,990,656 of 20,916,747,840 (0.01% trained)\n", "`generation_config` default values have been modified to match model-specific defaults: {'max_length': 131072}. If this is not desired, please set these values explicitly.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "def matmul(A, B):\n", " \"\"\"\n", " Fast matrix multiplication using only native Python code.\n", " \n", " Parameters\n", " ----------\n", " A : list of list of numbers\n", " Left matrix of dimensions (m x p).\n", " B : list of list of numbers\n", " Right matrix of dimensions (p x n).\n", " \n", " Returns\n", " -------\n", " C : list of list of numbers\n", " Resulting matrix of dimensions (m x n) such that C = A × B.\n", " \"\"\"\n", " # Transpose B to allow column access as rows.\n", " Bt = list(zip(*B))\n", " # Compute the dot product of each row from A with each column from B\n", " return [[sum(a * b for a, b in zip(row, col))\n", " for col in Bt]\n", " for row in A]\n", "def matmul(A, B):\n", " return ...\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [100/100 1:36:19, Epoch 0/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Lossrewardreward_stdcompletions / mean_lengthcompletions / min_lengthcompletions / max_lengthcompletions / clipped_ratiocompletions / mean_terminated_lengthcompletions / min_terminated_lengthcompletions / max_terminated_lengthklrewards / function_works / meanrewards / function_works / stdrewards / no_cheating / meanrewards / no_cheating / stdrewards / correctness_check / meanrewards / correctness_check / stdrewards / speed_check / meanrewards / speed_check / std
10.000000-0.9605324.244743536.000000392.000000680.0000000.000000536.000000392.000000680.0000000.0027981.0000000.0000001.0000000.000000-1.0000007.071068-1.9605322.826324
20.000000-11.50460114.842735718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.000834-0.5000002.121320-9.50000014.8492422.0000002.828427-3.5046014.956255
30.000000-1.2327683.847022718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0006911.0000000.0000001.0000000.000000-1.0000007.071068-2.2327683.224046
40.000000-9.11239118.225832541.000000364.000000718.0000000.500000364.000000364.000000364.0000000.004645-0.5000002.121320-9.50000014.8492422.0000002.828427-1.1123911.573158
50.0000001.9825230.584465503.000000352.000000654.0000000.000000503.000000352.000000654.0000000.0042411.0000000.0000001.0000000.0000004.0000000.000000-4.0174770.584465
60.000000-8.95949018.442066629.500000541.000000718.0000000.500000541.000000541.000000541.0000000.002716-0.5000002.121320-9.50000014.8492422.0000002.828427-0.9594901.356924
70.0000005.5170080.094176440.500000394.000000487.0000000.000000440.500000394.000000487.0000000.0017691.0000000.0000001.0000000.0000004.0000000.000000-0.4829920.094176
80.000000-9.26346518.012180718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.000987-0.5000002.121320-9.50000014.8492422.0000002.828427-1.2634651.786810
90.000000-13.00000012.727922586.000000454.000000718.0000000.500000454.000000454.000000454.0000000.002943-0.5000002.121320-9.50000014.8492422.0000002.828427-5.0000007.071068
100.000000-3.9856780.000226635.500000553.000000718.0000000.500000553.000000553.000000553.0000000.0018141.0000000.0000001.0000000.000000-6.0000000.0000000.0143220.000225
110.000000-8.36670019.280397718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001235-0.5000002.121320-9.50000014.8492422.0000002.828427-0.3667000.518593
120.000000-9.32722217.922014718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.000735-0.5000002.121320-9.50000014.8492422.0000002.828427-1.3272221.876975
130.000000-12.98925012.743125718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001106-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0107500.015203
140.0000001.7855222.598972640.500000563.000000718.0000000.500000563.000000563.000000563.0000000.0032061.0000000.0000001.0000000.0000001.0000004.242640-1.2144781.643669
150.000000-9.01898118.357933603.000000488.000000718.0000000.500000488.000000488.000000488.0000000.006529-0.5000002.121320-9.50000014.8492422.0000002.828427-1.0189811.441056
160.000000-3.9852480.000232718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0006251.0000000.0000001.0000000.000000-6.0000000.0000000.0147520.000231
170.000000-8.49668819.096567625.500000533.000000718.0000000.500000533.000000533.000000533.0000000.003519-0.5000002.121320-9.50000014.8492422.0000002.828427-0.4966880.702423
180.000000-12.98532912.748671718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001027-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0146710.020748
190.000000-1.1731723.936521718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0010251.0000000.0000001.0000000.000000-1.0000007.071068-2.1731723.134547
200.000000-2.0000000.000000391.000000297.000000485.0000000.000000391.000000297.000000485.0000000.0050211.0000000.0000001.0000000.0000006.0000000.000000-10.0000000.000000
210.000000-0.3879585.063533593.500000469.000000718.0000000.500000469.000000469.000000469.0000000.0042641.0000000.0000001.0000000.000000-1.0000007.071068-1.3879582.007535
220.000000-12.99263412.738339524.500000331.000000718.0000000.500000331.000000331.000000331.0000000.005515-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0073660.010417
230.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.000955-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
240.000000-12.98972912.742447635.000000552.000000718.0000000.500000552.000000552.000000552.0000000.002888-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0102710.014526
250.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001271-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
260.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001055-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
270.000000-9.10578218.235178534.000000350.000000718.0000000.500000350.000000350.000000350.0000000.021608-0.5000002.121320-9.50000014.8492422.0000002.828427-1.1057821.563811
280.0000002.7928983.932606645.500000573.000000718.0000000.500000573.000000573.000000573.0000000.0068351.0000000.0000001.0000000.0000001.0000004.242640-0.2071020.310035
290.000000-3.9702440.000759616.500000515.000000718.0000000.500000515.000000515.000000515.0000000.0116461.0000000.0000001.0000000.000000-6.0000000.0000000.0297560.000759
300.000000-12.97788912.759192718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001129-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0221110.031270
310.000000-12.98309512.751829586.500000455.000000718.0000000.500000455.000000455.000000455.0000000.032435-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0169050.023908
320.000000-8.08334719.681118718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001218-0.5000002.121320-9.50000014.8492422.0000002.828427-0.0833470.117870
330.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001185-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
340.000000-4.0000000.000000577.500000477.000000678.0000000.000000577.500000477.000000678.0000000.0215511.0000000.0000001.0000000.0000004.0000000.000000-10.0000000.000000
35-0.0000003.2141730.016615609.500000577.000000642.0000000.000000609.500000577.000000642.0000000.0049371.0000000.0000001.0000000.0000004.0000000.000000-2.7858270.016615
360.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001002-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
370.000000-8.40564319.225323691.000000664.000000718.0000000.500000664.000000664.000000664.0000000.001766-0.5000002.121320-9.50000014.8492422.0000002.828427-0.4056430.573666
380.0000002.5101880.017700601.000000541.000000661.0000000.000000601.000000541.000000661.0000000.0068951.0000000.0000001.0000000.0000004.0000000.000000-3.4898120.017700
390.0000001.1439301.851457676.000000634.000000718.0000000.500000634.000000634.000000634.0000000.0033301.0000000.0000001.0000000.0000001.0000004.242640-1.8560702.391184
400.0000000.3059450.040185385.500000260.000000511.0000000.000000385.500000260.000000511.0000000.0219961.0000000.0000001.0000000.0000004.0000000.000000-5.6940550.040185
410.000000-2.3859270.019569435.000000378.000000492.0000000.000000435.000000378.000000492.0000000.0040621.0000000.0000001.0000000.0000004.0000000.000000-8.3859270.019569
420.000000-3.9649930.000042625.000000532.000000718.0000000.500000532.000000532.000000532.0000000.0075711.0000000.0000001.0000000.000000-6.0000000.0000000.0350070.000042
430.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001561-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
440.000000-3.9565340.000491718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0011981.0000000.0000001.0000000.000000-6.0000000.0000000.0434660.000490
450.000000-3.9730950.000793718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0013381.0000000.0000001.0000000.000000-6.0000000.0000000.0269040.000793
460.000000-3.9761560.033721718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0011701.0000000.0000001.0000000.000000-1.0000007.071068-4.9761567.104789
470.000000-0.7798470.030023598.000000478.000000718.0000000.500000478.000000478.000000478.0000000.0037631.0000000.0000001.0000000.0000004.0000000.000000-6.7798470.030023
480.000000-0.4001165.054048587.000000544.000000630.0000000.000000587.000000544.000000630.0000000.0024841.0000000.0000001.0000000.000000-1.0000007.071068-1.4001162.017020
490.000000-0.3437055.124783487.000000256.000000718.0000000.500000256.000000256.000000256.0000000.0205601.0000000.0000001.0000000.000000-1.0000007.071068-1.3437061.946285
500.0000000.3490976.115803524.000000330.000000718.0000000.500000330.000000330.000000330.0000000.0115211.0000000.0000001.0000000.000000-1.0000007.071068-0.6509030.955265
510.000000-3.9599160.001149718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0013241.0000000.0000001.0000000.000000-6.0000000.0000000.0400840.001149
520.000000-0.3867215.073168651.000000584.000000718.0000000.500000584.000000584.000000584.0000000.0041811.0000000.0000001.0000000.000000-1.0000007.071068-1.3867211.997900
530.000000-12.98180712.753652718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001321-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0181940.025730
540.000000-3.9624460.002950718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0012481.0000000.0000001.0000000.000000-6.0000000.0000000.0375540.002950
550.000000-8.97693218.417400718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001200-0.5000002.121320-9.50000014.8492422.0000002.828427-0.9769321.381590
560.000000-8.29010819.388716718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.000950-0.5000002.121320-9.50000014.8492422.0000002.828427-0.2901080.410275
570.0000001.5581850.646650337.500000222.000000453.0000000.000000337.500000222.000000453.0000000.0082591.0000000.0000001.0000000.0000004.0000000.000000-4.4418150.646650
580.000000-1.9318022.792675718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0010271.0000000.0000001.0000000.0000001.0000004.242640-4.9318027.035316
590.000000-4.0000000.000000674.500000631.000000718.0000000.500000631.000000631.000000631.0000000.0032881.0000000.0000001.0000000.0000004.0000000.000000-10.0000000.000000
600.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001974-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
610.000000-8.84890318.598459673.500000629.000000718.0000000.500000629.000000629.000000629.0000000.001706-0.5000002.121320-9.50000014.8492422.0000002.828427-0.8489031.200530
620.0000004.0908080.014869707.500000697.000000718.0000000.500000697.000000697.000000697.0000000.0009901.0000000.0000001.0000000.0000004.0000000.000000-1.9091920.014869
630.000000-11.09183415.426476678.000000638.000000718.0000000.500000638.000000638.000000638.0000000.002370-0.5000002.121320-9.50000014.8492422.0000002.828427-3.0918344.372514
640.0000000.8162416.788723504.000000398.000000610.0000000.000000504.000000398.000000610.0000000.0090331.0000000.0000001.0000000.000000-1.0000007.071068-0.1837590.282345
650.000000-12.97128512.768532639.500000561.000000718.0000000.500000561.000000561.000000561.0000000.004788-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0287150.040609
660.000000-3.9789780.000921718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0011021.0000000.0000001.0000000.000000-6.0000000.0000000.0210220.000921
670.0000000.5484996.408890718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0009451.0000000.0000001.0000000.000000-1.0000007.071068-0.4515010.662178
680.000000-9.64760417.468925570.500000423.000000718.0000000.500000423.000000423.000000423.0000000.025197-0.5000002.121320-9.50000014.8492422.0000002.828427-1.6476042.330064
690.000000-10.83281115.792789559.500000401.000000718.0000000.500000401.000000401.000000401.0000000.038960-0.5000002.121320-9.50000014.8492422.0000002.828427-2.8328114.006200
700.000000-22.0000000.000000690.500000663.000000718.0000000.500000663.000000663.000000663.0000000.004275-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
710.000000-12.98364112.751058465.500000213.000000718.0000000.500000213.000000213.000000213.0000000.048212-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0163600.023136
720.000000-12.98583012.747961718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001176-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0141700.020039
730.0001001.2938830.501316506.500000295.000000718.0000000.500000295.000000295.000000295.0000000.0863801.0000000.0000001.0000000.0000004.0000000.000000-4.7061170.501316
740.000000-12.63752213.240543587.000000486.000000688.0000000.000000587.000000486.000000688.0000000.041948-0.5000002.121320-9.50000014.8492422.0000002.828427-4.6375216.558445
750.000000-8.19532119.522764644.000000570.000000718.0000000.500000570.000000570.000000570.0000000.018705-0.5000002.121320-9.50000014.8492422.0000002.828427-0.1953210.276226
760.000000-9.50619717.668905718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001101-0.5000002.121320-9.50000014.8492422.0000002.828427-1.5061972.130084
770.000000-9.08523918.264231641.500000565.000000718.0000000.500000565.000000565.000000565.0000000.038641-0.5000002.121320-9.50000014.8492422.0000002.828427-1.0852401.534761
780.0000004.2897120.292143683.000000648.000000718.0000000.500000648.000000648.000000648.0000000.0088021.0000000.0000001.0000000.0000004.0000000.000000-1.7102880.292143
790.000000-12.98687512.746484718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001148-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0131250.018562
800.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001387-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
810.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.000819-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
820.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001463-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
830.000000-13.00000012.727922662.500000607.000000718.0000000.500000607.000000607.000000607.0000000.027296-0.5000002.121320-9.50000014.8492422.0000002.828427-5.0000007.071068
840.000100-8.06907719.701302584.000000450.000000718.0000000.500000450.000000450.000000450.0000000.104870-0.5000002.121320-9.50000014.8492422.0000002.828427-0.0690760.097689
850.000200-9.36398317.870026569.000000420.000000718.0000000.500000420.000000420.000000420.0000000.166438-0.5000002.121320-9.50000014.8492422.0000002.828427-1.3639831.928963
860.000300-13.00000012.727922527.500000337.000000718.0000000.500000337.000000337.000000337.0000000.278213-0.5000002.121320-9.50000014.8492422.0000002.828427-5.0000007.071068
870.000300-0.1127575.169596457.000000196.000000718.0000000.500000196.000000196.000000196.0000000.3259311.0000000.0000001.0000000.000000-1.0000007.071068-1.1127571.901471
880.000200-3.6348850.447414587.000000456.000000718.0000000.500000456.000000456.000000456.0000000.1997671.0000000.0000001.0000000.000000-1.0000007.071068-4.6348856.623653
890.0004000.8717886.865792508.000000298.000000718.0000000.500000298.000000298.000000298.0000000.3636101.0000000.0000001.0000000.000000-1.0000007.071068-0.1282120.205277
900.000000-4.0429860.094284718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0012591.0000000.0000001.0000000.000000-6.0000000.000000-0.0429860.094284
910.000000-12.98354612.751191718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001957-0.5000002.121320-9.50000014.849242-3.0000004.2426400.0164540.023269
920.000000-9.23971018.045776718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001780-0.5000002.121320-9.50000014.8492422.0000002.828427-1.2397101.753215
930.000300-0.6286194.722605554.000000390.000000718.0000000.500000390.000000390.000000390.0000000.3127741.0000000.0000001.0000000.000000-1.0000007.071068-1.6286192.348463
940.000000-8.35652719.294785692.500000667.000000718.0000000.500000667.000000667.000000667.0000000.015856-0.5000002.121320-9.50000014.8492422.0000002.828427-0.3565270.504206
950.0000000.8195536.786077710.000000702.000000718.0000000.500000702.000000702.000000702.0000000.0022371.0000000.0000001.0000000.000000-1.0000007.071068-0.1804470.284991
960.0000005.8887160.034997718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.0010131.0000000.0000001.0000000.0000004.0000000.000000-0.1112840.034997
970.0004001.6104860.819715558.000000398.000000718.0000000.500000398.000000398.000000398.0000000.3916051.0000000.0000001.0000000.0000004.0000000.000000-4.3895140.819715
980.000300-8.59123618.962856579.000000440.000000718.0000000.500000440.000000440.000000440.0000000.310268-0.5000002.121320-9.50000014.8492422.0000002.828427-0.5912360.836134
990.000000-22.0000000.000000718.000000718.000000718.0000001.0000000.0000000.0000000.0000000.001404-2.0000000.000000-20.0000000.0000000.0000000.0000000.0000000.000000
1000.000100-11.68190614.591989655.000000592.000000718.0000000.500000592.000000592.000000592.0000000.089281-0.5000002.121320-9.50000014.8492422.0000002.828427-3.6819065.207002

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Unsloth: Will smartly offload gradients to save VRAM!\n", "def matmul(A, B):\n", " # ensure dimensions\n", " m = len(A)\n", " n = len(A[0])\n", " p = len(B[0]) if B else 0\n", " return [[sum(A[i][k] * B[k][j] for k in range(n)) for j in range(p)] for i in range(m)]\n", "None\n", "def matmul(A, B):\n", " # A: m x k\n", " # B: k x n\n", " # returns m x n\n", " ...\n", "def matmul(A, B):\n", " # A: r x p, B: p x c\n", " r, p = len(A), len(A[0]) if A else 0\n", " p2, c = len(B), len(B[0]) if B else 0\n", " if r == 0 or c == 0 or p != p2:\n", " raise ValueError(\"Incompatible dimensions for multiplication\")\n", " # transpose B to improve locality\n", " B_T = list(zip(*B)) # c x p\n", " result = [[0] * c for _ in range(r)]\n", " for i in range(r):\n", " Ai = A[i]\n", " Ri = result[i]\n", " for j, Bj in enumerate(B_T):\n", " s = 0\n", " for k in range(p):\n", " s += Ai[k] * Bj[k]\n", " Ri[j] = s\n", " return result\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B.\n", "\n", " Parameters\n", " ----------\n", " A : list[list[float]]\n", " Left‑hand matrix with shape m x k.\n", " B : list[list[float]]\n", " Right‑hand matrix with shape k x n.\n", "\n", " Returns\n", " -------\n", " list[list[float]]\n", " Resulting matrix with shape m x n.\n", "\n", " Raises\n", " ------\n", " ValueError\n", " If dimensions are incompatible.\n", " \"\"\"\n", " if not A or not B:\n", " return []\n", "\n", " # Dimensions\n", " rows_a, cols_a = len(A), len(A[0])\n", " rows_b, cols_b = len(B), len(B[0])\n", "\n", " if cols_a != rows_b:\n", " raise ValueError(\"Incompatible dimensions for multiplication\")\n", "\n", " # Convert B into columns for faster row‑by‑row multiplication\n", " B_t = [list(col) for col in zip(*B)]\n", "\n", " # Compute each row of the result\n", " result = [\n", " [sum(a * b for a, b in zip(row_a, col_b)) for col_b in B_t]\n", " for row_a in A\n", " ]\n", "\n", " return result\n", "None\n", "def matmul(A, B):\n", " \"\"\"Compute the matrix product of A and B using pure Python.\"\"\"\n", " # Quick sanity checks\n", " m, n = len(A), len(A[0])\n", " p, q = len(B), len(B[0])\n", " assert n == p, \"Number of columns of A must equal number of rows of B\"\n", " # Prepare the result matrix with zeros\n", " result = [[0] * q for _ in range(m)]\n", " # This variant loops over the outermost index that is likely to be cache-friendly\n", " # and prefetches the row of B so that we access its elements in order.\n", " for i in range(m):\n", " row_a = A[i]\n", " for k in range(n):\n", " aik = row_a[k]\n", " if aik: # skip the zero case to save work\n", " row_b = B[k]\n", " for j in range(q):\n", " result[i][j] += aik * row_b[j]\n", " return result\n", "def matmul(A, B):\n", " m = len(A)\n", " if m==0: return []\n", " n = len(A[0]) or 0\n", " p = len(B[0]) if B else 0\n", " # ensure B has n rows\n", " # Use list comprehension summing product over k\n", " # Compute B transposed for column-wise access: BT = list(zip(*B))\n", " BT = list(zip(*B))\n", " return [[sum(a*b for a,b in zip(row,col)) for col in BT] for row in A]\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices where each matrix is represented as a list of lists\n", " and the elements are integers or floats.\n", "\n", " Parameters\n", " ----------\n", " A : list[list[int|float]]\n", " Left‑hand matrix of size (m × p).\n", " B : list[list[int|float]]\n", " Right‑hand matrix of size (p × n).\n", "\n", " Returns\n", " -------\n", " list[list[int|float]]\n", " The product matrix C = A @ B of size (m × n).\n", " \"\"\"\n", "\n", " # Dimensions: A: m×p, B: p×n\n", " m, p = len(A), len(A[0])\n", " p2, n = len(B), len(B[0])\n", " if p != p2:\n", " raise ValueError(\"Inner dimensions must agree for matrix multiplication\")\n", "\n", " # Pre‑transpose B so that column access is contiguous\n", " # This reduces random memory access during the dot product\n", " B_T = [list(col) for col in zip(*B)]\n", "\n", " result = []\n", " for i, row in enumerate(A):\n", " # Compute each entry of the i‑th row of the result\n", " result_row = []\n", " for j, col in enumerate(B_T):\n", " # use a local variable for speed\n", " dot = 0\n", " for a, b in zip(row, col):\n", " dot += a * b\n", " result_row.append(dot)\n", " result.append(result_row)\n", "\n", " return result\n", "None\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B using plain Python.\n", " A: list of lists (m × n), B: list of lists (n × p).\n", " Returns the product (m × p) as a new list of lists.\n", " \"\"\"\n", " if not A or not B:\n", " return []\n", "\n", " m = len(A)\n", " n = len(A[0])\n", " p = len(B[0])\n", "\n", " # Quick compatibility check\n", " assert len(B) == n, \"Incompatible matrix dimensions\"\n", "\n", " # Allocate result matrix\n", " result = [[0] * p for _ in range(m)]\n", "\n", " # Standard triple‑loop multiplication, with a small speed‑up:\n", " # pull outer indices, cache row and column values locally,\n", " # and skip inner loop when the coefficient is zero.\n", " for i in range(m):\n", " Ai = A[i]\n", " Ri = result[i]\n", " for k in range(n):\n", " aik = Ai[k]\n", " if aik: # skip zero entries\n", " Bk = B[k]\n", " for j in range(p):\n", " Ri[j] += aik * Bk[j]\n", " return result\n", "def matmul(A, B):\n", " # number of rows in A\n", " m = len(A)\n", " # number of columns in B\n", " p = len(B[0]) if B else 0\n", " # number of columns in A (used as number of rows in B)\n", " k = len(A[0]) if A else 0\n", "\n", " # If shapes are incompatible, raise an error\n", " if len(B) != k:\n", " raise ValueError(\"Incompatible matrices: A.shape[1] != B.shape[0]\")\n", "\n", " # Result matrix initialised with zeros\n", " result = [[0] * p for _ in range(m)]\n", "\n", " # Triple‑loop multiplication\n", " for i in range(m):\n", " for j in range(p):\n", " # Compute the dot product of row i of A and column j of B\n", " s = 0\n", " for t in range(k):\n", " s += A[i][t] * B[t][j]\n", " result[i][j] = s\n", " return result\n", "None\n", "def matmul(A, B):\n", " # Determine dimensions\n", " m, n = len(A), len(A[0])\n", " nB, p = len(B), len(B[0])\n", " assert n == nB, \"Incompatible matrices\"\n", " result = [[0]*p for _ in range(m)]\n", " # Multiply\n", " for i in range(m):\n", " Ai = A[i]\n", " for k in range(n):\n", " aik = Ai[k]\n", " if aik:\n", " Bk = B[k]\n", " for j in range(p):\n", " result[i][j] += aik * Bk[j]\n", " return result\n", "def matmul(A, B):\n", " \"\"\"Fast matrix product using only native Python.\n", "\n", " Args:\n", " A (List[List[Number]]): left matrix (n × m)\n", " B (List[List[Number]]): right matrix (m × p)\n", "\n", " Returns:\n", " List[List[Number]]: the product A * B (n × p)\n", "\n", " Raises:\n", " ValueError: if the matrices cannot be multiplied\n", " \"\"\"\n", " # Basic shape checks\n", " if not A or not A[0] or not B or not B[0]:\n", " raise ValueError(\"Input matrices must be non‑empty\")\n", " n, m = len(A), len(A[0])\n", " if m != len(B):\n", " raise ValueError(\"Number of columns of A must equal number of rows of B\")\n", "\n", " # Transpose B once for better locality\n", " B_T = list(zip(*B)) # now each element of B_T is a tuple representing a column\n", "\n", " # Compute the product\n", " result = []\n", " for row in A:\n", " # dot(row, col) for each column\n", " result.append([sum(a * b for a, b in zip(row, col)) for col in B_T])\n", "\n", " return result\n", "None\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " ...\n", "def matmul(A, B):\n", " n = len(A)\n", " m = len(B[0])\n", " k = len(B)\n", " result = [[0]*m for _ in range(n)]\n", " for i in range(n):\n", " for j in range(m):\n", " s = 0\n", " for l in range(k):\n", " s += A[i][l] * B[l][j]\n", " result[i][j] = s\n", " return result\n", "None\n", "None\n", "def matmul(A, B):\n", " m = len(A)\n", " n = len(B[0])\n", " p = len(B)\n", " return [[sum(A[i][k] * B[k][j] for k in range(p)) for j in range(n)] for i in range(m)]\n", "None\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " import math\n", " # Assume square matrices of same size and power of 2\n", " n = len(A)\n", " if n == 0:\n", " return []\n", " def add(X, Y):\n", " return [[X[i][j] + Y[i][j] for j in range(n)] for i in range(n)]\n", " def sub(X, Y):\n", " return [[X[i][j] - Y[i][j] for j in range(n)] for i in range(n)]\n", " def split(M):\n", " k = n // 2\n", " return [ [row[:k] for row in M[:k]],\n", " [row[:k] for row in M[k:]],\n", " [row[k:] for row in M[:k]],\n", " [row[k:] for row in M[k:]] ]\n", " def combine(A11, A12, A21, A22):\n", " k = len(A11)\n", " result = [ [0]* (k*2) for _ in range(k*2) ]\n", " for i in range(k):\n", " result[i][:k] = A11[i]\n", " result[i][k:] = A12[i]\n", " result[i+k][:k] = A21[i]\n", " result[i+k][k:] = A22[i]\n", " return result\n", " def strassen(X, Y):\n", " if n == 1:\n", " return [[X[0][0]*Y[0][0]]]\n", " a, b, c, d = split(X)\n", " e, f, g, h = split(Y)\n", " p1 = strassen(a, sub(f, h))\n", " p2 = strassen(add(a, b), h)\n", " p3 = strassen(add(c, d), e)\n", " p4 = strassen(d, sub(g, e))\n", " p5 = strassen(add(a, d), add(e, h))\n", " p6 = strassen(sub(b, d), add(g, h))\n", " p7 = strassen(sub(a, c), add(e, f))\n", " c11 = add(sub(add(p5, p4), p2), p6)\n", " c12 = add(p1, p2)\n", " c21 = add(p3, p4)\n", " c22 = sub(sub(add(p1, p5), p3), p7)\n", " return combine(c11, c12, c21, c22)\n", " return strassen(A, B)\n", "def matmul(A, B):\n", " n = len(A)\n", " m = len(A[0])\n", " p = len(B[0])\n", " assert len(B) == m\n", " BT = list(zip(*B)) # transposed as tuples\n", " C = [[sum(Ai[k] * BTj[k] for k in range(m)) for BTj in BT] for Ai in A]\n", " return C\n", "def matmul(A, B):\n", " n = len(A)\n", " m = len(B[0])\n", " p = len(B)\n", " # transpose B\n", " B_T = list(map(list, zip(*B))) # list of columns\n", " return [[sum(a*b for a,b in zip(row, col)) for col in B_T] for row in A]\n", "None\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " ...\n", "def matmul(A, B):\n", " \"\"\"Matrix multiplication using only native Python (no external libraries).\n", "\n", " Works for arbitrary sized matrices with compatible dimensions.\n", " The algorithm transposes matrix B to enhance cache locality,\n", " then uses a list‑comprehension to calculate the dot–product of\n", " corresponding rows and columns.\n", "\n", " Parameters\n", " ----------\n", " A : list[list[Number]]\n", " Left matrix of shape (m, n).\n", " B : list[list[Number]]\n", " Right matrix of shape (n, p).\n", "\n", " Returns\n", " -------\n", " list[list[Number]]\n", " Resulting product matrix of shape (m, p).\n", " \"\"\"\n", " # Transpose B for efficient column access\n", " B_transposed = list(zip(*B)) # tuples, one per column of B\n", " return [\n", " [\n", " # dot product of row from A and column from B\n", " sum(a * b for a, b in zip(row_a, col_b))\n", " for col_b in B_transposed\n", " ]\n", " for row_a in A\n", " ]\n", "None\n", "None\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " # check dimension matches\n", " nrows_a = len(A)\n", " ncols_a = len(A[0]) if A else 0\n", " nrows_b = len(B)\n", " ncols_b = len(B[0]) if B else 0\n", " if ncols_a != nrows_b:\n", " raise ValueError(\"Incompatible dimensions\")\n", " # transpose B for cache locality\n", " BT = list(zip(*B)) # tuple of tuples used as rows\n", " result = [[0]*ncols_b for _ in range(nrows_a)]\n", " for i in range(nrows_a):\n", " ai = A[i]\n", " ri = result[i]\n", " for j in range(ncols_b):\n", " s = 0\n", " bj = BT[j]\n", " for k in range(ncols_a):\n", " s += ai[k] * bj[k]\n", " ri[j] = s\n", " return result\n", "def matmul(A, B):\n", " B_T = list(zip(*B))\n", " res = [[sum(a*b for a,b in zip(row, col)) for col in B_T] for row in A]\n", " return res\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two square matrices A and B.\n", "\n", " Parameters\n", " ----------\n", " A : list[list[float]]\n", " First n × n matrix.\n", " B : list[list[float]]\n", " Second n × n matrix.\n", "\n", " Returns\n", " -------\n", " list[list[float]]\n", " The product matrix C = A @ B.\n", " \"\"\"\n", " # Transpose B once for O(1) column access\n", " B_T = list(zip(*B))\n", "\n", " # Compute C[i][j] = dot(A[i], B_T[j])\n", " return [[sum(a * b for a, b in zip(row, col))\n", " for col in B_T]\n", " for row in A]\n", "def matmul(A, B): return ...\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B using native Python lists.\n", " `A` and `B` must be rectangular (i.e. all rows the same length).\n", "\n", " Returns a new matrix containing the product.\n", " Raises ValueError if the inner dimensions do not match.\n", " \"\"\"\n", " if not A or not B:\n", " return []\n", "\n", " n_rows_A, n_cols_A = len(A), len(A[0])\n", " n_rows_B, n_cols_B = len(B), len(B[0])\n", "\n", " if n_cols_A != n_rows_B:\n", " raise ValueError(\"cannot multiply: inner dimensions do not match\")\n", "\n", " # preallocate result matrix\n", " result = [[0] * n_cols_B for _ in range(n_rows_A)]\n", "\n", " for i in range(n_rows_A):\n", " # Local references for speed\n", " row_a = A[i]\n", " for k in range(n_cols_A):\n", " aik = row_a[k]\n", " if aik == 0:\n", " continue # skip zero multiplications\n", " row_b = B[k]\n", " for j in range(n_cols_B):\n", " result[i][j] += aik * row_b[j]\n", "\n", " return result\n", "None\n", "def matmul(A, B):\n", " return ...\n", "None\n", "None\n", "None\n", "def matmul(A, B):\n", " return ...\n", "None\n", "None\n", "None\n", "None\n", "def matmul(A, B):\n", " \"\"\"Return the matrix product of A and B.\n", "\n", " A must be an `n×m` matrix, B an `m×p` matrix.\n", " Matrices are represented as nested lists of numbers.\n", " \"\"\"\n", " if not A or not B:\n", " return []\n", "\n", " # Transpose B once so we can iterate over rows efficiently\n", " B_T = list(zip(*B))\n", "\n", " # Compute each entry using a dot‑product of corresponding rows\n", " return [\n", " [sum(a * b for a, b in zip(row_a, col_b)) for col_b in B_T]\n", " for row_a in A\n", " ]\n", "None\n", "def matmul(A, B):\n", " # A is m x p, B is p x n\n", " m = len(A)\n", " p = len(A[0]) # find p\n", " n = len(B[0])\n", " # create result m x n\n", " result = [... for i in ...]\n", "def matmul(A, B):\n", " m, p = len(A), len(A[0]) if A else 0\n", " p2, n = len(B), len(B[0]) if B else 0\n", " if p != p2: raise ValueError(\"Incompatible dimensions\")\n", " # Precompute transpose of B for cache-friendly access\n", " Bt = list(zip(*B))\n", " return [[sum(a * b for a, b in zip(row, col)) for col in Bt] for row in A]\n", "def matmul(A, B): return ...\n", "def matmul(A, B):\n", " return ...\n", "None\n", "def matmul(A, B):\n", " return ...\n", "None\n", "def matmul(A, B): return ...\n", "None\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A (m×n) and B (n×p) given as lists of lists.\n", " Uses only native Python code and is tuned for speed by transposing B.\n", " \"\"\"\n", " if not A:\n", " return []\n", "\n", " n_rows_A, n_cols_A = len(A), len(A[0])\n", " # Basic consistency check – assume all rows have equal length\n", " # and B has compatible dimensions.\n", " n_rows_B, n_cols_B = len(B), len(B[0]) if B else 0\n", "\n", " # Transpose B to access its columns as tuples (faster indexing)\n", " B_T = list(zip(*B)) # shape: (p × n)\n", "\n", " result = []\n", " for row_A in A: # iterate over rows of A\n", " # Compute the dot product of row_A with each column of B\n", " res_row = [sum(a * b for a, b in zip(row_A, col_B)) for col_B in B_T]\n", " result.append(res_row)\n", "\n", " return result\n", "None\n", "None\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B (lists of lists) using native Python.\n", " The matrices are assumed to be square and of compatible dimensions.\n", " \"\"\"\n", " n = len(A)\n", " # Transpose B to improve cache locality for the inner sum\n", " B_T = list(zip(*B)) # each element is a tuple\n", " return [[sum(a * b for a, b in zip(row, col))\n", " for col in B_T]\n", " for row in A]\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B using only native Python code.\n", "\n", " Parameters\n", " ----------\n", " A : list of lists (m x k)\n", " B : list of lists (k x n)\n", "\n", " Returns\n", " -------\n", " C : list of lists (m x n)\n", " Product matrix such that C[i][j] = sum(A[i][p] * B[p][j] for p in range(k))\n", "\n", " Raises\n", " ------\n", " ValueError\n", " If the number of columns in A does not equal the number of rows in B.\n", " \"\"\"\n", " if not A or not B:\n", " raise ValueError(\"Both matrices must be non‑empty.\")\n", " m, k1 = len(A), len(A[0])\n", " k2, n = len(B), len(B[0])\n", " if k1 != k2:\n", " raise ValueError(f\"Incompatible shapes: {m}x{k1} multiplied by {k2}x{n}\")\n", " # Pre‑compute columns of B for faster access\n", " B_cols = list(zip(*B)) # n tuples each of length k\n", " # Compute the product\n", " return [[sum(a * b for a, b in zip(row, col))\n", " for col in B_cols] for row in A]\n", "def matmul(A, B):\n", " \"\"\"\n", " Fast matrix multiplication for plain Python objects.\n", " A : list of m rows, each a list of n numbers\n", " B : list of n rows, each a list of p numbers\n", " Returns a new matrix of shape m × p.\n", " \"\"\"\n", " m = len(A)\n", " if m == 0:\n", " return []\n", " n = len(A[0])\n", " # Verify dimension compatibility\n", " if len(B) != n or any(len(row) != n for row in A):\n", " raise ValueError(\"Inner dimensions must agree.\")\n", " p = len(B[0])\n", "\n", " # Allocate result matrix\n", " result = [[0] * p for _ in range(m)]\n", "\n", " # Standard triple-loop, optimized for speed in pure Python\n", " for i in range(m):\n", " row_a = A[i]\n", " row_res = result[i]\n", " for k in range(n):\n", " a_val = row_a[k]\n", " row_b = B[k]\n", " # Use local variables for performance\n", " for j in range(p):\n", " row_res[j] += a_val * row_b[j]\n", "\n", " return result\n", "def matmul(A, B):\n", " \"\"\"Multiply two matrices A and B.\n", "\n", " A: list of m rows, each containing n elements.\n", " B: list of n rows, each containing p elements.\n", " Returns a new list of list representing the product matrix of shape (m, p).\n", " \"\"\"\n", " # Basic sanity checks\n", " if not A or not B:\n", " raise ValueError(\"Input matrices must be non‑empty.\")\n", " m, n = len(A), len(A[0])\n", " nB, p = len(B), len(B[0])\n", " if n != nB:\n", " raise ValueError(\"Number of columns of A must equal number of rows of B.\")\n", " for row in A:\n", " if len(row) != n:\n", " raise ValueError(\"All rows of A must have the same length.\")\n", " for row in B:\n", " if len(row) != p:\n", " raise ValueError(\"All rows of B must have the same length.\")\n", "\n", " # Allocate the result matrix (m x p) initialized to 0\n", " result = [[0] * p for _ in range(m)]\n", "\n", " # Perform multiplication\n", " for i in range(m):\n", " rowA = A[i]\n", " rowR = result[i]\n", " for k in range(n):\n", " aik = rowA[k]\n", " if aik: # Skip work for zero multiplication\n", " rowB = B[k]\n", " for j in range(p):\n", " rowR[j] += aik * rowB[j]\n", " return result\n", "None\n", "None\n", "def matmul(A, B):\n", " # Validate dimensions\n", " if not A or not B:\n", " return []\n", " if len(A[0]) != len(B):\n", " raise ValueError(\"Matrix dimensions do not match for multiplication.\")\n", " B_cols = list(zip(*B)) # transpose B\n", " result = [[sum(a*b for a,b in zip(row, col)) for col in B_cols] for row in A]\n", " return result\n", "None\n", "def matmul(A, B):\n", " \"\"\"Multiply matrices A × B (list‑of‑list format) using an optimized pure‑Python routine.\"\"\"\n", " n, p = len(A), len(A[0]) # rows of A, columns of B (must match)\n", " m = len(B[0]) # columns of B\n", " # Result matrix initialized with zeros\n", " result = [[0] * m for _ in range(n)]\n", " for i in range(n):\n", " rowA = A[i]\n", " rowR = result[i]\n", " for k in range(p):\n", " aik = rowA[k]\n", " if aik: # skip zero entries for a small extra speedup\n", " rowBk = B[k]\n", " for j in range(m):\n", " rowR[j] += aik * rowBk[j]\n", " return result\n", "def matmul(A, B):\n", " \"\"\"Fast matrix multiplication using only native Python code.\n", "\n", " Parameters\n", " ----------\n", " A : list[list[float]]\n", " Left hand matrix of shape (n, m).\n", " B : list[list[float]]\n", " Right hand matrix of shape (m, q).\n", "\n", " Returns\n", " -------\n", " list[list[float]]\n", " The product matrix of shape (n, q).\n", "\n", " Notes\n", " -----\n", " The routine pre-allocates the result matrix and uses local variable\n", " bindings to reduce attribute look‑ups inside the innermost loop,\n", " which gives a noticeable speed boost for large matrices.\n", " \"\"\"\n", " # Basic dimensions checking\n", " n, m = len(A), len(A[0])\n", " m2, q = len(B), len(B[0])\n", " if m != m2:\n", " raise ValueError(\"Number of columns in A must equal number of rows in B.\")\n", "\n", " # Pre‑allocate result matrix with zeros\n", " result = [[0.0] * q for _ in range(n)]\n", "\n", " # Perform multiplication\n", " for i in range(n):\n", " rowA = A[i]\n", " rowR = result[i]\n", " for k in range(m):\n", " aik = rowA[k]\n", " rowB = B[k]\n", " for j in range(q):\n", " rowR[j] += aik * rowB[j]\n", " return result\n", "def matmul(A, B):\n", " B_T = list(zip(*B)) # transpose B for inner product\n", " return [[sum(a*b for a,b in zip(row, col)) for col in B_T] for row in A]\n", "def matmul(A, B):\n", " # A: m x n, B: n x p\n", " m, n = len(A), len(A[0])\n", " n2, p = len(B), len(B[0])\n", " assert n==n2\n", " # initialize result matrix\n", " C = [[0]*p for _ in range(m)]\n", " # transpose B for better locality\n", " B_T = list(zip(*B))\n", " for i in range(m):\n", " Ai = A[i]\n", " Ci = C[i]\n", " for k in range(n):\n", " aik = Ai[k]\n", " Bk = B_T[k]\n", " for j in range(p):\n", " Ci[j] += aik * Bk[j]\n", " return C\n", "def matmul(A, B):\n", " \"\"\"Fast matrix multiplication using pure Python.\n", "\n", " Arguments:\n", " A: List of lists, the left matrix of size m x n.\n", " B: List of lists, the right matrix of size n x p.\n", "\n", " Returns a new matrix C of size m x p where C[i][j] = sum(A[i][k] * B[k][j] for k in range(n)).\n", "\n", " This implementation transposes B once to enable efficient column access\n", " and uses nested list comprehensions together with the built‑in `sum`\n", " function, which is implemented in C.\n", " \"\"\"\n", " # Verify dimensions\n", " if not A or not B:\n", " raise ValueError(\"Matrices cannot be empty\")\n", " n = len(A[0])\n", " if any(len(row) != n for row in A):\n", " raise ValueError(\"All rows in A must have the same length\")\n", " if len(B) != n:\n", " raise ValueError(\"Number of columns in A must equal number of rows in B\")\n", " p = len(B[0])\n", " if any(len(row) != p for row in B):\n", " raise ValueError(\"All rows in B must have the same length\")\n", "\n", " # Transpose B to get columns as rows\n", " B_cols = list(zip(*B))\n", "\n", " # Compute product\n", " return [[sum(a * b for a, b in zip(row, col)) for col in B_cols] for row in A]\n", "def matmul(A, B):\n", " BT = list(zip(*B))\n", " return [[sum(a*b for a,b in zip(row, col)) for col in BT] for row in A]\n", "def matmul(A, B):\n", " \"\"\"Multiplies matrix A by matrix B using pure native Python.\n", "\n", " Parameters\n", " ----------\n", " A : list[list[float]]\n", " The first matrix of shape (l, m).\n", " B : list[list[float]]\n", " The second matrix of shape (m, n).\n", "\n", " Returns\n", " -------\n", " list[list[float]]\n", " The product matrix of shape (l, n).\n", " \"\"\"\n", " # Pre‑get dimensions for speed\n", " l = len(A) # Number of rows in A\n", " m = len(A[0]) # Number of columns in A / rows in B\n", " n = len(B[0]) # Number of columns in B\n", "\n", " # The result will have shape (l, n)\n", " result = [[0.0] * n for _ in range(l)]\n", "\n", " for i in range(l):\n", " Ai = A[i]\n", " for k in range(m):\n", " aik = Ai[k]\n", " Bk = B[k]\n", " # Unroll column loop to reduce attribute lookups\n", " for j in range(n):\n", " result[i][j] += aik * Bk[j]\n", "\n", " return result\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A (n x m) and B (m x p) using only native Python.\n", " Returns the resulting matrix as a list of lists.\n", " \"\"\"\n", " n = len(A)\n", " m = len(A[0])\n", " if len(B) != m:\n", " raise ValueError(\"Number of columns in A must equal number of rows in B\")\n", " p = len(B[0])\n", "\n", " # Pre‑allocate result matrix with zeros\n", " result = [[0] * p for _ in range(n)]\n", "\n", " for i in range(n):\n", " row_a = A[i]\n", " for k in range(m):\n", " aik = row_a[k] # A[i][k]\n", " row_b = B[k] # The k-th row of B\n", " for j in range(p):\n", " result[i][j] += aik * row_b[j] # C[i][j] += A[i][k] * B[k][j]\n", " return result\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " return ...\n", "None\n", "def matmul(A, B):\n", " # assume A dims m x n, B n x p\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " # Should handle maybe rectangular matrices\n", " ...\n", "def matmul(A, B): return ...\n", "def matmul(A, B): return ...\n", "def matmul(A, B):\n", " result = []\n", " for i in range(len(A)):\n", " res_row = []\n", " for j in range(len(B[0])):\n", " sum_val = 0\n", " for k in range(len(B)):\n", " sum_val += A[i][k] * B[k][j]\n", " res_row.append(sum_val)\n", " result.append(res_row)\n", " return result\n", "def matmul(A, B): return ...\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices represented as lists of lists using only\n", " standard Python code. Raises a ValueError if the matrices cannot\n", " be multiplied.\n", " \n", " Parameters\n", " ----------\n", " A : List[List[Number]]\n", " The left-hand-side matrix of shape (m, n).\n", " B : List[List[Number]]\n", " The right-hand-side matrix of shape (n, p).\n", "\n", " Returns\n", " -------\n", " C : List[List[Number]]\n", " The product matrix of shape (m, p).\n", " \"\"\"\n", " if not A or not B:\n", " raise ValueError(\"Empty matrices cannot be multiplied\")\n", "\n", " # Verify inner dimensions match\n", " n = len(A[0])\n", " for row in A:\n", " if len(row) != n:\n", " raise ValueError(\"All rows of A must have the same length\")\n", " if len(B) != n:\n", " raise ValueError(\"Number of columns in A must equal number of rows in B\")\n", "\n", " p = len(B[0])\n", " for row in B:\n", " if len(row) != p:\n", " raise ValueError(\"All rows of B must have the same length\")\n", "\n", " # Transpose B so that columns can be accessed as tuples\n", " B_T = list(zip(*B)) # Each element is a tuple of length n\n", "\n", " # Compute the product\n", " C = []\n", " for a_row in A:\n", " c_row = []\n", " for b_col in B_T:\n", " dot_product = sum(x * y for x, y in zip(a_row, b_col))\n", " c_row.append(dot_product)\n", " C.append(c_row)\n", "\n", " return C\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B represented as nested lists.\n", " \n", " Parameters:\n", " A (list[list[float]]): Matrix of size (m x n).\n", " B (list[list[float]]): Matrix of size (n x p).\n", " \n", " Returns:\n", " list[list[float]]: Resultant matrix of size (m x p).\n", " \"\"\"\n", " # Pre‑compute the columns of B\n", " B_cols = list(zip(*B))\n", " \n", " # Compute the product row by row\n", " return [\n", " [\n", " sum(a * b for a, b in zip(row, col))\n", " for col in B_cols\n", " ]\n", " for row in A\n", " ]\n", "def matmul(A, B):\n", " # Basic dimension check\n", " if not A or not B:\n", " return []\n", " n_rows_a = len(A)\n", " n_cols_a = len(A[0])\n", " n_rows_b = len(B)\n", " n_cols_b = len(B[0])\n", " if n_cols_a != n_rows_b:\n", " raise ValueError(\"Incompatible dimensions for matrix multiplication\")\n", " result = [[0] * n_cols_b for _ in range(n_rows_a)]\n", " for i in range(n_rows_a):\n", " for k in range(n_cols_a):\n", " aik = A[i][k]\n", " if aik == 0:\n", " continue\n", " for j in range(n_cols_b):\n", " result[i][j] += aik * B[k][j]\n", " return result\n", "def matmul(A, B): return ...\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " # A: m x n, B: n x p\n", " # returns m x p\n", " # check sizes\n", " m, n = len(A), len(A[0])\n", " assert len(B) == n\n", " p = len(B[0])\n", " # Precompute transpose of B\n", " Bt = list(zip(*B))\n", " res = [[sum(a*b for a,b in zip(row, col)) for col in Bt] for row in A]\n", " return res\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " \"\"\"Multiply two matrices A and B (list of lists) purely in plain Python.\"\"\"\n", " m, n = len(A), len(A[0])\n", " nB, p = len(B), len(B[0])\n", " if n != nB:\n", " raise ValueError(\"Inner matrix dimensions must agree.\")\n", " # Pre‑allocate result matrix\n", " C = [[0] * p for _ in range(m)]\n", " # Perform multiplication in the ctr order: i, k, j\n", " for i in range(m):\n", " rowA = A[i]\n", " rowC = C[i]\n", " for k in range(n):\n", " aik = rowA[k]\n", " if aik: # skip zero multiplications\n", " rowB = B[k]\n", " for j in range(p):\n", " rowC[j] += aik * rowB[j]\n", " return C\n", "def matmul(A, B): return ...\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " B_T = list(zip(*B))\n", " return [[sum(a*b for a,b in zip(row, col)) for col in B_T] for row in A]\n", "def matmul(A, B):\n", " return ...\n", "None\n", "def matmul(A, B):\n", " ...\n", "def matmul(A, B):\n", " ...\n", "def matmul(A, B):\n", " ...\n", "None\n", "def matmul(A, B):\n", " # assume A rows, B columns\n", " m, k1 = len(A), len(A[0]) if A else 0\n", " k2, n = len(B), len(B[0]) if B else 0\n", " assert k1 == k2, \"Inner dimensions must match\"\n", " B_T = list(zip(*B)) # transposed B\n", " result = [[sum(a*b for a,b in zip(row, col)) for col in B_T] for row in A]\n", " return result\n", "None\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B (both given as lists of lists) using\n", " only native Python constructs.\n", "\n", " The function first validates that the inner dimensions match (number of\n", " columns in A must equal the number of rows in B). Then it computes the\n", " product using a straightforward triple‑loop but expressed in a compact\n", " list‑comprehension. The columns of B are accessed by transposing B\n", " once via ``zip(*B)``, which avoids explicit indexing and gives good\n", " cache locality in Python.\n", "\n", " Parameters\n", " ----------\n", " A : List[List[Number]]\n", " Left‑hand matrix of size m × n.\n", "\n", " B : List[List[Number]]\n", " Right‑hand matrix of size n × p.\n", "\n", " Returns\n", " -------\n", " List[List[Number]]\n", " The matrix product of A and B, which has shape m × p.\n", "\n", " Raises\n", " ------\n", " ValueError\n", " If the matrices cannot be multiplied due to incompatible dimensions.\n", " \"\"\"\n", " # Validate dimensions\n", " if not A or not B:\n", " raise ValueError(\"Matrices cannot be empty\")\n", " n_cols_A = len(A[0])\n", " n_rows_B = len(B)\n", " if n_cols_A != n_rows_B:\n", " raise ValueError(\"Inner dimensions must match: \"\n", " f\"{n_cols_A} != {n_rows_B}\")\n", "\n", " # Transpose B once to make column access efficient\n", " B_cols = list(zip(*B))\n", "\n", " # Compute product using list comprehensions\n", " return [\n", " [sum(a * b for a, b in zip(row, col)) for col in B_cols]\n", " for row in A\n", " ]\n", "def matmul(A, B):\n", " if not A or not B:\n", " return []\n", " m, p = len(A), len(A[0])\n", " n = len(B[0])\n", " # ensure inner dimension matches\n", " result = [[0]*n for _ in range(m)]\n", " for i in range(m):\n", " for k in range(p):\n", " aik = A[i][k]\n", " for j in range(n):\n", " result[i][j] += aik * B[k][j]\n", " return result\n", "def matmul(A, B):\n", " n = len(A)\n", " m = len(B[0])\n", " p = len(B)\n", " return [[sum(A[i][k] * B[k][j] for k in range(p)) for j in range(m)] for i in range(n)]\n", "def matmul(A, B):\n", " # Basic check dims\n", " n = len(A)\n", " assert n > 0\n", " m = len(A[0])\n", " # B has size m x p\n", " assert len(B) == m\n", " p = len(B[0])\n", " # Pre-allocate result\n", " C = [[0]*p for _ in range(n)]\n", " # Compute transpose of B for better locality\n", " B_T = list(map(list, zip(*B)))\n", " for i in range(n):\n", " Ai = A[i]\n", " Ci = C[i]\n", " for j in range(p):\n", " Bj = B_T[j]\n", " s = 0\n", " for k in range(m):\n", " s += Ai[k] * Bj[k]\n", " Ci[j] = s\n", " return C\n", "def matmul(A, B):\n", " \"\"\"Return the product of two matrices in Theta(m^3/mn)*n^2 time,\n", " optimizing cache usage if --precache-multiple was selected.\n", " \"\"\"\n", " _check_input(A, B)\n", " m = len(A); n = len(A[0]); k = len(B[0]); # (m x n) * (n x k) => (m x k)\n", " _use_a = (m / n <= 100.0 * n / k)\n", " M_local_cache_avail_prefs = sympy.cache.get('M_local_cache_avail_prefs', 0)\n", " M_local_cache_avail_nonp = sympy.cache.get('M_local_cache_avail_nonp', 0)\n", "\n", " if sympy.cache and M_local_cache_avail_nonp and _use_a:\n", " local_cache_key_A = (m,n,1)\n", " local_cache_key_B = (n,k,1)\n", " if local_cache_key_A not in sympy.cache:\n", " sympy.cache[local_cache_key_A] = [(r,c) for r in range(m) for c in range(n)]\n", " if local_cache_key_B not in sympy.cache:\n", " sympy.cache[local_cache_key_B] = [(r,c) for r in range(n) for c in range(k)]\n", " A_local, B_local = sympy.cache[local_cache_key_A], sympy.cache[local_cache_key_B]\n", " else:\n", " A_local, B_local = None, None\n", "\n", " def fast_matmul(A_local, B_local, _):\n", " result = [[sum(A_local[i][j] * B_local[j][l] for j in range(n))\n", " for l in range(k)]\n", " for i in range(m)]\n", " return result\n", "\n", " if A_local and B_local:\n", " return fast_matmul(A_local, B_local, k)\n", " else:\n", " return fast_matmul(A, B, k)\n", "def matmul(A, B):\n", " n, m = len(A), len(B)\n", " assert all(len(row)==m for row in A) and all(len(row)==len(B[0]) for row in B)\n", " # maybe compute B transposed\n", " BT = list(zip(*B))\n", " result = [[sum(a*b for a,b in zip(row, col)) for col in BT] for row in A]\n", " return result\n", "def matmul(A, B):\n", " m = len(A)\n", " n = len(A[0])\n", " p = len(B[0])\n", " # compute result matrix C(m x p)\n", " result = [[sum(A[i][k]*B[k][j] for k in range(n)) for j in range(p)] for i in range(m)]\n", " return result\n", "None\n", "None\n", "None\n", "def matmul(A, B):\n", " n = len(A)\n", " m = len(B[0])\n", " p = len(B)\n", " result = [[0]*m for _ in range(n)]\n", " for i in range(n):\n", " Ai=A[i]\n", " Ri=result[i]\n", " for k in range(p):\n", " aik=Ai[k]\n", " Bk=B[k]\n", " for j in range(m):\n", " Ri[j]+=aik*Bk[j]\n", " return result\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A (n x m) and B (m x p) without using external libraries.\n", " \"\"\"\n", " n=len(A)\n", " m=len(A[0])\n", " p=len(B[0])\n", " # Check compatibility\n", " if len(B)!=m:\n", " raise ValueError(\"Incompatible dimensions\")\n", " # Transpose B for cache-friendly access\n", " B_T=[list(col) for col in zip(*B)]\n", " # Prepare result matrix\n", " result=[[0]*p for _ in range(n)]\n", " for i in range(n):\n", " row=A[i]\n", " # local assignments\n", " res_row=result[i]\n", " for j in range(p):\n", " col=B_T[j]\n", " s=0\n", " for a, b in zip(row, col):\n", " s += a*b\n", " res_row[j]=s\n", " return result\n", "def matmul(A, B):\n", " \"\"\"Multiply two matrices A (m x n) and B (n x p) using pure Python.\n", "\n", " Args:\n", " A: list of m lists, each of length n.\n", " B: list of n lists, each of length p.\n", "\n", " Returns:\n", " C: list of m lists, each of length p, the product.\n", " \"\"\"\n", " m = len(A); n = len(A[0]); p = len(B[0])\n", " # Preallocate result\n", " C = [[0]*p for _ in range(m)]\n", " # transpose B for better cache (although Python's memory model)\n", " B_T = [list(col) for col in zip(*B)]\n", " # iterate over rows of A and rows of B_T\n", " for i in range(m):\n", " Ai = A[i]\n", " Ci = C[i]\n", " for k in range(p):\n", " s = 0\n", " Bk = B_T[k]\n", " for j in range(n):\n", " s += Ai[j] * Bk[j]\n", " Ci[k] = s\n", " return C\n", "def matmul(A, B):\n", " # A: m x n, B: n x p\n", " m, n = len(A), len(A[0])\n", " n2, p = len(B), len(B[0])\n", " assert n == n2\n", " # Precompute columns of B\n", " Bcols = list(zip(*B))\n", " return [[sum(a*b for a,b in zip(row,c)) for c in Bcols] for row in A]\n", "None\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices using pure Python.\n", "\n", " Parameters\n", " ----------\n", " A : List[List[Number]]\n", " The first matrix (m × n).\n", " B : List[List[Number]]\n", " The second matrix (n × p).\n", "\n", " Returns\n", " -------\n", " List[List[Number]]\n", " The product matrix (m × p).\n", "\n", " Raises\n", " ------\n", " ValueError\n", " If the inner dimensions of A and B do not match.\n", " \"\"\"\n", " if not A or not B or not B[0]:\n", " return []\n", "\n", " m = len(A)\n", " n = len(A[0])\n", " if len(B) != n:\n", " raise ValueError(\"Inner dimensions must agree: A: {}, B: {}.\".format(n, len(B)))\n", "\n", " p = len(B[0])\n", " # column count of B must be p\n", " # Use list comprehensions for a compact, Python‑native implementation\n", " return [[sum(A[i][k] * B[k][j] for k in range(n)) for j in range(p)] for i in range(m)]\n", "def matmul(A, B):\n", " ...\n", "None\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B): return ...\n", "def matmul(A, B):\n", " # matrix multiplication using only native Python\n", " ...\n", " return ...\n", "def matmul(A, B):\n", " n_rows = len(A)\n", " n_cols = len(B[0])\n", " k = len(B) # check match\n", " if any(len(row)!=k for row in A):\n", " raise ValueError(\"A dimensions don't match B\")\n", " # transpose B for better cache\n", " B_T = list(zip(*B)) # returns tuples but we can keep\n", " result = [[0]*n_cols for _ in range(n_rows)]\n", " for i in range(n_rows):\n", " Ai = A[i]\n", " Ri = result[i]\n", " for j, Bj in enumerate(B_T):\n", " s = 0\n", " for a,b in zip(Ai, Bj):\n", " s += a*b\n", " Ri[j] = s\n", " return result\n", "def matmul(A, B):\n", " n = len(A)\n", " m = len(B[0])\n", " p = len(B)\n", " # assert p == len(A[0])? we can compute\n", " result = [[0]*m for _ in range(n)]\n", " for i in range(n):\n", " ai = A[i]\n", " for k in range(p):\n", " aik = ai[k]\n", " if aik:\n", " bk = B[k]\n", " for j in range(m):\n", " result[i][j] += aik * bk[j]\n", " return result\n", "None\n", "None\n", "def matmul(A, B):\n", " # Ensure both are lists of lists, etc. \n", " return [[sum(a*b for a,b in zip(row,col)) for col in zip(*B)] for row in A]\n", "None\n", "None\n", "def matmul(A, B): return ...\n", "None\n", "def matmul(A, B): return ...\n", "None\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiplies two matrices A and B using pure Python lists.\n", " Expects A to be m x n and B to be n x p.\n", " Returns the resulting m x p matrix.\n", " \"\"\"\n", " return [[sum(a * b for a, b in zip(rowA, colB)) for colB in zip(*B)] for rowA in A]\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B using only native Python constructs.\n", " \n", " Parameters\n", " ----------\n", " A : list[list[Number]]\n", " Left‑hand matrix.\n", " B : list[list[Number]]\n", " Right‑hand matrix.\n", " \n", " Returns\n", " -------\n", " list[list[Number]]\n", " The product matrix A * B.\n", " \"\"\"\n", " # Transpose B to ease column access\n", " Bt = list(zip(*B))\n", " # Compute each entry as the dot product of a row of A and a column of B\n", " return [[sum(a * b for a, b in zip(row, col))\n", " for col in Bt] for row in A]\n", "None\n", "def matmul(A, B):\n", " # assume A: m x n, B: n x p\n", " n = len(A)\n", " m = len(A[0])\n", " p = len(B[0])\n", " # precompute columns of B via zip\n", " columns_B = list(zip(*B))\n", " result = [[sum(a*b for a,b in zip(row, col)) for col in columns_B] for row in A]\n", " return result\n", "None\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B using plain Python.\n", " A : list of m rows, each a list of p numbers\n", " B : list of p rows, each a list of n numbers\n", " Returns a list of m rows, each a list of n numbers\n", " \"\"\"\n", " m, p = len(A), len(A[0]) # dimensions of A\n", " assert p == len(B), \"Incompatible matrix dimensions\"\n", " n = len(B[0]) # number of columns of B\n", "\n", " # initialise result matrix with zeros\n", " C = [[0] * n for _ in range(m)]\n", "\n", " # iterate over the shared dimension first\n", " for k in range(p):\n", " B_row = B[k]\n", " for i in range(m):\n", " aik = A[i][k]\n", " if aik: # skip multiplies by zero\n", " C_row = C[i]\n", " # the inner loop that does the real work\n", " for j in range(n):\n", " C_row[j] += aik * B_row[j]\n", " return C\n", "def matmul(A, B):\n", " m = len(A); n = len(A[0]); p = len(B[0])\n", " # validate B has proper shape\n", " result = [[0]*p for _ in range(m)]\n", " for i in range(m):\n", " rowA = A[i]\n", " row_res = result[i]\n", " for k in range(n):\n", " aik = rowA[k]\n", " if aik:\n", " rowBk = B[k]\n", " for j in range(p):\n", " row_res[j] += aik * rowBk[j]\n", " return result\n", "None\n", "None\n", "def matmul(A, B):\n", " if not A or not B:\n", " return []\n", " n, m = len(A), len(A[0])\n", " p, q = len(B), len(B[0])\n", " if m != p:\n", " raise ValueError(\"Inner matrix dimensions must agree.\")\n", " # compute C with zeros\n", " C = [[0]*q for _ in range(n)]\n", " for i in range(n):\n", " Ai = A[i]\n", " Ci = C[i]\n", " for k in range(m):\n", " aik = Ai[k]\n", " Bk = B[k]\n", " for j in range(q):\n", " Ci[j] += aik * Bk[j]\n", " return C\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A (n × p) and B (p × m) using a cache‑friendly\n", " block‑style algorithm that reduces Python overhead compared to a\n", " straightforward triple loop.\n", " \"\"\"\n", " n = len(A) # rows of A\n", " p = len(B) # shared dimension\n", " m = len(B[0]) # columns of B\n", "\n", " # Result matrix initialised with zeros\n", " C = [[0] * m for _ in range(n)]\n", "\n", " # Optimised loop ordering: i → k → j\n", " for i in range(n):\n", " Ai = A[i]\n", " Ci = C[i]\n", " for k in range(p):\n", " aik = Ai[k]\n", " if aik: # skip zero multiplications\n", " Bk = B[k]\n", " for j in range(m):\n", " Ci[j] += aik * Bk[j]\n", " return C\n", "def matmul(A, B):\n", " # ensure convertible shapes\n", " m, n = len(A), len(A[0])\n", " p = len(B[0])\n", " # Precompute transpose of B to improve cache locality\n", " BT = list(zip(*B))\n", " result = [[0]*p for _ in range(m)]\n", " for i in range(m):\n", " Ai = A[i]\n", " Ri = result[i]\n", " for j, Bj in enumerate(BT):\n", " s = 0\n", " for k in range(n):\n", " s += Ai[k] * Bj[k]\n", " Ri[j] = s\n", " return result\n", "None\n", "def matmul(A, B): return ...\n", "None\n", "None\n", "None\n", "None\n", "None\n", "None\n", "def matmul(A, B):\n", " # ensure shape\n", " if not A: return []\n", " m, k = len(A), len(A[0])\n", " k2, n = len(B), len(B[0])\n", " assert k == k2\n", " # transpose B\n", " B_T = list(zip(*B))\n", " # compute\n", " return [[sum(a*b for a,b in zip(row, col)) for col in B_T] for row in A]\n", "None\n", "def matmul(A, B):\n", " # A is an m x n matrix, B is an n x p matrix\n", " m = len(A)\n", " n = len(A[0]) if A else 0\n", " p = len(B[0]) if B else 0\n", "\n", " # In case the dimensions are incompatible, raise an error\n", " if n != len(B):\n", " raise ValueError(\"Incompatible matrix dimensions: A has %d cols but B has %d rows.\" % (n, len(B)))\n", "\n", " # Prepare the result matrix filled with zeros\n", " C = [[0] * p for _ in range(m)]\n", "\n", " # Perform the multiplication using a cache-friendly triple loop\n", " for i in range(m):\n", " Ai = A[i]\n", " Ci = C[i]\n", " for k in range(n):\n", " aik = Ai[k]\n", " Bk = B[k]\n", " for j in range(p):\n", " Ci[j] += aik * Bk[j]\n", " return C\n", "None\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B using plain Python lists.\n", "\n", " Parameters\n", " ----------\n", " A : list[list[float]]\n", " Left matrix of shape (m, n).\n", " B : list[list[float]]\n", " Right matrix of shape (n, p).\n", "\n", " Returns\n", " -------\n", " list[list[float]]\n", " The product matrix of shape (m, p).\n", "\n", " Raises\n", " ------\n", " ValueError\n", " If the inner dimensions do not agree.\n", " \"\"\"\n", " # Validate input\n", " if not A or not B:\n", " return []\n", "\n", " m, n_A = len(A), len(A[0])\n", " n_B, p = len(B), len(B[0])\n", "\n", " if n_A != n_B:\n", " raise ValueError(\"Inner matrix dimensions must agree: \"\n", " f\"{n_A} != {n_B}\")\n", "\n", " # Pre‑allocate result matrix with zeros\n", " C = [[0.0 for _ in range(p)] for _ in range(m)]\n", "\n", " # Classic triple‑loop multiplication\n", " for i in range(m):\n", " Ai = A[i]\n", " Ci = C[i]\n", " for k in range(n_A):\n", " aik = Ai[k]\n", " Bk = B[k]\n", " # Unroll the inner loop over p\n", " for j in range(p):\n", " Ci[j] += aik * Bk[j]\n", "\n", " return C\n", "None\n", "None\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiplies two matrices A and B represented as lists of lists using\n", " plain Python code. Handles generic rectangular matrices.\n", " \"\"\"\n", " # Basic size validation\n", " if not A or not B or not A[0] or not B[0]:\n", " raise ValueError(\"Matrices cannot be empty\")\n", " rows_a = len(A)\n", " cols_a = len(A[0])\n", " rows_b = len(B)\n", " cols_b = len(B[0])\n", "\n", " if cols_a != rows_b:\n", " raise ValueError(\"Inner dimensions must match for multiplication\")\n", "\n", " # Compute the product using a straightforward triple loop\n", " result = [\n", " [sum(A[i][k] * B[k][j] for k in range(cols_a)) for j in range(cols_b)]\n", " for i in range(rows_a)\n", " ]\n", "\n", " return result\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices using a naïve algorithm implemented in pure Python.\n", " \"\"\"\n", " # Get dimensions\n", " na = len(A)\n", " ma = len(A[0]) if A else 0\n", " nb = len(B)\n", " mb = len(B[0]) if B else 0\n", " \n", " if ma != nb:\n", " raise ValueError(\"Inner dimensions must match for multiplication.\")\n", " \n", " # Create result matrix\n", " result = [[0 for _ in range(mb)] for _ in range(na)]\n", " \n", " # Standard triple‑loop multiplication\n", " for i in range(na):\n", " Ai = A[i]\n", " for j in range(mb):\n", " sum = 0\n", " for k in range(ma):\n", " sum += Ai[k] * B[k][j]\n", " result[i][j] = sum\n", " return result\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B and return the product.\n", "\n", " Parameters\n", " ----------\n", " A : list[list[Number]]\n", " Left matrix of shape (m, n).\n", " B : list[list[Number]]\n", " Right matrix of shape (n, p).\n", "\n", " Returns\n", " -------\n", " C : list[list[Number]]\n", " Matrix product of shape (m, p).\n", "\n", " Raises\n", " ------\n", " ValueError\n", " If the input matrices cannot be multiplied due to incompatible shapes.\n", " \"\"\"\n", " # Validate inputs\n", " if not A or not B:\n", " return []\n", "\n", " # Dimensions\n", " m, n = len(A), len(A[0]) if A else 0\n", " n_b, p = len(B), len(B[0]) if B else 0\n", "\n", " if n != n_b:\n", " raise ValueError(f\"Incompatible dimensions: A is {m}x{n}, B is {n_b}x{p}\")\n", "\n", " # Pre-allocate result matrix with zeros\n", " C = [[0] * p for _ in range(m)]\n", "\n", " # Main multiplication loop\n", " for i in range(m):\n", " a_row = A[i]\n", " c_row = C[i]\n", " for k in range(n):\n", " aik = a_row[k]\n", " if aik == 0:\n", " continue # Skip multiplication by zero\n", " b_row_k = B[k]\n", " # Unroll the inner loop for potential speed\n", " for j in range(p):\n", " c_row[j] += aik * b_row_k[j]\n", " return C\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " return ...\n", "def matmul(A, B):\n", " \"\"\"Multiply two matrices using only native Python lists.\"\"\"\n", " # Basic dimension checks\n", " if not A or not B or not A[0] or not B[0]:\n", " raise ValueError(\"Input matrices must be non-empty\")\n", " if len(A[0]) != len(B):\n", " raise ValueError(\"Number of columns in A must equal number of rows in B\")\n", "\n", " # Transpose B to get its columns efficiently\n", " BT = list(zip(*B)) # each is a tuple: column of B\n", "\n", " # Compute the product\n", " result = [[sum(a * b for a, b in zip(row, col)) for col in BT] for row in A]\n", " return result\n", "def matmul(A, B):\n", " m = len(A)\n", " p = len(A[0])\n", " if p==0:\n", " return [[] for _ in range(m)]\n", " n = len(B[0])\n", " # transpose B for cache\n", " Bt = [[B[i][j] for i in range(p)] for j in range(n)]\n", " res = [[0]*n for _ in range(m)]\n", " for i in range(m):\n", " Ai = A[i]\n", " Ri = res[i]\n", " for k in range(p):\n", " aik = Ai[k]\n", " Bk = Bt # not right; we need to use Bt[j][k].\n", "def matmul(A, B): return ...\n", "def matmul(A, B):\n", " ...\n", "None\n", "def matmul(A, B):\n", " # A: m x n, B: n x p\n", " m, n = len(A), len(A[0])\n", " assert n == len(B), \"Incompatible matrices\"\n", " p = len(B[0])\n", " # pre-transpose B to speed up\n", " B_T = list(map(list, zip(*B))) # p x n\n", " return [[sum(a*b for a, b in zip(row, col)) for col in B_T] for row in A]\n", "None\n", "def matmul(A, B): return ...\n", "def matmul(A, B):\n", " \"\"\"Return `A * B` for two 2‑D lists `A` and `B` using pure Python.\"\"\"\n", " # Ensure dimensions are compatible\n", " if not A or not B or not B[0]:\n", " return []\n", "\n", " # Transpose B once to obtain column access in O(1)\n", " B_t = list(zip(*B))\n", "\n", " # Compute the matrix product using list comprehensions\n", " return [\n", " [sum(a * b for a, b in zip(row, col))\n", " for col in B_t] # compute dot product of row with each column of B\n", " for row in A\n", " ]\n", "def matmul(A, B):\n", " \"\"\"\n", " Compute the matrix product C = A @ B with pure Python.\n", "\n", " Parameters\n", " ----------\n", " A : List[List[float]]\n", " Left matrix of size m × p.\n", " B : List[List[float]]\n", " Right matrix of size p × n.\n", "\n", " Returns\n", " -------\n", " List[List[float]]\n", " Resulting matrix of size m × n.\n", " \"\"\"\n", " m, pA = len(A), len(A[0])\n", " pB, n = len(B), len(B[0])\n", "\n", " if pA != pB:\n", " raise ValueError(\"Inner matrix dimensions must agree (got %d×%d and %d×%d) \"\n", " % (m, pA, pB, n))\n", "\n", " # Allocate result matrix\n", " C = [[0.0] * n for _ in range(m)]\n", "\n", " # Blocked matrix multiplication for better cache locality.\n", " # Choose a block size that works well on most systems\n", " block = 64\n", "\n", " for i0 in range(0, m, block):\n", " i_max = min(i0 + block, m)\n", " for k0 in range(0, pA, block):\n", " k_max = min(k0 + block, pA)\n", " for j0 in range(0, n, block):\n", " j_max = min(j0 + block, n)\n", "\n", " for i in range(i0, i_max):\n", " ai = A[i]\n", " ci = C[i]\n", " for k in range(k0, k_max):\n", " aik = ai[k]\n", " bk = B[k]\n", " for j in range(j0, j_max):\n", " ci[j] += aik * bk[j]\n", " return C\n", "None\n", "def matmul(A, B):\n", " if not A or not B:\n", " return []\n", " m = len(A)\n", " n = len(A[0])\n", " p = len(B[0])\n", " # check compatibility: len(B)==n\n", " # compute result matrix with nested loops\n", " res = [[0]*p for _ in range(m)]\n", " for i in range(m):\n", " Ai = A[i]\n", " for j in range(p):\n", " s = 0\n", " for k in range(n):\n", " s += Ai[k]*B[k][j]\n", " res[i][j] = s\n", " return res\n", "def matmul(A, B):\n", " # check dimensions\n", " ...\n", " # compute product\n", "def matmul(A, B):\n", " # Extract dimensions\n", " if not A or not B or not A[0] or not B[0]:\n", " return [] # or raise ValueError...\n", " # Determine dimensions\n", " p = len(A[0]) # columns of A, rows of B\n", " if p != len(B):\n", " raise ValueError(\"Incompatible matrix sizes.\")\n", " # Transpose B for efficient column access\n", " BT = list(zip(*B))\n", " # Perform multiplication\n", " result = []\n", " for row in A:\n", " new_row = []\n", " for col in BT:\n", " # Dot product\n", " sum_val = 0\n", " for a, b in zip(row, col):\n", " sum_val += a * b\n", " new_row.append(sum_val)\n", " result.append(new_row)\n", " return result\n", "def matmul(A, B):\n", " # optimize: transpose B\n", " BT = list(zip(*B)) # tuple of tuples (row of B^T)\n", " return [[sum(a*b for a, b in zip(rowA, colB)) for colB in BT] for rowA in A]\n", "def matmul(A, B):\n", " n = len(A)\n", " m = len(B[0])\n", " p = len(B)\n", " return [[sum(A[i][k]*B[k][j] for k in range(p)) for j in range(m)] for i in range(n)]\n", "def matmul(A, B):\n", " \"\"\"\n", " Multiply two matrices A and B using pure Python.\n", "\n", " Parameters\n", " ----------\n", " A : list[list[float | int | complex]]\n", " The first matrix of dimensions (m, p).\n", " B : list[list[float | int | complex]]\n", " The second matrix of dimensions (p, n).\n", "\n", " Returns\n", " -------\n", " list[list[float | int | complex]]\n", " The product matrix of dimensions (m, n).\n", "\n", " Raises\n", " ------\n", " ValueError\n", " If the inner dimensions do not match.\n", " \"\"\"\n", " # Basic sanity checks for empty/corrupt inputs\n", " if not A or not B:\n", " raise ValueError(\"Input matrices cannot be empty\")\n", "\n", " p = len(A[0])\n", " for row in A:\n", " if len(row) != p:\n", " raise ValueError(\"All rows in A must have the same length\")\n", " q = len(B)\n", " for row in B:\n", " if len(row) != len(B[0]):\n", " raise ValueError(\"All rows in B must have the same length\")\n", "\n", " # Ensure the inner dimensions match\n", " if p != q:\n", " raise ValueError(\"Inner matrix dimensions must agree: len(A[0]) != len(B)\")\n", "\n", " # Transpose B to avoid repeated look‑ups in the inner loop\n", " B_T = list(zip(*B)) # Each element is a tuple representing a column\n", "\n", " # Perform matrix multiplication\n", " return [[sum(a * b for a, b in zip(row, col)) for col in B_T] for row in A]\n", "def matmul(A, B):\n", " \"\"\"Return the matrix product of A and B using native Python.\"\"\"\n", " # Transpose B once to avoid repeated column lookups.\n", " B_T = list(zip(*B))\n", " # Compute each element of the product using a generator expression.\n", " return [[sum(a * b for a, b in zip(row, col)) for col in B_T] for row in A]\n", "None\n", "None\n", "None\n", "None\n", "def matmul(A, B):\n", " if not A or not B: \n", " return []\n", " nrowA = len(A)\n", " ncolA = len(A[0])\n", " nrowB = len(B)\n", " ncolB = len(B[0])\n", " if ncolA != nrowB:\n", " raise ValueError(\"Incompatible dimensions for matrix multiplication\")\n", " # Precompute columns of B\n", " colsB = list(zip(*B))\n", " return [\n", " [sum(a*b for a, b in zip(row, col)) for col in colsB]\n", " for row in A\n", " ]\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=100, training_loss=3.209321222243488e-05, metrics={'train_runtime': 5865.3354, 'train_samples_per_second': 0.034, 'train_steps_per_second': 0.017, 'total_flos': 0.0, 'train_loss': 3.209321222243488e-05})" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now with the LoRA we just trained with GRPO - we first save the LoRA first!" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "model.save_pretrained(\"gemma_4_lora\") # Local saving\n", "tokenizer.save_pretrained(\"gemma_4_lora\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Verify LoRA is actually trained!" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "from safetensors import safe_open\n", "\n", "tensors = {}\n", "with safe_open(\"grpo_saved_lora/adapter_model.safetensors\", framework = \"pt\") as f:\n", " # Verify both A and B are non zero\n", " for key in f.keys():\n", " tensor = f.get_tensor(key)\n", " n_zeros = (tensor == 0).sum() / tensor.numel()\n", " assert(n_zeros.item() != tensor.numel())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# Inference\n", "Now let's try the model we just trained!" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "text = tokenizer.apply_chat_template(\n", " [{\"role\": \"user\", \"content\": prompt.strip()}],\n", " tokenize = False,\n", " add_generation_prompt = True,\n", ")\n", "\n", "from transformers import TextStreamer\n", "\n", "_ = model.generate(\n", " **tokenizer(images = None, text = text, return_tensors = \"pt\").to(\"cuda\"),\n", " temperature = 1.0, top_p = 0.95, top_k = 64,\n", " max_new_tokens = 1024,\n", " streamer = TextStreamer(tokenizer, skip_prompt = False),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### Saving to float16 for VLLM\n", "\n", "We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens. See [our docs](https://unsloth.ai/docs/basics/inference-and-deployment) for more deployment options." ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Merge to 16bit\n", "if False: model.save_pretrained_merged(\"gemma_4_finetune_16bit\", tokenizer, save_method = \"merged_16bit\",)\n", "if False: model.push_to_hub_merged(\"HF_USERNAME/gemma_4_finetune_16bit\", tokenizer, save_method = \"merged_16bit\", token = \"YOUR_HF_TOKEN\")\n", "\n", "# Merge to 4bit\n", "if False: model.save_pretrained_merged(\"gemma_4_finetune_4bit\", tokenizer, save_method = \"merged_4bit\",)\n", "if False: model.push_to_hub_merged(\"HF_USERNAME/gemma_4_finetune_4bit\", tokenizer, save_method = \"merged_4bit\", token = \"YOUR_HF_TOKEN\")\n", "\n", "# Just LoRA adapters\n", "if False:\n", " model.save_pretrained(\"gemma_4_lora\")\n", " tokenizer.save_pretrained(\"gemma_4_lora\")\n", "if False:\n", " model.push_to_hub(\"HF_USERNAME/gemma_4_lora\", token = \"YOUR_HF_TOKEN\")\n", " tokenizer.push_to_hub(\"HF_USERNAME/gemma_4_lora\", token = \"YOUR_HF_TOKEN\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GGUF / llama.cpp Conversion\n", "To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.\n", "\n", "Some supported quant methods (full list on our [docs page](https://unsloth.ai/docs/basics/inference-and-deployment/saving-to-gguf)):\n", "* `q8_0` - Fast conversion. High resource use, but generally acceptable.\n", "* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.\n", "* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.\n", "\n", "[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)" ] }, { "cell_type": "code", "metadata": {}, "execution_count": null, "outputs": [], "source": [ "# Save to 8bit Q8_0\n", "if False: model.save_pretrained_gguf(\"gemma_4_finetune\", tokenizer,)\n", "# Remember to go to https://huggingface.co/settings/tokens for a token!\n", "# And change hf to your username!\n", "if False: model.push_to_hub_gguf(\"HF_USERNAME/gemma_4_finetune\", tokenizer, token = \"YOUR_HF_TOKEN\")\n", "\n", "# Save to 16bit GGUF\n", "if False: model.save_pretrained_gguf(\"gemma_4_finetune\", tokenizer, quantization_method = \"f16\")\n", "if False: model.push_to_hub_gguf(\"HF_USERNAME/gemma_4_finetune\", tokenizer, quantization_method = \"f16\", token = \"YOUR_HF_TOKEN\")\n", "\n", "# Save to q4_k_m GGUF\n", "if False: model.save_pretrained_gguf(\"gemma_4_finetune\", tokenizer, quantization_method = \"q4_k_m\")\n", "if False: model.push_to_hub_gguf(\"HF_USERNAME/gemma_4_finetune\", tokenizer, quantization_method = \"q4_k_m\", token = \"YOUR_HF_TOKEN\")\n", "\n", "# Save to multiple GGUF options - much faster if you want multiple!\n", "if False:\n", " model.push_to_hub_gguf(\n", " \"HF_USERNAME/gemma_4_finetune\", # Change hf to your username!\n", " tokenizer,\n", " quantization_method = [\"q4_k_m\", \"q8_0\", \"q5_k_m\",],\n", " token = \"YOUR_HF_TOKEN\",\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, use the `gemma_4_finetune.Q8_0.gguf` file or `gemma_4_finetune.Q4_K_M.gguf` file in llama.cpp.\n", "\n", "And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n", "\n", "Some other resources:\n", "1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)\n", "2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)\n", "3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)\n", "4. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://unsloth.ai/docs/get-started/unsloth-notebooks)!\n", "\n", "

\n", " \n", " \n", " \n", "\n", " Join Discord if you need help + ⭐️ Star us on Github ⭐️\n", "
\n", "\n", " This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {} } } }, "nbformat": 4, "nbformat_minor": 0 }