Skip to main content
The enigma.ops module provides higher-level operations built on top of the core DSL primitives.

GEMM

enigma.gemm(A, B, C, *, M, N, K, transpose_A=False, transpose_B=False,
            accum_dtype="float", use_simdgroup=None)
Computes C += A @ B over an M x N tile with K reduction. Call inside @enigma.kernel.
ParameterTypeDefaultDescription
A, BTensorrequiredInput tiles (threadgroup or device)
CRegisterTensor or TensorrequiredOutput accumulator
M, N, KintrequiredTile dimensions
transpose_A, transpose_BboolFalseTranspose inputs (scalar path only)
accum_dtypestr"float"Accumulator element type
use_simdgroupbool or NoneNoneNone = auto (simdgroup for 8x8x8, scalar otherwise). True = force simdgroup. False = force scalar
Two lowering paths:
  • Simdgroup MMA (8x8x8): uses simdgroup_matrix_load / simdgroup_multiply_accumulate / simdgroup_matrix_store
  • Scalar fallback (any size): triple for_range with RegisterTensor accumulator

Quantization helpers

Pack / Unpack

packed = enigma.pack_uint8x4(b0, b1, b2, b3)   # 4 x uint8 -> uint32
b0, b1, b2, b3 = enigma.unpack_uint8x4(packed)  # uint32 -> 4 x uint

packed = enigma.pack_int4x2(lo, hi)              # 2 x int4 -> uint8
lo, hi = enigma.unpack_int4x2(packed)             # uint8 -> 2 x int (sign-extended)
FunctionDescription
enigma.pack_uint8x4(b0, b1, b2, b3)Pack four 8-bit values into one uint (LSB-first)
enigma.unpack_uint8x4(packed)Unpack uint into 4 uint lanes
enigma.pack_int4x2(lo, hi)Pack two signed 4-bit ints into one byte
enigma.unpack_int4x2(packed)Unpack byte into two sign-extended ints

Dequantize

enigma.dequantize_int8(x, scale, zero_point=0)
Returns scale * (x - zero_point) as float. Useful for fused-dequant GEMM kernels.