> ## Documentation Index
> Fetch the complete documentation index at: https://klyne-research.mintlify.site/llms.txt
> Use this file to discover all available pages before exploring further.

# Kernels and JIT

> How @enigma.kernel and @enigma.jit work, and when to use each.

Enigma provides two decorators for GPU code: `@enigma.kernel` for direct compute kernels and `@enigma.jit` for host-side functions that perform layout algebra and launch kernels.

## `@enigma.kernel`

`@enigma.kernel` defines a GPU compute kernel. The decorated function is **traced** at compile time: the Python body runs once with symbolic `IRValue` placeholders, producing an IR graph that is then lowered to Metal.

### Rules for kernel bodies

* **No native Python control flow** — use `enigma.for_range`, `enigma.if_`, `enigma.while_` instead.
* **No Python data types as values** — all computation goes through `IRValue` objects.
* **No early returns** — the tracer records the full body unconditionally.
* **Types are annotations, not runtime checks** — `A: enigma.f32` declares the buffer's element type.

### Kernel parameters

Each parameter is one of:

* `enigma.f32`, `enigma.f16`, `enigma.bf16`, `enigma.i32`, `enigma.u32`, etc. — typed buffer (pointer in Metal)
* `enigma.Scalar(dtype)` — scalar constant, lowered as a 1-element buffer

```python theme={null}
@enigma.kernel
def saxpy(X: enigma.f32, alpha: enigma.Scalar(enigma.f32), Y: enigma.f32, Out: enigma.f32):
    tid = enigma.thread_position_in_grid
    Out[tid] = alpha * X[tid] + Y[tid]
```

### Compiling a kernel

```python theme={null}
compiled = enigma.compile(saxpy)
```

`compiled` is a `CompiledKernel`. Dispatch it with `MetalRuntime.execute(...)`:

```python theme={null}
rt = enigma.MetalRuntime()
raw = rt.execute(
    compiled,
    inputs=[x, alpha_val, y],   # numpy arrays in parameter order
    output_size=n * 4,
    grid=(n, 1, 1),
    threads=(256, 1, 1),
)
```

### Repeated dispatch

For hot loops, pre-allocate buffers once:

```python theme={null}
prepared = rt.prepare(compiled, inputs=[x, alpha_val, y], output_size=n * 4)
# later:
prepared.dispatch(grid=(n, 1, 1), threads=(256, 1, 1))
result = prepared.read_output(dtype=np.float32)
```

For benchmarking:

```python theme={null}
elapsed_ms = prepared.dispatch_timed(grid=(n, 1, 1), threads=(256, 1, 1))
```

***

## `@enigma.jit`

`@enigma.jit` is a host-side function that runs at compile time. Use it when you need to:

* Perform layout algebra (tiling, thread-value mapping) before generating kernels
* Launch multiple kernels in sequence
* Accept `Tensor` objects and compute block/thread partitioning

### JIT function pattern

```python theme={null}
@enigma.jit
def tiled_add(mA: enigma.Tensor, mB: enigma.Tensor, mC: enigma.Tensor):
    # Layout arithmetic — runs at compile time
    thr = enigma.make_ordered_layout((4, 64), order=(1, 0))
    val = enigma.make_ordered_layout((4, 4), order=(1, 0))
    tiler_mn, tv_layout = enigma.make_layout_tv(thr, val)

    gA = enigma.tensor_zipped_divide(mA, tiler_mn)
    gB = enigma.tensor_zipped_divide(mB, tiler_mn)
    gC = enigma.tensor_zipped_divide(mC, tiler_mn)

    # Inner kernel — dispatched for each tile
    @enigma.kernel
    def inner(blkA, blkB, blkC, tv):
        thread_idx = enigma.thread_position_in_grid
        thrA = enigma.tensor_composition(blkA, tv, tiler_mn)[(thread_idx, None)]
        thrB = enigma.tensor_composition(blkB, tv, tiler_mn)[(thread_idx, None)]
        thrC = enigma.tensor_composition(blkC, tv, tiler_mn)[(thread_idx, None)]
        thrC.store(thrA.load() + thrB.load())

    inner.launch(grid=..., block=...)
```

### Compiling a JIT function

```python theme={null}
m, n = 256, 512
layout = enigma.Layout((m, n), (n, 1))
A = enigma.Tensor("A", 0, "float", layout)
B = enigma.Tensor("B", 1, "float", layout)
C = enigma.Tensor("C", 2, "float", layout)

compiled = enigma.compile(tiled_add, A, B, C)
# compiled.grid and compiled.block are set by inner.launch()
```

***

## Control flow

Enigma provides tracing-compatible control flow wrappers:

### `enigma.for_range`

```python theme={null}
with enigma.for_range(0, n) as i:
    acc = acc + A[i]
```

Supports IRValue bounds, custom induction dtype, and carried state via `init=[...]`:

```python theme={null}
with enigma.for_range(0, n, init=[enigma.f32(0.0)]) as (i, acc):
    acc = acc + A[i]
```

### `enigma.if_`

```python theme={null}
with enigma.if_(enigma.cmp_lt(tid, n)):
    C[tid] = A[tid] + B[tid]
```

### `enigma.while_`

```python theme={null}
with enigma.while_(lambda: enigma.cmp_lt(i, n)):
    i = i + 1
```

### Predicated load/store

For boundary conditions without full control flow:

```python theme={null}
mask = enigma.cmp_lt(tid, n)
val  = enigma.load_if(A, tid, mask, default=enigma.f32(0.0))
enigma.store_if(C, tid, val, mask)
```
