Skip to main content
Enigma provides two decorators for GPU code: @enigma.kernel for direct compute kernels and @enigma.jit for host-side functions that perform layout algebra and launch kernels.

@enigma.kernel

@enigma.kernel defines a GPU compute kernel. The decorated function is traced at compile time: the Python body runs once with symbolic IRValue placeholders, producing an IR graph that is then lowered to Metal.

Rules for kernel bodies

  • No native Python control flow — use enigma.for_range, enigma.if_, enigma.while_ instead.
  • No Python data types as values — all computation goes through IRValue objects.
  • No early returns — the tracer records the full body unconditionally.
  • Types are annotations, not runtime checksA: enigma.f32 declares the buffer’s element type.

Kernel parameters

Each parameter is one of:
  • enigma.f32, enigma.f16, enigma.bf16, enigma.i32, enigma.u32, etc. — typed buffer (pointer in Metal)
  • enigma.Scalar(dtype) — scalar constant, lowered as a 1-element buffer
@enigma.kernel
def saxpy(X: enigma.f32, alpha: enigma.Scalar(enigma.f32), Y: enigma.f32, Out: enigma.f32):
    tid = enigma.thread_position_in_grid
    Out[tid] = alpha * X[tid] + Y[tid]

Compiling a kernel

compiled = enigma.compile(saxpy)
compiled is a CompiledKernel. Dispatch it with MetalRuntime.execute(...):
rt = enigma.MetalRuntime()
raw = rt.execute(
    compiled,
    inputs=[x, alpha_val, y],   # numpy arrays in parameter order
    output_size=n * 4,
    grid=(n, 1, 1),
    threads=(256, 1, 1),
)

Repeated dispatch

For hot loops, pre-allocate buffers once:
prepared = rt.prepare(compiled, inputs=[x, alpha_val, y], output_size=n * 4)
# later:
prepared.dispatch(grid=(n, 1, 1), threads=(256, 1, 1))
result = prepared.read_output(dtype=np.float32)
For benchmarking:
elapsed_ms = prepared.dispatch_timed(grid=(n, 1, 1), threads=(256, 1, 1))

@enigma.jit

@enigma.jit is a host-side function that runs at compile time. Use it when you need to:
  • Perform layout algebra (tiling, thread-value mapping) before generating kernels
  • Launch multiple kernels in sequence
  • Accept Tensor objects and compute block/thread partitioning

JIT function pattern

@enigma.jit
def tiled_add(mA: enigma.Tensor, mB: enigma.Tensor, mC: enigma.Tensor):
    # Layout arithmetic — runs at compile time
    thr = enigma.make_ordered_layout((4, 64), order=(1, 0))
    val = enigma.make_ordered_layout((4, 4), order=(1, 0))
    tiler_mn, tv_layout = enigma.make_layout_tv(thr, val)

    gA = enigma.tensor_zipped_divide(mA, tiler_mn)
    gB = enigma.tensor_zipped_divide(mB, tiler_mn)
    gC = enigma.tensor_zipped_divide(mC, tiler_mn)

    # Inner kernel — dispatched for each tile
    @enigma.kernel
    def inner(blkA, blkB, blkC, tv):
        thread_idx = enigma.thread_position_in_grid
        thrA = enigma.tensor_composition(blkA, tv, tiler_mn)[(thread_idx, None)]
        thrB = enigma.tensor_composition(blkB, tv, tiler_mn)[(thread_idx, None)]
        thrC = enigma.tensor_composition(blkC, tv, tiler_mn)[(thread_idx, None)]
        thrC.store(thrA.load() + thrB.load())

    inner.launch(grid=..., block=...)

Compiling a JIT function

m, n = 256, 512
layout = enigma.Layout((m, n), (n, 1))
A = enigma.Tensor("A", 0, "float", layout)
B = enigma.Tensor("B", 1, "float", layout)
C = enigma.Tensor("C", 2, "float", layout)

compiled = enigma.compile(tiled_add, A, B, C)
# compiled.grid and compiled.block are set by inner.launch()

Control flow

Enigma provides tracing-compatible control flow wrappers:

enigma.for_range

with enigma.for_range(0, n) as i:
    acc = acc + A[i]
Supports IRValue bounds, custom induction dtype, and carried state via init=[...]:
with enigma.for_range(0, n, init=[enigma.f32(0.0)]) as (i, acc):
    acc = acc + A[i]

enigma.if_

with enigma.if_(enigma.cmp_lt(tid, n)):
    C[tid] = A[tid] + B[tid]

enigma.while_

with enigma.while_(lambda: enigma.cmp_lt(i, n)):
    i = i + 1

Predicated load/store

For boundary conditions without full control flow:
mask = enigma.cmp_lt(tid, n)
val  = enigma.load_if(A, tid, mask, default=enigma.f32(0.0))
enigma.store_if(C, tid, val, mask)