Skip to main content
This page documents the decorators that mark a Python function for GPU compilation, and the built-in queries that return thread and grid indices. For the conceptual overview, see Kernels and JIT.

@enigma.kernel

Marks a function as a GPU compute kernel. The body is traced once at compile time, producing an IR graph that is lowered to Metal Shading Language.
@enigma.kernel
def fn(A: enigma.f32, B: enigma.f32, C: enigma.f32):
    tid = enigma.thread_position_in_grid
    C[tid] = A[tid] + B[tid]

Parameter types

Each parameter must have a type annotation:
AnnotationMeaning
enigma.f32, enigma.f16, enigma.bf16, enigma.i32, …Typed device buffer (device T* in MSL)
enigma.Scalar(dtype)Per-dispatch constant, lowered as a 1-element buffer auto-loaded at entry
Parameters map to MSL [[buffer(N)]] bindings in declaration order, starting at index 0. See Data Types for the full type table.

Returns

A KernelDef object. Do not call directly — pass to enigma.compile().

@enigma.jit

Marks a host-side function that runs at compile time. Use for layout algebra, tile partitioning, and multi-kernel orchestration.
@enigma.jit
def tiled_add(mA: enigma.Tensor, mB: enigma.Tensor, mC: enigma.Tensor):
    ...
Pass Tensor arguments to enigma.compile():
compiled = enigma.compile(tiled_add, A, B, C)
See Kernels and JIT for the full pattern.

Thread & grid queries

These return an IRValue of dtype "uint" representing a thread or group index. They are valid only inside @enigma.kernel bodies.

Shorthand (x dimension)

tid = enigma.thread_position_in_grid   # property — returns x dimension

Per-dimension queries

Each query takes an optional dim argument: "x" (default), "y", or "z".
FunctionMetal equivalentDescription
enigma.thread_position_in_grid_xyz(dim="x")thread_position_in_grid.{x|y|z}Global thread index
enigma.thread_position_in_threadgroup(dim="x")thread_position_in_threadgroup.{x|y|z}Index within threadgroup
enigma.threadgroup_position_in_grid(dim="x")threadgroup_position_in_grid.{x|y|z}Threadgroup index in grid
enigma.threads_per_threadgroup(dim="x")threads_per_threadgroup.{x|y|z}Threads per threadgroup
enigma.threads_per_grid(dim="x")threads_per_grid.{x|y|z}Total threads in grid
enigma.threadgroups_per_grid(dim="x")threadgroups_per_grid.{x|y|z}Threadgroups in grid
enigma.grid_size(dim="x")grid_size.{x|y|z}Alias for threadgroups_per_grid

Flat queries (no dim parameter)

FunctionMetal equivalentDescription
enigma.thread_index_in_threadgroup()thread_index_in_threadgroupFlattened 1D index within threadgroup
enigma.thread_index_in_simdgroup()thread_index_in_simdgroupLane index within SIMD group (0–31)
enigma.simdgroup_index_in_threadgroup()simdgroup_index_in_threadgroupSIMD group index within threadgroup
enigma.threads_per_simdgroup()threads_per_simdgroupThreads per SIMD group (typically 32)
enigma.simdgroups_per_threadgroup()simdgroups_per_threadgroupSIMD groups per threadgroup

Example: 2D grid

@enigma.kernel
def transpose(In: enigma.f32, Out: enigma.f32):
    row = enigma.thread_position_in_grid_xyz("y")
    col = enigma.thread_position_in_grid_xyz("x")
    Out[col * 64 + row] = In[row * 64 + col]

Function constants

Metal specialization constants bound at pipeline creation time. Use these for values that should be compile-time-constant in the pipeline but selectable per dispatch (e.g. tile sizes, fusion flags).

enigma.function_constant(dtype, index) -> IRValue

ParameterTypeDescription
dtypestr"float", "int", "uint", "bool"
indexintFunction constant index, matched at pipeline creation
@enigma.kernel
def scaled_add(A: enigma.f32, B: enigma.f32, C: enigma.f32):
    scale = enigma.function_constant("float", index=0)
    tid = enigma.thread_position_in_grid
    C[tid] = A[tid] * scale + B[tid]

arch namespace

Hardware-feature gating helpers. Use these to write kernels that adapt to the active GPU family.
FunctionReturnsDescription
enigma.arch.is_apple_silicon()boolTrue on arm64 macOS
enigma.arch.gpu_family()stre.g. "apple9"
enigma.arch.supports_simdgroup_matrix()bool8×8 matrix unit availability
enigma.arch.supports_async_copy()boolAIR async copy intrinsics (M3+)
enigma.arch.simdgroup_size()intSIMD group lane count
These are host-side helpers (not callable inside @enigma.kernel). For the runtime-side equivalent, see MetalRuntime.device_capabilities() in Runtime.