Skip to main content
This example reduces a large array to a single sum using a two-level approach: SIMD group reductions within each threadgroup, followed by a second pass to combine threadgroup results.

Strategy

Input (N elements)

    ├── Threadgroup 0: 256 threads × simd_sum → 8 partial sums → 1 threadgroup sum
    ├── Threadgroup 1: 256 threads × simd_sum → 8 partial sums → 1 threadgroup sum
    └── ...

    └── Second pass: sum all threadgroup results → final scalar

Level 1: threadgroup reduction kernel

import enigma
import numpy as np
import math


@enigma.kernel
def reduce_threadgroup(
    A:          enigma.f32,
    PartialOut: enigma.f32,
):
    tid       = enigma.thread_position_in_grid
    local_id  = enigma.thread_position_in_threadgroup("x")
    simd_idx  = enigma.thread_index_in_simdgroup()
    simd_gid  = enigma.simdgroup_index_in_threadgroup()
    tg_size   = enigma.threads_per_threadgroup()

    # Allocate shared memory for simd-group partial sums
    # (up to 32 simd groups × 1 float each)
    scratch = enigma.threadgroup_alloc("float", 32)

    # Each thread loads one element
    val = A[tid]

    # Step 1: reduce within simd group
    simd_total = enigma.simd_sum(val)

    # Step 2: lane 0 of each simd group writes its subtotal to shared memory
    with enigma.if_(enigma.cmp_eq(simd_idx, 0)):
        scratch[simd_gid] = simd_total

    # Synchronize before reading shared memory
    enigma.barrier(mem_flags="mem_threadgroup")

    # Step 3: first simd group reduces the simd-group partial sums
    n_simd_groups = tg_size / 32
    with enigma.if_(enigma.cmp_lt(local_id, n_simd_groups)):
        sub = scratch[local_id]
        tg_sum = enigma.simd_sum(sub)
        with enigma.if_(enigma.cmp_eq(local_id, 0)):
            group_id = enigma.threadgroup_position_in_grid("x")
            PartialOut[group_id] = tg_sum

Level 2: final reduction

Once all threadgroup sums are in PartialOut, run a second pass. For small PartialOut arrays (< 256 elements), a single threadgroup is sufficient:
@enigma.kernel
def reduce_final(
    Partial: enigma.f32,
    Result:  enigma.f32,
):
    local_id = enigma.thread_position_in_threadgroup("x")
    simd_idx = enigma.thread_index_in_simdgroup()
    simd_gid = enigma.simdgroup_index_in_threadgroup()

    scratch = enigma.threadgroup_alloc("float", 32)

    val = Partial[local_id]
    simd_total = enigma.simd_sum(val)

    with enigma.if_(enigma.cmp_eq(simd_idx, 0)):
        scratch[simd_gid] = simd_total

    enigma.barrier(mem_flags="mem_threadgroup")

    with enigma.if_(enigma.cmp_eq(local_id, 0)):
        final = enigma.f32(0.0)
        n_groups = enigma.simdgroup_index_in_threadgroup()  # total simd groups
        # single-thread serial sum of simd-group totals
        with enigma.for_range(0, 8) as i:
            final = final + scratch[i]
        Result[0] = final

Full dispatch

compiled_pass1 = enigma.compile(reduce_threadgroup)
compiled_pass2 = enigma.compile(reduce_final)

rt = enigma.MetalRuntime()

n = 1 << 20        # 1M elements
threads_per_tg = 256
n_threadgroups = n // threads_per_tg

a = np.random.randn(n).astype(np.float32)

# Pass 1: threadgroup reductions
partial_raw = rt.execute(
    compiled_pass1,
    inputs=[a],
    output_size=n_threadgroups * 4,
    grid=(n, 1, 1),
    threads=(threads_per_tg, 1, 1),
)
partial = np.frombuffer(partial_raw, dtype=np.float32).copy()

# Pass 2: final reduction (single threadgroup)
result_raw = rt.execute(
    compiled_pass2,
    inputs=[partial],
    output_size=4,
    grid=(n_threadgroups, 1, 1),
    threads=(n_threadgroups, 1, 1),
)
result = np.frombuffer(result_raw, dtype=np.float32)[0]

# Validate
expected = float(a.sum())
print(f"GPU: {result:.4f}  CPU: {expected:.4f}  err: {abs(result - expected):.4f}")

Notes

  • SIMD group size is hardware-fixed at 32. The threadgroup size should be a multiple of 32.
  • This two-pass approach works for any N that is a multiple of threads_per_tg.
  • For padded inputs (N not a power of 2), use load_if with a bounds mask in the first pass.

See also