Skip to main content
Elementwise addition is the hello-world of GPU programming. This example covers the full dispatch path.

Complete code

import enigma
import numpy as np

# --- kernel definition ---

@enigma.kernel
def vector_add(A: enigma.f32, B: enigma.f32, C: enigma.f32):
    tid = enigma.thread_position_in_grid
    C[tid] = A[tid] + B[tid]


# --- compile ---

compiled = enigma.compile(vector_add)
print("kernel:", compiled.kernel_name)


# --- dispatch ---

rt = enigma.MetalRuntime()
n = 1 << 20  # 1M elements

a = np.random.randn(n).astype(np.float32)
b = np.random.randn(n).astype(np.float32)

raw = rt.execute(
    compiled,
    inputs=[a, b],
    output_size=n * 4,
    grid=(n, 1, 1),
    threads=(256, 1, 1),
)
out = np.frombuffer(raw, dtype=np.float32)


# --- validate ---

np.testing.assert_allclose(out, a + b, rtol=1e-5, atol=1e-7)
print("pass")

What happens at each step

Compilation

enigma.compile(vector_add) traces the Python function body once with symbolic IRValue placeholders. The result is Metal source like:
kernel void enigma_kernel_vector_add(
    device const float* A [[buffer(0)]],
    device const float* B [[buffer(1)]],
    device float* C       [[buffer(2)]],
    uint tid [[thread_position_in_grid]]
) {
    C[tid] = A[tid] + B[tid];
}

Dispatch parameters

ParameterValueMeaning
grid=(n, 1, 1)1M threadsOne thread per element
threads=(256, 1, 1)256 threads/threadgroup256/32 = 8 SIMD groups
output_size=n * 44MB1M floats × 4 bytes

Vectorized variant

Use vec_width=4 to process 4 elements per thread with float4 instructions:
compiled = enigma.compile(vector_add, vec_width=4)

raw = rt.execute(
    compiled,
    inputs=[a, b],
    output_size=n * 4,
    grid=(n // 4, 1, 1),   # 4× fewer threads
    threads=(256, 1, 1),
)
The kernel body is unchanged — vec_width is handled by the compiler.

Benchmarking

prepared = rt.prepare(compiled, inputs=[a, b], output_size=n * 4)

# Warm up
prepared.dispatch(grid=(n, 1, 1), threads=(256, 1, 1))

# Measure
times = [prepared.dispatch_timed(grid=(n, 1, 1), threads=(256, 1, 1)) for _ in range(20)]
times.sort()
print(f"median: {times[len(times)//2]:.3f} ms")
print(f"min:    {times[0]:.3f} ms")

See also