AI Systems & Architecture9 min readshipped

Reproducing Warp Decode: What Happens When You Flip MoE Parallelism on Blackwell

Cursor published Warp Decode last week. No code. No independent reproductions. Just a blog post describing a different way to parallelize expert dispatch on a GPU. 1Warp Decode is Cursor's technique for small-batch MoE decode. They described it in a post but released no code and no reruns anyone could check, which is what prompted this reproduction. The claims: 1.84x throughput on B200 with better numerical accuracy than the standard approach.

I built it from scratch on a DGX Spark GB10 (Blackwell, sm_121) and tested it on two models. This post covers the reproduction, the results, why it works better on some architectures than others, and the one thing I tried that failed.

What Warp Decode Does

Standard MoE inference organizes GPU work around experts. For each token, a router picks top-k experts. The implementation gathers all tokens assigned to expert 0, runs that expert's FFN, scatters results back, then moves to expert 1. Repeat for all active experts. This scatter-gather loop has overhead: Python dispatch, dynamic indexing, synchronization between expert batches.

Warp Decode flips the axis. Instead of "which tokens go to this expert?", each GPU thread asks "what's my output value?" Each thread (or warp) owns one output dimension and pulls from whatever experts it needs. No gather. No scatter. No synchronization between experts. Every thread is independent.

Concretely, I implemented two Triton kernels:

Kernel 1 (gate_up): Each program instance owns a block of intermediate neurons for one token-expert pair. It streams over the hidden dimension, accumulating gate and up dot products in FP32 registers, applies the activation function, and writes one chunk of intermediate output.

Kernel 2 (down): Each program instance owns a block of output dimensions for one token. It loops through all top-k routed experts, loading intermediate activations and down-projection weights, folding routing weights into an FP32 accumulator. One write at the end.

The key property: no shared mutable state between programs. The GPU scheduler sees a flat namespace of independent work.

Results: Gemma 4 26B-A4B

128 experts, top-8 routing, 704-dim expert FFN, GELU activation. Brand new model (one week old at the time of testing).

Layer-level benchmarks with error bars (200 iterations, fixed seeds):

BatchHF Eager (ms)Warp Decode (ms)SpeedupCosine Sim
12.39±0.930.54±0.094.42x0.99999976
25.30±1.821.01±0.125.23x0.99999970
48.66±3.041.88±0.154.60x0.99999976
814.16±4.293.01±0.134.71x0.99999970

Average 4.7x at the layer level. Notice the variance: HF eager has CV 0.30-0.39 (the Python per-expert loop is noisy), Warp Decode has CV 0.04-0.17 (Triton kernels are consistent).

Against HF's batched matmul (a stronger baseline that uses torch.bmm with gathered expert weights): 2.6-3.2x speedup.

I also tried torch.compile() on the eager path. It didn't help. Dynamic shapes in the per-expert loop defeat the compiler. At batch=8 it was actually slower than uncompiled.

End-to-end generation (full model, all 30 MoE layers patched, 64 tokens):

ConfigTokens/sec
Stock HuggingFace (eager)11.89
Warp Decode (30 layers patched)16.39
Speedup1.38x

The 4.7x layer-level compresses to 1.38x end-to-end because MoE expert dispatch is roughly 25% of the total forward pass on this model. Attention, dense MLP, embedding, and normalization are unchanged.

Results: Phi-3.5-MoE

16 experts, top-2 routing, 6400-dim expert FFN, SiLU activation. This model tests a very different architecture: fewer experts, lower routing, much larger intermediate dimensions.

BatchHF Eager (ms)Warp Decode (ms)Speedup
12.191.701.29x
25.053.031.67x
47.695.381.43x
811.768.251.43x

Smaller speedup. Top-2 routing means only 2 expert dispatches per token. The overhead Warp Decode eliminates is proportionally smaller. The technique still helps, but modestly.

The Fusion Experiment

Cursor describes keeping "the eight intermediate results" entirely in registers, never writing them to global memory. I tried this: a single fused kernel that computes gate_up intermediate values in registers and immediately uses them for the down projection. No intermediate buffer.

It was slower. On both models. By a lot.

ModelIntermediateUnfused (ms)Fused (ms)Ratio
Gemma 4704-dim, 11 KB/token0.545.100.10x
Phi-3.5-MoE6400-dim, 25 KB/token1.7011.220.15x

I tested with 6400-dim intermediates (Phi) specifically to see if larger buffers would tip the balance toward fusion. They didn't.

I initially thought this was a Triton limitation: Triton programs are thread blocks, not individual warps, so maybe the granularity was wrong. To test this, I wrote the same fused kernel in raw CUDA with explicit __shfl_xor_sync warp-level butterfly reductions. One warp per output dimension, full hardware control, compiled for sm_121.

The CUDA fused kernel was 700x slower than the unfused Triton approach.

The diagnosis: data reuse. The CUDA fused kernel loads hidden states once per intermediate dimension (5,632 times per output dimension across all experts). The Triton unfused kernel loads hidden states once per 64 intermediate dimensions (11 times per expert). That's 64x more redundant memory traffic. For Gemma 4's dimensions, this works out to ~90 GB of redundant reads at ~200 GB/s effective bandwidth, predicting ~450ms. The measured 387ms matches.

The fix would be shared memory tiling: load hidden states once per block, reuse across intermediate dimensions. But that's what Triton's tile-based programming model already does automatically. You'd be reimplementing Triton's data reuse in CUDA, then adding fusion on top.

I then wrote a V2 with shared memory tiling: hidden states loaded once per block, reused across intermediate dimensions. Still 900ms. The hidden state reuse was fixed, but each thread still serializes over 704 intermediate dimensions, computing full dot products of length 2816. The inner loop has no parallelism across intermediate dims.

Triton's unfused approach splits intermediate dimensions across 64 threads working simultaneously. Getting that same parallelism inside a fused CUDA kernel while maintaining the register-local intermediate values requires multi-level tiling: shared memory for hidden states and weight tiles, cooperative thread groups for the dot products, register-level accumulation for the intermediate activations. That's serious kernel engineering, not a weekend experiment.

I iterated through 9 CUDA kernel versions, systematically fixing each bottleneck. The journey from V1 (387ms, 700x slower) to V9 (0.99ms, 1.85x slower) was instructive:

  • 280x came from getting the algorithm right: match Triton's parallelism decomposition (V1 to V6)
  • 1.4x came from vectorized loads: half2, float4, __ldg cache hints (V6 to V9)
  • 1.85x remains: Triton's compiler generates better instruction scheduling and register allocation

One final finding closed the loop. NVIDIA's wmma tensor core API requires matrix dimensions M, N, K all >= 16. At batch=1 decode, the computation is matrix-vector (N=1). Tensor cores can't help. Triton's tl.dot uses vectorized FMA for this shape, not tensor cores. The remaining gap is compiler quality, not hardware utilization.

That's the pragmatic answer.

What Determines Whether Warp Decode Helps

It's the top-k routing count relative to expert count.

ModelExpertsTop-kRatioSpeedup
Gemma 412881:164.7x
Phi-3.5-MoE1621:81.4x

High top-k with many experts means more dispatch overhead per token. Warp Decode eliminates that overhead. Low top-k with few experts means the dispatch is already cheap. Less room to improve.

Intermediate dimension doesn't drive it. Model size doesn't drive it. Routing sparsity does.

The Industry Direction

ModelYearExpertsIntermediateTop-k
Mixtral 8x7B2023814,3362
DeepSeek-V320252562,0488
Gemma 4 26B-A4B20261287048
Qwen3.5-35B20262565128

The trend: many small experts, high top-k, small intermediate dimensions. Every 2026 model follows this pattern. Fine-grained routing with more experts gives better specialization and load balancing.

This is exactly the architecture profile where Warp Decode wins. The technique isn't a novelty for one model generation. It's aligned with where MoE design is heading.

Where These Batch Sizes Actually Matter

Before the vLLM comparison, some context on what batch sizes show up in production. This determines whether our results are relevant or academic.

Decode dominates interactive sessions. For a typical chat interaction, prefill (processing the prompt) takes 10-20% of wall-clock time. Decode (generating each output token sequentially) takes 80-90%. Every decode step is memory-bandwidth-bound, loading entire model weights from HBM. That's where MoE dispatch overhead lives, and where Warp Decode operates.

Production batch sizes by workload:

WorkloadBatch SizeTime PressureOur Speedup
Code completion (Cursor)1-4Sub-100ms latency1.08-1.18x vs vLLM kernel
Interactive chat8-16TTFT-sensitive1.08-1.15x vs vLLM kernel
Multi-tenant serving32-64ThroughputvLLM kernel wins
Batch/offline128-2048Cost per tokenvLLM kernel wins decisively

Sources: vLLM production configs typically set --max-num-seqs 256 but effective decode batch per GPU is 32-64. Google TPU Gemini serving uses batch 30 for decode. Anthropic offers a "low-batch-size" fast mode at 6x cost premium, which tells you the latency-sensitive market is real.

The split isn't even. Roughly 15-25% of production MoE workloads run at B=1-32 (latency-sensitive). The rest optimize for throughput. But that 15-25% is where users directly feel response time, and it's where the willingness to pay for speed is highest.

vs vLLM's Triton Kernel (Fair Comparison)

I extracted vLLM 0.19.0's actual fused_moe_kernel (the @triton.jit function, copied verbatim) and benchmarked kernel-to-kernel, excluding preprocessing. First attempt included my Python reimplementation of moe_align_block_size (4-9ms overhead), which made vLLM look terrible. That was wrong. Isolating just kernel execution time gives the honest picture:

BatchUse CasevLLM kernel (ms)Warp Decode (ms)Winner
1Code completion0.400.37Warp (1.08x)
4Small chat1.531.30Warp (1.18x)
8Chat serving2.291.99Warp (1.15x)
16Chat serving3.653.39Warp (1.08x)
32High concurrency4.555.44vLLM (0.84x)
64Batch inference5.589.57vLLM (0.58x)
128Prefill5.5717.04vLLM (0.33x)

Crossover at batch ~24. Warp Decode wins the decode path (B=1-16, which covers code completion and interactive chat), vLLM's blocked matmul wins the throughput path (B=32+, which covers batch inference and prefill).

This matches Cursor's original framing: "Warp Decode optimizes exclusively for small-batch decode scenarios." It also explains why they built it. Cursor runs code completion (batch 1-4). That's exactly where this technique wins.

For the full MoE forward pass (including Python dispatch overhead), Warp Decode wins at every batch size tested up to 128, with 1.9-5.1x speedup over HF eager. But the kernel-to-kernel comparison is the honest one.

What Would Change for Production

Triton to CUDA. We used Triton for rapid prototyping. Production inference engines (vLLM, TensorRT-LLM) would need compiled CUDA kernels. The algorithm translates directly. The warp-level fusion that Cursor describes would be possible in CUDA but not Triton.

Batch size scaling. We tested batch 1-8 (single-token decode). Larger batches at serving scale need separate benchmarks. The parallelism story changes when you have 32+ tokens.

Quantization interaction. INT4/FP8 expert weights might shift the compute-vs-bandwidth balance. Worth testing.

Full framework integration. We monkey-patched HuggingFace's expert dispatch. A proper integration handles the router, recombination, and memory management together.

Reproducibility

38 correctness tests. Cosine similarity 0.99999976. Max absolute error 3.91e-03. Two models tested. Error bars on all measurements. Fixed random seeds. Code available.

What we don't have: comparison against Cursor's original implementation (no code available), testing on B200 hardware, or integration into a serving framework.

The reproduction validates the core idea and maps the architectural conditions where it helps. It doesn't claim to have matched or beaten the original authors' numbers on their hardware.


Footnote: The Bandwidth Wall

After completing the dispatch optimization work, I ran an automated research loop (52 experiments, two parallel Claude-in-the-loop optimization agents iterating for ~45 minutes) to see what else could be squeezed out.

Both agents independently converged on the same finding: these kernels are bandwidth-bound, not compute-bound. The expert weights dominate memory traffic. Halving precision to INT8 (per-row absmax symmetric quantization) gives a clean 2x speedup with cosine similarity 0.9999. INT4 with group quantization reaches 2.5x but accuracy drops to 0.991.

PrecisionB=1 Layer LatencySpeedupAccuracy
FP16 (Warp Decode)0.57ms1.0x1.0000
INT80.28ms2.01x0.9999
INT4g128 + INT8 mixed0.21ms2.58x0.9913

This isn't novel. vLLM and TensorRT-LLM already support INT8 MoE dispatch. The interesting part is the interaction: Warp Decode's dispatch optimization and weight quantization are complementary. The dispatch pattern reduces overhead from expert routing. Quantization reduces bandwidth from weight loading. They target different bottlenecks and compound.

For practitioners: if your MoE decode is slow, quantize the expert weights first (2x with no engineering). Then look at dispatch patterns (1.1-1.2x with custom kernels). Both help. Neither obsoletes the other.


Related reading on this site: Finding the Right Stack on a DGX Spark for the Blackwell GB10 serving setup this ran on, the Gemma 4 diffusion-drafter build log for more low-level work on the same model family, and AutoResearch on Blackwell: 151 Experiments Overnight for the automated optimization loop behind the bandwidth-wall footnote above.

Follow the lab

Get the next experiment

Enjoyed the breakdown on Reproducing Warp Decode: What Happens When You Flip MoE Parallelism on Blackwell? New entries land roughly weekly. No digest, no roundup. Just the next build log, when it ships.