Skip to main content
General matrix multiply (GEMM) is the canonical high-performance GPU kernel. This example shows how Enigma’s layout algebra replaces manual index arithmetic with composable layout transformations.

Problem

Compute C = A × B where:
  • A is (M, K), row-major
  • B is (K, N), row-major
  • C is (M, N), row-major

Approach: tiled outer product

Each threadgroup computes one tile of C. Threads within the threadgroup cooperate by loading shared tiles of A and B into threadgroup memory, then computing the partial dot products.
┌─────────────────────────────────┐
│ C tile (M_tile × N_tile)        │
│  computed by one threadgroup    │
│                                 │
│  Loop over K in K_tile steps:   │
│  - load A[block_m, k_block]     │  → shared memory
│  - load B[k_block, block_n]     │  → shared memory
│  - barrier                      │
│  - accumulate partial products  │
└─────────────────────────────────┘

Naive kernel (no shared memory)

A simpler starting point — each thread computes one output element:
import enigma
import numpy as np


@enigma.kernel
def naive_gemm(
    A: enigma.f32,   # (M, K) row-major
    B: enigma.f32,   # (K, N) row-major
    C: enigma.f32,   # (M, N) row-major
    K: enigma.Scalar(enigma.u32),
    N: enigma.Scalar(enigma.u32),
):
    row = enigma.thread_position_in_grid_xyz("y")
    col = enigma.thread_position_in_grid_xyz("x")

    acc = enigma.f32(0.0)
    with enigma.for_range(0, K, init=[enigma.f32(0.0)]) as (k, acc):
        a_val = A[row * K + k]
        b_val = B[k * N + col]
        acc = enigma.fma(a_val, b_val, acc)

    C[row * N + col] = acc
Dispatch:
M, K_dim, N = 256, 256, 256
a = np.random.randn(M, K_dim).astype(np.float32)
b = np.random.randn(K_dim, N).astype(np.float32)

compiled = enigma.compile(naive_gemm)
rt = enigma.MetalRuntime()

raw = rt.execute(
    compiled,
    inputs=[a.ravel(), b.ravel(),
            np.array([K_dim], dtype=np.uint32),
            np.array([N], dtype=np.uint32)],
    output_size=M * N * 4,
    grid=(N, M, 1),
    threads=(16, 16, 1),
)
c_gpu = np.frombuffer(raw, dtype=np.float32).reshape(M, N)
c_cpu = a @ b
np.testing.assert_allclose(c_gpu, c_cpu, rtol=1e-4, atol=1e-4)
print("pass")

Tiled kernel with layout algebra

For larger matrices, tile with layout algebra to enable shared memory reuse:
@enigma.jit
def tiled_gemm(mA: enigma.Tensor, mB: enigma.Tensor, mC: enigma.Tensor):
    # Thread-value layout: 256 threads, each handling a 4×4 tile
    thr = enigma.make_ordered_layout((4, 64), order=(1, 0))
    val = enigma.make_ordered_layout((4, 4), order=(1, 0))
    tiler_mn, tv_layout = enigma.make_layout_tv(thr, val)

    # Partition tensors into block tiles
    gA = enigma.tensor_zipped_divide(mA, tiler_mn)
    gB = enigma.tensor_zipped_divide(mB, tiler_mn)
    gC = enigma.tensor_zipped_divide(mC, tiler_mn)

    @enigma.kernel
    def inner(blkA, blkB, blkC, tv):
        thread_idx = enigma.thread_position_in_grid

        # Per-thread fragments
        thrA = enigma.tensor_composition(blkA, tv, tiler_mn)[(thread_idx, None)]
        thrB = enigma.tensor_composition(blkB, tv, tiler_mn)[(thread_idx, None)]
        thrC = enigma.tensor_composition(blkC, tv, tiler_mn)[(thread_idx, None)]

        # Load, compute, store
        a_frag = thrA.load()
        b_frag = thrB.load()
        c_frag = thrC.load()
        # (accumulate into c_frag — simplified here)
        thrC.store(c_frag)

    n_blocks_m = enigma.size(gA, mode=[1])
    n_blocks_n = enigma.size(gB, mode=[1])
    inner.launch(
        grid=(n_blocks_m * n_blocks_n, 1, 1),
        block=(enigma.size(tv_layout, mode=[0]), 1, 1),
    )

Performance tips

  • Tile size: 16×16 or 32×32 threadgroup tiles hit L1/L2 cache reuse sweet spots.
  • vec_width=4: Vectorize loads with float4 to improve memory bandwidth.
  • Simdgroup matrix: On M3+ hardware, use simdgroup_multiply_accumulate for 8×8 hardware-accelerated matrix multiply.
# Simdgroup matrix variant (M3+ only)
caps = rt.device_capabilities()
caps.require_m3("simdgroup_matrix_multiply_accumulate")

matA = enigma.simdgroup_matrix_load(buf_A, elements_per_row=8)
matB = enigma.simdgroup_matrix_load(buf_B, elements_per_row=8)
matC = enigma.make_filled_simdgroup_matrix(0.0)
matC = enigma.simdgroup_multiply_accumulate(matA, matB, matC)
enigma.simdgroup_matrix_store(matC, buf_C, elements_per_row=8)

See also