# Teaching a Small Model to Draft Like DFlash: A Build Log on Gemma 4 E2B as a Diffusion Drafter
> [!tip] TLDR
> Adapt Gemma 4 E2B from a causal LM into a masked-diffusion drafter via LoRA, with bidirectional attention inside the denoising block. Bidirectional attention is what carries the gains, not longer training.
>
> Reproducible setup prompt for Claude Code or any coding assistant: *"On a single GPU box, load `google/gemma-4-E2B` via `Gemma4ForConditionalGeneration`, add a `<|diffusion_mask|>` special token, attach LoRA (r=16, α=32) to q/k/v/o + MLP projections under `model.language_model.layers.*`. Train masked-token cross-entropy at masked positions only (don't pass labels to the HF forward, compute loss yourself). Compare two runs on FineWeb-Edu, 500 steps each: causal-only attention vs a 4D float attention mask that's bidirectional inside a denoising block and causal on the prefix. Build the mask in `next(model.parameters()).dtype` to avoid an SDPA dtype error. Eval reconstruction at 15/25/50% mask rates and generation under both confidence and cosine reveal schedules."*
>
> Read on for what worked, what failed, and what the artifacts told me.
---
## Why I'm doing this
Last week z-lab dropped **Qwen3.6-27B-DFlash**: Qwen3 with diffusion-based speculative decoding. The pitch is simple. Speculative decoding speeds up LLMs by having a small model draft tokens and a big model verify them in parallel. DFlash replaces the typical autoregressive drafter with a diffusion process. Cleaner draft generation, better acceptance rates, faster end-to-end inference.
The 27B parameter version is a research artifact. I wanted to know: **can the same idea work on a small model I can train in an afternoon?**
This is the build log. I wrote it as the experiments ran, and the lessons accumulated as the data came in.
---
## The setup
One DGX Spark. Gemma-4-E2B (E2B = ~2B effective params, the lighter Gemma-4 sibling). LoRA adapters, r=16, ~24M trainable params on top of 5B frozen. FineWeb-Edu as the corpus. 500 training steps to start. Two configurations:
- **Level 1**: causal-only masked denoising. Baseline. Mask 25% of tokens with span masking, predict them, only attend leftward.
- **Level 2**: blockwise bidirectional 4D attention. Same data, same budget, same hyperparameters. Only difference: the denoising block can see both directions.
The two configurations test one specific thing. **Does bidirectional context during denoising actually matter, or do we just need more training?**
If L2 doesn't beat L1, then DFlash's whole bidirectional-attention story is wrong, and the speedup is coming from somewhere else. If L2 wins, we know the attention bias is doing real work.
---
## What I had to figure out before training
### 1. Does Gemma-4 even accept a 4D float attention mask?
This is the question that decides whether the experiment is possible at all. Most causal LMs build their attention mask internally and don't take a custom one. If Gemma-4 ignores my bidirectional mask, I'm back to subclassing attention modules (a much bigger project).
I wrote a probe script. Pass two custom 4D masks (one fully bidirectional, one blockwise causal-then-bidirectional) and see whether the model uses them. Both passed. Gemma-4 accepts custom 4D masks via SDPA without monkeypatching.
**Lesson:** Always probe the model's actual interface before building on assumptions. This saved me probably two days.
### 2. The dtype trap
First training run died on a dtype mismatch. The model runs in bf16. My masks were fp32. SDPA refused them.
Fix: build the mask in `next(model.parameters()).dtype`. One-liner. But the error message is buried in C++ stack traces, so it took an embarrassing amount of time to diagnose.
### 3. The multimodal weight-loading surprise
Loading `Gemma4ForCausalLM` showed a flood of MISSING weights. Turns out Gemma-4's checkpoint is structured for `Gemma4ForConditionalGeneration` (multimodal), with weights under `model.language_model.layers.*` instead of `model.layers.*`. Loading via the conditional class fixed it.
Related gotcha: a naive LoRA target spec like `["q_proj", "k_proj", ...]` would attach LoRA to vision_tower and audio_tower projections too. I had to enumerate the full module paths under `model.language_model.layers.*` so the LoRA stays in the language model.
### 4. The masked CE loss had to be recomputed
If you pass `labels=` to a Hugging Face causal LM, it computes shifted next-token loss. That's not what we want for masked denoising. I needed cross-entropy at *masked positions*, not over a left-shifted sequence.
Fix: don't pass labels. Compute the loss yourself over the loss_mask.
These four pre-flight items took a morning. Without them, the experiment doesn't even start. None of them are in any tutorial.
---
## What L1 told me
L1 trained for 500 steps, 90 minutes. Final loss 4.14, masked accuracy 28.8%.
Reconstruction held up reasonably:
| Mask rate | Accuracy | NLL |
|-----------|----------|-----|
| 15% | 44.2% | 2.77 |
| 25% | 35.3% | 3.42 |
| 50% | 22.2% | 4.66 |
Generation was the surprise. I gave it four prompts (a narrative, a code task, a science infill, a domain question). It recovered the topic on every one. K-Pg extinction prompt? It came back with Yucatan, K-Pg boundary, impact crater, dinosaurs. All correct, all there.
Then it produced this:
> the impact of asteroid that caused the the the Yucatan peninsula, Mexico, about 6. million the extinction of the dinosaurs., which the impact of dinosaurs, the impact crater the extinction and the the Cretaceous.
The semantic recall is real. The syntax is destroyed. Repeated articles. Missing predicates. Broken parentheticals. Every prompt produced the same pattern.
The "the the the" artifact is interesting. It's not random. It's the confidence sampler's failure mode: when the model is uncertain about most tokens but very confident articles belong somewhere, it greedily reveals every article position simultaneously. The result is determiner soup.
This is exactly what you'd predict if the causal constraint is the bottleneck. Article choice depends on the noun that follows (a vs an, the vs a). With only left context, the model can't disambiguate, so it defaults to "the" everywhere and fills neighboring high-confidence slots with the same fallback.
---
## What L2 told me
L2 trained for 500 steps, 89 minutes. Same wall-clock as L1 to within a minute.
Final training metrics:
| Metric | L1 | L2 |
|--------|------|------|
| Wall time | 90.4 min | 89.1 min |
| Final loss | 4.14 | **3.19** |
| Final accuracy | 28.8% | **39.5%** |
| Best accuracy | 33.4% (s320) | **43.1% (s440)** |
The training curves diverged from step one. L2's starting loss was 7.07 vs L1's 9.09. Same model, same init, same data, same hyperparameters. The only difference is the attention mask. That gap closes briefly around step 50, then widens for the rest of training. By step 100, L2 is at 28% accuracy while L1 is still at 20%. By step 500, L2 is at 39.5% while L1 finished at 28.8%.
Then I ran the reconstruction eval. This is where it stopped being interesting and started being decisive.
| Mask rate | L1 accuracy | L2 accuracy | Absolute Δ | Relative Δ |
|-----------|-------------|-------------|------------|-----------|
| 15% | 44.2% | **65.9%** | +21.8% | +49.3% |
| 25% | 35.3% | **54.8%** | +19.4% | +55.0% |
| 50% | 22.2% | **34.9%** | +12.6% | +56.8% |
NLL drops by 30-47%. Span recovery roughly doubles. The relative improvement is *larger* at 50% masking than at 15%, which matches the prediction: bidirectional context is most valuable when the left side alone is uninformative.
Read the table again. **L2 at 50% masking (34.9%) is roughly the same as L1 at 25% masking (35.3%).** Bidirectional attention buys you an entire mask-rate jump. The model under L2 can recover half of a corrupted sequence about as well as the L1 model recovers a quarter.
This is bigger than I expected. I went in thinking +5% accuracy would validate the hypothesis. Getting +20% is a different story.
---
## Why this matters for drafting
The reason reconstruction accuracy is the right metric for a drafter isn't obvious. Speculative decoding works like this: the drafter proposes K tokens, the verifier checks them, accepted tokens stay, rejected tokens are replaced. The acceptance rate is what determines the speedup.
**Higher reconstruction accuracy translates directly to higher acceptance rate.** A drafter that gets 65% of masked tokens right is going to get more candidates accepted than one that gets 44% right, all else equal.
DFlash's whole pitch is that diffusion-based drafting is competitive with autoregressive drafting because the diffusion process is parallelizable (faster) and the bidirectional context produces better proposals. My L2 results say: yes, the bidirectional context part is real. The 50% relative improvement on reconstruction is the kind of gap that would show up as meaningful speedup in a real speculative decoding pipeline.
---
## What L2 generation actually looks like
The reconstruction story was clean. Generation is messier.
The "the the the" artifact is gone. That part of the hypothesis held. But a new failure mode showed up.
Here's L2 on the K-Pg extinction prompt:
> the impact an asteroid caused to the extinction in at 065 million ago and only a effect about 5,0% destruction of the mass of asteroid
Compare to L1 on the same prompt:
> the impact of asteroid that caused the the the Yucatan peninsula, Mexico, about 6. million the extinction of the dinosaurs., which the impact of dinosaurs, the impact crater the extinction and the the Cretaceous.
L2 produces a coherent first clause. "The impact an asteroid caused to the extinction." Articles are placed correctly. The number is recovered as "065 million ago" (the number 65 is right, the formatting is broken). It even tries to attach the percentage from the prompt: "5,0% destruction" (it should be 75%, but it remembered there is a percentage, and tried to write one).
Then it falls apart. The remaining 60 tokens are stray digits, commas, and whitespace.
The other three prompts show the same pattern:
- **Narrative:** "was everybody in the, the little child. There, little,,,,,,..." opens grammatically, then collapses into commas
- **Code:** "def fib(n): n00n): return / fib" better Python structure than L1, then total whitespace collapse
- **Science (doublets):** "it affects the quality of the data. For ,000 cells000 50 cells..." first six tokens are publishable, then digit and comma soup
So L2 fixed one artifact and revealed another. The model produces a few legitimate tokens at the start of the generation block, then the iterative confidence sampler collapses into a different fallback: commas, digits, whitespace. The pattern shifted from "high-frequency articles" to "high-frequency punctuation and digits."
This tells me the failure isn't really about the attention mask. It's about the **iterative confidence sampler** itself. When the model is uncertain about most positions but very confident a few high-frequency tokens belong somewhere, the sampler greedily reveals all of them. With L1, those were articles. With L2, the article problem is solved by bidirectional context, but the sampler still has to commit on uncertain positions, and now it falls back to even-higher-frequency tokens (commas, digits, spaces).
**The real lesson:** Bidirectional attention fixes the *prediction* problem (validated by reconstruction). The remaining issue is *decoding*. The confidence sampler is too greedy with low-information tokens, and that's a separate problem.
---
## The decoder ablation: it was the sampler, not the model
L2 generation looked broken. The K-Pg sample produced six coherent tokens then collapsed into commas. I had two hypotheses:
1. The model is undertrained. Need more steps.
2. The iterative confidence sampler is committing to high-frequency filler tokens too eagerly.
The cheap test: hold L2 fixed and sweep nine decoder configurations. Ninety seconds of compute per run, twelve minutes total. Same model, same checkpoint, just different ways of revealing masked positions.
The variants:
| Label | Temp | top_p | Schedule | Sample |
|-------|------|-------|----------|--------|
| A_baseline | 1.0 | 1.0 | confidence | yes |
| B_temp_low | 0.7 | 1.0 | confidence | yes |
| C_topp_90 | 1.0 | 0.9 | confidence | yes |
| D_temp07_topp90 | 0.7 | 0.9 | confidence | yes |
| E_linear_reveal | 1.0 | 1.0 | linear | yes |
| F_cosine_reveal | 1.0 | 1.0 | cosine | yes |
| G_8_steps | 1.0 | 1.0 | confidence | yes (8 steps) |
| H_32_steps | 1.0 | 1.0 | confidence | yes (32 steps) |
| I_argmax | 1.0 | 1.0 | confidence | no |
The "schedule" column is the one that matters. Confidence-based reveal sorts masked positions by model confidence and reveals the easy ones first. Linear and cosine reveal positions in their natural left-to-right order (cosine just front-loads more of them in early steps).
I scored each variant by **filler rate**: the fraction of output characters that are commas, digits, or whitespace. The collapse signature.
| Variant | Avg filler rate |
|---------|-----------------|
| A_baseline (confidence) | 0.665 |
| B_temp_low | 0.709 |
| C_topp_90 | 0.581 |
| D_temp07_topp90 | 0.735 |
| **E_linear_reveal** | **0.216** |
| **F_cosine_reveal** | **0.157** |
| G_8_steps | 0.623 |
| H_32_steps | 0.648 |
| I_argmax | 0.686 |
Switching from confidence-based reveal to cosine reveal dropped filler from 67% to 16%. A 76% reduction. No model retraining. Just a one-line change in the sampler.
Same K-Pg prompt under cosine reveal:
> a large meteorite caused asteroid impact to strike Earth collision mass Texas7 mileskilometer miles accelerating±meters Formation energy the critical creation the kinetic energy quad the conservation
Still not good prose. But it's recognizably text. The collapse to commas and digits is gone. The model is producing content tokens for the entire 96-token window.
C_topp_90 (confidence reveal + top-p filtering) was the most semantically coherent variant on the K-Pg prompt:
> a that asteroid or asteroid- struck on the Yucatán peninsula, 5000 0 00 of the Earth, and the release of a huge cloud of of dust and, the vapor heat of the impact
Yucatán peninsula. Cloud of dust. Heat of impact. Real words attaching to real concepts. Still has a filler tail, but the front of the generation is doing actual work.
I_argmax (greedy) was the most useful negative result: it produced "the the the the the the" twenty times in a row on the K-Pg prompt, with a 92.6% bigram repetition rate. Greedy decoding plus confidence reveal is a recipe for the exact L1 failure mode, even on the L2 model. **This confirms the artifact is decoder-driven, not model-driven.**
### Why confidence reveal fails
The confidence sampler reveals masked positions in order of how certain the model is about them. Early in sampling, almost everything is masked. The model attends to the prompt and assigns high probability to "obvious" positions: commas after natural clause breaks, digits in numeric contexts, articles before nouns. The sampler reveals those first.
By the time the harder positions are being revealed, the easy positions have already committed to filler tokens. The model now has to fill in content positions surrounded by commas and digits, which are not informative context. Predictions become low-confidence, and the model defaults to the lowest-cost token: more commas.
Linear and cosine reveal break this loop. They commit to positions in order, regardless of confidence. The model has to produce a content token at position five before it can produce one at position six. The natural "filler token first" pathology never gets a foothold.
**The decoder fix is bigger than the model fix.** Bidirectional attention added ~50% relative reconstruction accuracy. Cosine reveal cut the standalone generation collapse by 76%. Both matter, but they fix different things.
---
## What 2000 steps bought us
After the L1/L2/decoder triad, the open question was whether L2 was undertrained. The reconstruction curve was still climbing at step 500. I let it run another 1500 steps with bidirectional attention and cosine reveal as the eval default. Total wall-clock: 5h 56m on a single DGX Spark.
Results:
| Mask rate | L2-500 acc | L2-2k acc | Absolute Δ | Relative Δ |
|-----------|------------|-----------|------------|-----------|
| 15% | 65.9% | **70.9%** | +5.0% | +7.6% |
| 25% | 54.8% | **61.2%** | +6.5% | +11.8% |
| 50% | 34.9% | **39.5%** | +4.6% | +13.2% |
NLL dropped further too: 1.48 → 1.25 at 15% mask, 2.16 → 1.85 at 25%, 3.62 → 3.24 at 50%. Span recovery climbed by 5-7 percentage points across all rates. The model is still learning at step 2000. Not plateaued.
But here's the part that matters for the broader story. **Generation under cosine reveal didn't get qualitatively better.** Filler control is still good. The samples are recognizably text. But coherent prose isn't there yet:
L2-2k on the K-Pg prompt under cosine reveal:
> asteroid impact which caused which caused and dinosaur for or succession the the dinosaurs into extinction Periodmostly of fossil left and was
Topic words are correct (asteroid, impact, dinosaur, extinction, Period, fossil). Syntax is still broken. Compare to a strict autoregressive baseline on the same prompt:
> that an asteroid or comet impact, with a diameter of 10-20 kilometers, wiped out the dinosaurs and the mass of life on Earth
The AR baseline produces fluent prose. The diffusion drafter produces a topic-coherent token salad. **Reconstruction climbed by another 5-6 points absolute. Generation didn't cross the line into coherent prose.** That's a meaningful update on what 4× the compute buys you with this recipe.
The honest read: at this scale and budget, more training is improving the *prediction* layer (reconstruction) but not unlocking the *generation* layer (sampling produces text but not sentences). Either we need substantially more training, or the architecture (causal base + LoRA + bidirectional mask) caps out short of fluent generation. I'd guess the latter without a proper bidirectional pretraining stage, but that's beyond this POC.
For drafting, this is fine. The verifier doesn't need fluent prose, it needs accurate token proposals. Reconstruction at 70/61/40% is the right metric for that, and L2-2k is a stronger drafter than L2-500.
---
## What this means for drafting
If I were going to use this as an actual speculative drafter, I'd care about three things:
1. **Acceptance rate of proposals:** L2-2k's reconstruction accuracy (71/61/40% across mask rates) is the upper bound on how well a verifier would accept its tokens. That's a useful drafter today, with a clear path to better via more training.
2. **Coherent drafts:** Less important than acceptance rate. The verifier doesn't care if the draft reads well, only if it's correct often enough.
3. **Decoding strategy:** The confidence sampler's failure mode in standalone generation is a problem for using L2 as a *standalone* generator, but probably less of a problem inside a speculative decoding loop where the verifier corrects errors as it goes.
So the takeaway is more nuanced than I expected. L2-2k isn't a good standalone generator. But it might already be a useful drafter, and the path to better generation runs through fixing the decoder, not the model.
---
## Where this fits relative to DFlash
DFlash is a production-quality 27B drafter trained at scale, integrated with flash attention kernels, and benchmarked against high-throughput baselines. What I'm doing is the toy version: a 5B model with a 24M-param adapter, trained for 6 hours on FineWeb-Edu, evaluated on 16 sequences and four prompts.
The point isn't to compete. The point is to understand the components. If I can't make a small bidirectional diffusion drafter produce plausible draft tokens in an afternoon, then DFlash is doing something I don't understand yet. If I can, I have a working baseline I can scale up and instrument.
The fastest way to understand a system is to build the smallest version that exhibits the behavior you care about, then watch it fail.
---
## Things I want to remember
- **Probe the model's interface before assuming.** The 4D mask probe was 30 minutes that saved days. Most causal LMs build attention masks internally and ignore custom ones.
- **Pre-flight checks for dtype, weight loading, and loss computation are non-negotiable.** None of them are documented anywhere. Multimodal weight layouts are a particular trap.
- **LoRA target enumeration matters.** Leaf-name targets will adapt the wrong modules in multimodal models. Always specify full module paths.
- **L1 is the most useful baseline.** Without it, "L2 is good" means nothing. The story is L1's specific failures and how L2 fixes them, not "L2 has high accuracy."
- **The artifact is the data.** L1's "the the the" output told me more about the system than any metric. When you see a weird failure mode, look at it carefully. It's almost always pointing at a specific bottleneck.
- **Don't conflate model and decoder.** Bad output can come from either. The cheapest test is to hold one fixed and sweep the other. Twelve-minute decoder ablation eliminated half my retraining backlog.
- **Confidence-based reveal is a trap when masked tokens are correlated with high-frequency fillers.** It commits to easy positions first, which leaves the hard ones surrounded by uninformative context. Linear or cosine reveal breaks the trap.
- **Reconstruction and generation are different things.** Four times the compute kept reconstruction climbing but didn't produce fluent generation. For a drafter, that's fine. For a standalone generator, it isn't.
- **Compute budget on a single DGX Spark is plenty for this kind of sequential A/B.** L1 + L2 + decoder ablation in 3.5 hours. The 2000-step run added another 6. The bottleneck for these experiments isn't GPU time, it's having a clear hypothesis and a tight feedback loop.
---
## Status snapshot
| Phase | Status |
|-------|--------|
| L1 (causal denoising, 500 steps) | Complete |
| L2 (bidirectional, 500 steps) | Complete |
| Reconstruction evals (L1/L2) | Complete (L2 +50% relative across mask rates) |
| Generation evals (L1/L2) | Complete (L2 fixed article repetition, exposed filler collapse) |
| Decoder ablation (9 variants × 4 prompts) | Complete (cosine reveal cuts filler 76%) |
| L2 long (2000 steps, cosine reveal) | Complete (reconstruction +5-6 pts absolute, generation still not fluent) |
| Speculative decoding integration | Queued |
| Block-ratio ablation | Queued |
The two open follow-ups are: wire L2-2k as a drafter behind a 9B verifier and measure acceptance rate (the actual production application), and ablate the prefix/denoising block ratio (currently 50/50, picked arbitrarily). Both are afternoon experiments on the same hardware.
---
### Related Articles
- [[Cutting-Edge AI/gemini-diffusion-google-deepmind-analysis|Gemini Diffusion: What if Text Generators Worked Like Stable Diffusion for Words?]]
- [[Cutting-Edge AI/deepseek-v3-0324-technical-review|DeepSeek V3 Technical Review]]
- [[AI Development & Agents/turboquant-kv-quantization|TurboQuant KV Quantization on DGX Spark]]
---
<p style="text-align: center;"><strong>About the Author</strong>: Justin Johnson builds AI systems and writes about practical AI development.</p>
<p style="text-align: center;"><a href="https://justinhjohnson.com">justinhjohnson.com</a> | <a href="https://twitter.com/bioinfo">Twitter</a> | <a href="https://www.linkedin.com/in/justinhaywardjohnson/">LinkedIn</a> | <a href="https://rundatarun.io">Run Data Run</a> | <a href="https://subscribe.rundatarun.io">Subscribe</a></p>