> ## Documentation Index
> Fetch the complete documentation index at: https://klyne-research.mintlify.site/llms.txt
> Use this file to discover all available pages before exploring further.

# Tiled GEMM

> Matrix multiply using layout algebra and the @enigma.jit tiling pattern.

General matrix multiply (GEMM) is the canonical high-performance GPU kernel. This example shows how Enigma's layout algebra replaces manual index arithmetic with composable layout transformations.

## Problem

Compute `C = A × B` where:

* `A` is `(M, K)`, row-major
* `B` is `(K, N)`, row-major
* `C` is `(M, N)`, row-major

## Approach: tiled outer product

Each threadgroup computes one tile of `C`. Threads within the threadgroup cooperate by loading shared tiles of `A` and `B` into threadgroup memory, then computing the partial dot products.

```
┌─────────────────────────────────┐
│ C tile (M_tile × N_tile)        │
│  computed by one threadgroup    │
│                                 │
│  Loop over K in K_tile steps:   │
│  - load A[block_m, k_block]     │  → shared memory
│  - load B[k_block, block_n]     │  → shared memory
│  - barrier                      │
│  - accumulate partial products  │
└─────────────────────────────────┘
```

## Naive kernel (no shared memory)

A simpler starting point — each thread computes one output element:

```python theme={null}
import enigma
import numpy as np


@enigma.kernel
def naive_gemm(
    A: enigma.f32,   # (M, K) row-major
    B: enigma.f32,   # (K, N) row-major
    C: enigma.f32,   # (M, N) row-major
    K: enigma.Scalar(enigma.u32),
    N: enigma.Scalar(enigma.u32),
):
    row = enigma.thread_position_in_grid_xyz("y")
    col = enigma.thread_position_in_grid_xyz("x")

    acc = enigma.f32(0.0)
    with enigma.for_range(0, K, init=[enigma.f32(0.0)]) as (k, acc):
        a_val = A[row * K + k]
        b_val = B[k * N + col]
        acc = enigma.fma(a_val, b_val, acc)

    C[row * N + col] = acc
```

Dispatch:

```python theme={null}
M, K_dim, N = 256, 256, 256
a = np.random.randn(M, K_dim).astype(np.float32)
b = np.random.randn(K_dim, N).astype(np.float32)

compiled = enigma.compile(naive_gemm)
rt = enigma.MetalRuntime()

raw = rt.execute(
    compiled,
    inputs=[a.ravel(), b.ravel(),
            np.array([K_dim], dtype=np.uint32),
            np.array([N], dtype=np.uint32)],
    output_size=M * N * 4,
    grid=(N, M, 1),
    threads=(16, 16, 1),
)
c_gpu = np.frombuffer(raw, dtype=np.float32).reshape(M, N)
c_cpu = a @ b
np.testing.assert_allclose(c_gpu, c_cpu, rtol=1e-4, atol=1e-4)
print("pass")
```

## Tiled kernel with layout algebra

For larger matrices, tile with layout algebra to enable shared memory reuse:

```python theme={null}
@enigma.jit
def tiled_gemm(mA: enigma.Tensor, mB: enigma.Tensor, mC: enigma.Tensor):
    # Thread-value layout: 256 threads, each handling a 4×4 tile
    thr = enigma.make_ordered_layout((4, 64), order=(1, 0))
    val = enigma.make_ordered_layout((4, 4), order=(1, 0))
    tiler_mn, tv_layout = enigma.make_layout_tv(thr, val)

    # Partition tensors into block tiles
    gA = enigma.tensor_zipped_divide(mA, tiler_mn)
    gB = enigma.tensor_zipped_divide(mB, tiler_mn)
    gC = enigma.tensor_zipped_divide(mC, tiler_mn)

    @enigma.kernel
    def inner(blkA, blkB, blkC, tv):
        thread_idx = enigma.thread_position_in_grid

        # Per-thread fragments
        thrA = enigma.tensor_composition(blkA, tv, tiler_mn)[(thread_idx, None)]
        thrB = enigma.tensor_composition(blkB, tv, tiler_mn)[(thread_idx, None)]
        thrC = enigma.tensor_composition(blkC, tv, tiler_mn)[(thread_idx, None)]

        # Load, compute, store
        a_frag = thrA.load()
        b_frag = thrB.load()
        c_frag = thrC.load()
        # (accumulate into c_frag — simplified here)
        thrC.store(c_frag)

    n_blocks_m = enigma.size(gA, mode=[1])
    n_blocks_n = enigma.size(gB, mode=[1])
    inner.launch(
        grid=(n_blocks_m * n_blocks_n, 1, 1),
        block=(enigma.size(tv_layout, mode=[0]), 1, 1),
    )
```

## Performance tips

* **Tile size**: 16×16 or 32×32 threadgroup tiles hit L1/L2 cache reuse sweet spots.
* **vec\_width=4**: Vectorize loads with `float4` to improve memory bandwidth.
* **Simdgroup matrix**: On M3+ hardware, use `simdgroup_multiply_accumulate` for 8×8 hardware-accelerated matrix multiply.

```python theme={null}
# Simdgroup matrix variant (M3+ only)
caps = rt.device_capabilities()
caps.require_m3("simdgroup_matrix_multiply_accumulate")

matA = enigma.simdgroup_matrix_load(buf_A, elements_per_row=8)
matB = enigma.simdgroup_matrix_load(buf_B, elements_per_row=8)
matC = enigma.make_filled_simdgroup_matrix(0.0)
matC = enigma.simdgroup_multiply_accumulate(matA, matB, matC)
enigma.simdgroup_matrix_store(matC, buf_C, elements_per_row=8)
```

## See also

* [Layout Algebra](/programming-guide/layout-algebra) — tiling concepts
* [Memory Model](/concepts/memory-model) — threadgroup shared memory
* [SIMD Group Operations](/programming-guide/simd-group-ops)
* [API Reference: Layout Functions](/api-reference/layout-functions)
