Skip to main content
This page walks through writing, compiling, and running a vector-add kernel — the minimal end-to-end path in Enigma.

Step 1: define the kernel

Decorate a Python function with @enigma.kernel. The function body is traced, not executed: each expression becomes an IR node, not a Python value.
import enigma

@enigma.kernel
def vector_add(A: enigma.f32, B: enigma.f32, C: enigma.f32):
    tid = enigma.thread_position_in_grid
    C[tid] = A[tid] + B[tid]
What this does:
  • A, B, C are typed buffer parameters (f32 = float in Metal)
  • thread_position_in_grid returns the global thread ID as an IRValue
  • Indexing A[tid] emits a load; assigning C[tid] = ... emits a store

Step 2: compile to a Metal library

compiled = enigma.compile(vector_add)
enigma.compile runs the full pipeline:
  1. Traces the Python function to an IR
  2. Lowers the IR to the Enigma MLIR dialect
  3. Emits Metal Shading Language (MSL) source
  4. Invokes xcrun metal and xcrun metallib
  5. Returns a CompiledKernel with all artifacts

Inspect the output

print(compiled.kernel_name)        # enigma_kernel_vector_add
print(compiled.metal_source[:500]) # generated .metal source

Compile with verbose output

compiled = enigma.compile(
    vector_add,
    dump_ir=True,
    dump_mlir=True,
    keep_metal_source=True,
    work_dir="./build/enigma",
)

Step 3: dispatch

import numpy as np

rt = enigma.MetalRuntime()
n = 4096

a = np.random.randn(n).astype(np.float32)
b = np.random.randn(n).astype(np.float32)

raw = rt.execute(
    compiled,
    inputs=[a, b],
    output_size=n * 4,       # bytes: 4096 elements × 4 bytes
    grid=(n, 1, 1),          # one thread per element
    threads=(256, 1, 1),     # threadgroup size
)
out = np.frombuffer(raw, dtype=np.float32)

Key parameters

ParameterMeaning
inputsList of numpy arrays passed as read-only device buffers
output_sizeSize in bytes of the output buffer
grid(gx, gy, gz) — total threads in each dimension
threads(tx, ty, tz) — threads per threadgroup

Step 4: validate

np.testing.assert_allclose(out, a + b, rtol=1e-5, atol=1e-7)
print("pass")

Common first-run mistakes

MistakeSymptom
output_size too smallPartial result or garbage at tail
grid smaller than data domainSome elements not written
Wrong dtype in np.frombufferNonsensical values
Running on non-Metal machineMetalRuntime dispatch exception

Next steps