enigma.ops module provides higher-level operations built on top of the core DSL primitives.
GEMM
C += A @ B over an M x N tile with K reduction. Call inside @enigma.kernel.
| Parameter | Type | Default | Description |
|---|---|---|---|
A, B | Tensor | required | Input tiles (threadgroup or device) |
C | RegisterTensor or Tensor | required | Output accumulator |
M, N, K | int | required | Tile dimensions |
transpose_A, transpose_B | bool | False | Transpose inputs (scalar path only) |
accum_dtype | str | "float" | Accumulator element type |
use_simdgroup | bool or None | None | None = auto (simdgroup for 8x8x8, scalar otherwise). True = force simdgroup. False = force scalar |
- Simdgroup MMA (8x8x8): uses
simdgroup_matrix_load/simdgroup_multiply_accumulate/simdgroup_matrix_store - Scalar fallback (any size): triple
for_rangewithRegisterTensoraccumulator
Quantization helpers
Pack / Unpack
| Function | Description |
|---|---|
enigma.pack_uint8x4(b0, b1, b2, b3) | Pack four 8-bit values into one uint (LSB-first) |
enigma.unpack_uint8x4(packed) | Unpack uint into 4 uint lanes |
enigma.pack_int4x2(lo, hi) | Pack two signed 4-bit ints into one byte |
enigma.unpack_int4x2(packed) | Unpack byte into two sign-extended ints |
Dequantize
scale * (x - zero_point) as float. Useful for fused-dequant GEMM kernels.