@enigma.jit
def tiled_gemm(mA: enigma.Tensor, mB: enigma.Tensor, mC: enigma.Tensor):
# Thread-value layout: 256 threads, each handling a 4×4 tile
thr = enigma.make_ordered_layout((4, 64), order=(1, 0))
val = enigma.make_ordered_layout((4, 4), order=(1, 0))
tiler_mn, tv_layout = enigma.make_layout_tv(thr, val)
# Partition tensors into block tiles
gA = enigma.tensor_zipped_divide(mA, tiler_mn)
gB = enigma.tensor_zipped_divide(mB, tiler_mn)
gC = enigma.tensor_zipped_divide(mC, tiler_mn)
@enigma.kernel
def inner(blkA, blkB, blkC, tv):
thread_idx = enigma.thread_position_in_grid
# Per-thread fragments
thrA = enigma.tensor_composition(blkA, tv, tiler_mn)[(thread_idx, None)]
thrB = enigma.tensor_composition(blkB, tv, tiler_mn)[(thread_idx, None)]
thrC = enigma.tensor_composition(blkC, tv, tiler_mn)[(thread_idx, None)]
# Load, compute, store
a_frag = thrA.load()
b_frag = thrB.load()
c_frag = thrC.load()
# (accumulate into c_frag — simplified here)
thrC.store(c_frag)
n_blocks_m = enigma.size(gA, mode=[1])
n_blocks_n = enigma.size(gB, mode=[1])
inner.launch(
grid=(n_blocks_m * n_blocks_n, 1, 1),
block=(enigma.size(tv_layout, mode=[0]), 1, 1),
)