import enigma
import numpy as np
# --- kernel definition ---
@enigma.kernel
def vector_add(A: enigma.f32, B: enigma.f32, C: enigma.f32):
tid = enigma.thread_position_in_grid
C[tid] = A[tid] + B[tid]
# --- compile ---
compiled = enigma.compile(vector_add)
print("kernel:", compiled.kernel_name)
# --- dispatch ---
rt = enigma.MetalRuntime()
n = 1 << 20 # 1M elements
a = np.random.randn(n).astype(np.float32)
b = np.random.randn(n).astype(np.float32)
raw = rt.execute(
compiled,
inputs=[a, b],
output_size=n * 4,
grid=(n, 1, 1),
threads=(256, 1, 1),
)
out = np.frombuffer(raw, dtype=np.float32)
# --- validate ---
np.testing.assert_allclose(out, a + b, rtol=1e-5, atol=1e-7)
print("pass")