compiled_pass1 = enigma.compile(reduce_threadgroup)
compiled_pass2 = enigma.compile(reduce_final)
rt = enigma.MetalRuntime()
n = 1 << 20 # 1M elements
threads_per_tg = 256
n_threadgroups = n // threads_per_tg
a = np.random.randn(n).astype(np.float32)
# Pass 1: threadgroup reductions
partial_raw = rt.execute(
compiled_pass1,
inputs=[a],
output_size=n_threadgroups * 4,
grid=(n, 1, 1),
threads=(threads_per_tg, 1, 1),
)
partial = np.frombuffer(partial_raw, dtype=np.float32).copy()
# Pass 2: final reduction (single threadgroup)
result_raw = rt.execute(
compiled_pass2,
inputs=[partial],
output_size=4,
grid=(n_threadgroups, 1, 1),
threads=(n_threadgroups, 1, 1),
)
result = np.frombuffer(result_raw, dtype=np.float32)[0]
# Validate
expected = float(a.sum())
print(f"GPU: {result:.4f} CPU: {expected:.4f} err: {abs(result - expected):.4f}")