pith. sign in

arxiv: 2605.15422 · v1 · pith:QN5VUKGEnew · submitted 2026-05-14 · 💻 cs.LG

DualKV: Shared-Prompt Flash Attention for Efficient RL Training with Large Rollouts and Long Contexts

Pith reviewed 2026-05-19 15:44 UTC · model grok-4.3

classification 💻 cs.LG
keywords DualKVFlashAttentionRL trainingshared promptcausal maskingGRPOtraining efficiencylong context
0
0 comments X

The pith

DualKV processes shared prompts only once in FlashAttention for RL training by exploiting causal masking invariance, matching standard attention exactly.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper shows that in decoder-only models for RL post-training, prompt representations stay identical across multiple response sequences thanks to causal masking. This invariance allows all prompt-related computations to run only once instead of being repeated for each rollout. DualKV provides the first kernel-level implementation with fused CUDA forward and backward passes that handle the shared prompt and individual responses separately. A data repacking step further reduces the total tokens the model processes from N(P+R) to P+NR. Because the method matches standard attention exactly, it delivers speedups like 1.63x to 2.09x on Qwen3-8B without any accuracy trade-off.

Core claim

DualKV is mathematically equivalent to standard attention and introduces no approximation. It achieves this by using fused CUDA kernels that process two disjoint KV regions in one launch: the shared prompt context once and the per-sequence responses individually, paired with a data-pipeline redesign that repacks tokens to extend the savings beyond attention to the full model.

What carries the argument

Fused CUDA forward and backward kernels iterating over shared context and per-sequence response KV regions, enabled by prompt representation invariance under causal masking.

If this is right

  • 1.63–2.09× policy-update speedup on Qwen3-8B GRPO with N=32 and 8K context
  • Enables 2× larger micro-batches
  • Raises MFU from 36% to 76%
  • Achieves 3.82× policy-update and 3.38× end-to-end speedup at 30B MoE scale

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • The same invariance could be exploited in inference for batched requests with common prefixes
  • Data repacking might benefit other sequence-parallel training setups
  • Future kernels could generalize this to variable prompt lengths or mixed shared contexts

Load-bearing premise

Causal masking in decoder-only models makes prompt representations invariant across sequences at every layer.

What would settle it

Compute prompt hidden states in a multi-response batch and compare them to states from isolated single-response forward passes; exact match would confirm the invariance.

Figures

Figures reproduced from arXiv: 2605.15422 by Bernie Wang, George Karypis, Jiading Gai, Shuai Zhang, Xiang Song.

Figure 2
Figure 2. Figure 2: LongReason training (Qwen3-8B, 8×H100), four configurations. Top row: GRPO (N=32, steps 1–24). Bottom row: DAPO (N=32, steps 1–25). Panels: (A) Training reward — all configs track identically. (B) Peak memory. (C) Policy-update latency. (D) Policy-update MFU. GRPO: DualKV mb=8 achieves 2.09× speedup, 75.8% MFU. DAPO: 2.47× speedup, 77.4% MFU. all configurations where both run. At high micro-batch (P=5K, mb… view at source ↗
Figure 1
Figure 1. Figure 1: Per-phase step breakdown. Setup. We train Qwen3-8B on a single p5.48xlarge in￾stance (8×H100-SXM5-80GB) with FSDP2 in BF16 and gradient checkpointing. Rollout generation uses vLLM with tensor parallelism = 2 and N=32 responses per prompt. Training uses train_batch_size = 128 and ppo_mini_batch_size = 64 (both fixed across all configurations), and we sweep ppo_micro_batch_size_per_gpu ∈ {4, 8}. Each configu… view at source ↗
Figure 3
Figure 3. Figure 3: Per-step training accuracy. To validate DualKV at production scale with cross-node communication and MoE expert routing, we train Qwen3-30B-A3B [Yang et al., 2025] with GRPO on two p5.48xlarge nodes (16×H100 GPUs, Elastic Fabric Adapter): FSDP2 + BF16 + gradient checkpointing; rollout via vLLM (TP=2, N=32). All runs: train_batch_size=128, ppo_mini_batch_size=64, ppo_micro_batch_size_per_gpu=8, Pmax=8192, R… view at source ↗
Figure 4
Figure 4. Figure 4: Memory scaling: DualKV vs. FA2. Memory decouples from micro-batch size. DualKV’s memory follows M0 +cP · P + cmbR · mb · R: the prompt cost is paid once regardless of micro-batch size, so the three mb curves cluster tightly ( [PITH_FULL_IMAGE:figures/full_fig_p009_4.png] view at source ↗
Figure 5
Figure 5. Figure 5: DAPO training on LongReason (Qwen3-8B, 8 [PITH_FULL_IMAGE:figures/full_fig_p020_5.png] view at source ↗
Figure 6
Figure 6. Figure 6: Memory scaling law: DualKV (measured, solid) vs. FA2 (projected, dashed). Llama-3.1-8B, [PITH_FULL_IMAGE:figures/full_fig_p021_6.png] view at source ↗
Figure 7
Figure 7. Figure 7: Per-phase breakdown of an average GRPO training step on LongReason (Section 4.2). DualKV [PITH_FULL_IMAGE:figures/full_fig_p023_7.png] view at source ↗
Figure 8
Figure 8. Figure 8: Multi-node MoE GRPO training on LongReason (Section 4.3; Qwen3-30B-A3B, 16 [PITH_FULL_IMAGE:figures/full_fig_p023_8.png] view at source ↗
Figure 9
Figure 9. Figure 9: Per-phase breakdown of an average GRPO training step for the multi-node MoE experiment (Sec [PITH_FULL_IMAGE:figures/full_fig_p023_9.png] view at source ↗
read the original abstract

Modern RL post-training methods such as GRPO and DAPO train on $N$ response sequences of $R$ tokens sampled from a shared prompt of $P$ tokens, but standard FlashAttention replicates all $P$ prompt tokens $N$ times across both forward and backward passes -- duplicating compute and memory on identical hidden states. In large-rollout, long-context RL training ($N{\geq}16$, $P{\geq}8\text{K}$), this redundancy dominates the policy update cost. We observe that in decoder-only models, causal masking makes prompt representations invariant across sequences at every layer, so all per-token operations (norms, projections, MLP) and attention can process the prompt once -- a property not yet exploited at the kernel level for training. We propose \textbf{DualKV}, the first FlashAttention kernel variant that eliminates shared-prompt replication during RL training, via (1)~fused CUDA forward and backward kernels that iterate over two disjoint KV regions -- shared context and per-sequence response -- in a single kernel launch, and (2)~a data-pipeline redesign in veRL that repacks $N(P{+}R)$ tokens into $P{+}NR$ tokens per micro-batch, extending the token reduction from attention to the entire model by a factor $\rho = N(P{+}R)/(P{+}NR)$. DualKV is mathematically equivalent to standard attention and introduces no approximation. On Qwen3-8B GRPO training with 8$\times$H100 GPUs ($N{=}32$, 8K-context), DualKV achieves $1.63$--$2.09\times$ policy-update speedup, enables $2\times$ larger micro-batches, and raises MFU from $36\%$ to $76\%$. Similar gains hold for DAPO ($2.47\times$ speedup, $77\%$ MFU). At 30B MoE scale on 16$\times$H100, DualKV achieves $3.82\times$ policy-update and $3.38\times$ end-to-end step speedup over FlashAttention (which requires 4-way Ulysses sequence parallelism to avoid OOM).

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

2 major / 3 minor

Summary. The paper proposes DualKV, a FlashAttention variant for RL post-training (e.g., GRPO, DAPO) that exploits causal masking in decoder-only models to process shared prompt tokens only once across N rollouts. It introduces fused CUDA forward/backward kernels handling disjoint shared-prompt KV and per-sequence response KV regions in one launch, plus a veRL data repacking that reduces tokens from N(P+R) to P+NR. The method claims exact mathematical equivalence to standard attention with no approximation, and reports 1.63–2.09× policy-update speedups, 2× larger micro-batches, and MFU gains from 36% to 76% on Qwen3-8B (N=32, 8K context) with similar or larger gains at 30B MoE scale.

Significance. If the equivalence and kernel correctness hold, the work addresses a real redundancy in large-rollout RL training that dominates policy-update cost for N≥16 and long contexts. Concrete wall-clock and MFU numbers on named models/hardware, plus the parameter-free nature of the speedup (no fitted constants), strengthen the practical contribution. The extension of token reduction beyond attention to the full model via repacking is a notable engineering strength.

major comments (2)
  1. [§3.2] §3.2 (invariance argument): The induction that prompt hidden states remain identical across sequences at every layer is load-bearing for the equivalence claim. The manuscript should explicitly state the base case (shared embeddings + causal mask on prompt prefix) and inductive step (token-wise ops + attention only to prior prompt tokens), including whether this holds under all common variants such as grouped-query attention or RoPE.
  2. [§4.3] §4.3 (backward kernel): The fused backward pass description does not detail how gradients for the shared prompt KV are accumulated across the N sequences before the single write-back. Because this accumulation is required for exact equivalence to the replicated standard attention backward, an explicit equation or pseudocode step showing the reduction is needed.
minor comments (3)
  1. [Table 1] Table 1: the MFU column for the baseline should also report the sequence-parallelism configuration used (Ulysses 4-way) to make the 3.82× comparison at 30B scale directly interpretable.
  2. [Figure 3] Figure 3 caption: the x-axis label 'effective batch size' is ambiguous; clarify whether it refers to micro-batch token count before or after repacking.
  3. [Abstract] The abstract states 'DualKV is mathematically equivalent to standard attention and introduces no approximation.' This sentence should be repeated verbatim in the conclusion or §5 for emphasis.

Simulated Author's Rebuttal

2 responses · 0 unresolved

We thank the referee for the positive evaluation and the recommendation of minor revision. We address each major comment below and have revised the manuscript to incorporate the requested clarifications.

read point-by-point responses
  1. Referee: [§3.2] §3.2 (invariance argument): The induction that prompt hidden states remain identical across sequences at every layer is load-bearing for the equivalence claim. The manuscript should explicitly state the base case (shared embeddings + causal mask on prompt prefix) and inductive step (token-wise ops + attention only to prior prompt tokens), including whether this holds under all common variants such as grouped-query attention or RoPE.

    Authors: We agree that an explicit base case and inductive step will strengthen the presentation of the equivalence claim. In the revised manuscript we have expanded §3.2 to state: Base case—at layer 0 the shared prompt embeddings are identical across sequences and the causal mask restricts each prompt token to attend only to preceding prompt tokens. Inductive step—given identical prompt hidden states at layer ℓ, all subsequent token-wise operations (RMSNorm, linear projections, MLP) and the attention computation (which references only prior prompt KV under causality) produce identical states at layer ℓ+1. The argument holds for grouped-query attention because the KV grouping is applied uniformly to the shared prompt, and for RoPE because rotary embeddings for the prompt prefix depend only on the same absolute positions in every rollout. The updated section now contains this full inductive argument. revision: yes

  2. Referee: [§4.3] §4.3 (backward kernel): The fused backward pass description does not detail how gradients for the shared prompt KV are accumulated across the N sequences before the single write-back. Because this accumulation is required for exact equivalence to the replicated standard attention backward, an explicit equation or pseudocode step showing the reduction is needed.

    Authors: We acknowledge that the backward-kernel description would benefit from an explicit reduction step. In the revised §4.3 we have added pseudocode and a short equation showing that, for each prompt position p, the gradient is accumulated as grad_KV_prompt[p] = ∑_{i=1}^N grad_KV_from_sequence_i[p] before the single write-back to global memory. This summation is performed in shared memory within the fused kernel and guarantees that the resulting gradient matches the sum obtained from N independent standard-attention backward passes, preserving exact equivalence. revision: yes

Circularity Check

0 steps flagged

No significant circularity identified

full rationale

The paper's central claim rests on the standard property that causal masking in decoder-only models renders prompt hidden states identical across rollouts at every layer by induction (shared embeddings, causal attention limited to prior prompt tokens, and token-wise operations like norms and MLPs). This invariance is an architectural consequence, not a fitted input or self-definition, and DualKV implements the identical computation graph with no approximation. Performance numbers are measured directly from benchmark runs on Qwen3-8B and 30B MoE models rather than any equation that reduces to its own inputs by construction. No load-bearing self-citations, uniqueness theorems, or ansatzes appear in the derivation; the token reduction factor ρ is a direct algebraic consequence of the repacking P + NR, not a renamed prediction.

Axiom & Free-Parameter Ledger

0 free parameters · 1 axioms · 0 invented entities

The approach rests on one domain assumption about causal masking invariance and standard CUDA programming primitives; no free parameters or new entities are introduced.

axioms (1)
  • domain assumption Causal masking in decoder-only models makes prompt representations invariant across sequences at every layer
    This invariance is the load-bearing observation that justifies processing the prompt only once.

pith-pipeline@v0.9.0 · 5949 in / 1326 out tokens · 51921 ms · 2026-05-19T15:44:11.884822+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Reference graph

Works this paper leans on

25 extracted references · 25 canonical work pages · 12 internal anchors

  1. [1]

    Bifurcated attention: Accelerating massively parallel decoding with shared prefixes in LLMs

    Ben Athiwaratkun, Sujan Kumar Gonugondla, Sanjay Krishna Gouda, Haifeng Qian, Hantian Ding, Qing Sun, Jun Wang, Jiacheng Guo, Liangfu Chen, Parminder Bhatia, Ramesh Nallapati, Sudipta Sengupta, and Bing Xiang. Bifurcated attention: Accelerating massively parallel decoding with shared prefixes in LLMs . arXiv preprint arXiv:2403.08845, 2024

  2. [2]

    Training Verifiers to Solve Math Word Problems

    Karl Cobbe, Vineet Kosaraju, Mohammad Bavarian, Mark Chen, Heewoo Jun, Lukasz Kaiser, Matthias Plappert, Jerry Tworek, Jacob Hilton, Reiichiro Nakano, Christopher Hesse, and John Schulman. Training verifiers to solve math word problems. arXiv preprint arXiv:2110.14168, 2021

  3. [3]

    Flash A ttention-2: Faster attention with better parallelism and work partitioning

    Tri Dao. Flash A ttention-2: Faster attention with better parallelism and work partitioning. In International Conference on Learning Representations, 2024

  4. [4]

    Fu, Stefano Ermon, Atri Rudra, and Christopher R \'e

    Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher R \'e . Flash A ttention: Fast and memory-efficient exact attention with IO -awareness. In Advances in Neural Information Processing Systems, 2022

  5. [5]

    The Llama 3 Herd of Models

    Abhimanyu Dubey et al. The L lama 3 herd of models. arXiv preprint arXiv:2407.21783, 2024

  6. [6]

    OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework

    Jian Hu, Xibin Wu, Wei Shen, Jason Klein Liu, et al. OpenRLHF : An easy-to-use, scalable and high-performance RLHF framework. arXiv preprint arXiv:2405.11143, 2024

  7. [7]

    REINFORCE++: Stabilizing Critic-Free Policy Optimization with Global Advantage Normalization

    Jian Hu, Jason Klein Liu, Haotian Xu, and Wei Shen. REINFORCE++ : Stabilizing critic-free policy optimization with global advantage normalization. arXiv preprint arXiv:2501.03262, 2025

  8. [8]

    DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models

    Sam Ade Jacobs, Masahiro Tanaka, Chengming Zhang, Minjia Zhang, Shuaiwen Leon Song, Samyam Rajbhandari, and Yuxiong He. DeepSpeed U lysses: System optimizations for enabling training of extreme long sequence transformer models. arXiv preprint arXiv:2309.14509, 2023

  9. [9]

    Jimenez, John Yang, Alexander Wettig, Shunyu Yao, Kexin Pei, Ofir Press, and Karthik Narasimhan

    Carlos E. Jimenez, John Yang, Alexander Wettig, Shunyu Yao, Kexin Pei, Ofir Press, and Karthik Narasimhan. SWE -bench: Can language models resolve real-world GitHub issues? In International Conference on Learning Representations, 2024

  10. [10]

    Reducing activation recomputation in large transformer models

    Vijay Anand Korthikanti, Jared Casper, Sangkug Lym, Lawrence McAfee, Michael Andersch, Mohammad Shoeybi, and Bryan Catanzaro. Reducing activation recomputation in large transformer models. Proceedings of Machine Learning and Systems (MLSys), 5, 2023

  11. [11]

    Efficient memory management for large language model serving with PagedAttention

    Woosuk Kwon, Zhuohan Li, Siyuan Zhuang, Ying Sheng, Lianmin Zheng, Cody Hao Yu, Joseph Gonzalez, Hao Zhang, and Ion Stoica. Efficient memory management for large language model serving with PagedAttention . In Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles, pages 611--626, 2023

  12. [12]

    Let's Verify Step by Step

    Hunter Lightman, Vineet Kosaraju, Yura Burda, Harri Edwards, Bowen Baker, Teddy Lee, Jan Leike, John Schulman, Ilya Sutskever, and Karl Cobbe. Let's verify step by step. arXiv preprint arXiv:2305.20050, 2023

  13. [13]

    Long R eason: A synthetic long-context reasoning benchmark via context expansion

    Zhan Ling, Kang Liu, Kai Yan, Yifan Yang, Weijian Lin, Ting-Han Fan, Lingfeng Shen, Zhengyin Du, and Jiecao Chen. Long R eason: A synthetic long-context reasoning benchmark via context expansion. arXiv preprint arXiv:2501.15089, 2025

  14. [14]

    Post-training gpt-oss for agentic reasoning with reinforcement learning

    LinkedIn AI . Post-training gpt-oss for agentic reasoning with reinforcement learning. Hugging Face blog, https://huggingface.co/blog/LinkedIn/gpt-oss-agentic-rl, 2025

  15. [15]

    Ring attention with blockwise transformers for near-infinite context

    Hao Liu, Matei Zaharia, and Pieter Abbeel. Ring attention with blockwise transformers for near-infinite context. In International Conference on Learning Representations, 2024

  16. [16]

    RepoBench: Benchmarking Repository-Level Code Auto-Completion Systems

    Tianyang Liu, Canwen Xu, and Julian McAuley. RepoBench : Benchmarking repository-level code auto-completion systems. arXiv preprint arXiv:2306.03091, 2023

  17. [17]

    Prefix grouper: Efficient GRPO training through shared-prefix forward

    Zikang Liu, Tongtian Yue, Yepeng Tang, Longteng Guo, Junxian Cai, Qingbin Liu, Xi Chen, and Jing Liu. Prefix grouper: Efficient GRPO training through shared-prefix forward. arXiv preprint arXiv:2506.05433, 2025

  18. [18]

    Proximal Policy Optimization Algorithms

    John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, and Oleg Klimov. Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347, 2017

  19. [19]

    Flash A ttention-3: Fast and accurate attention with asynchrony and low-precision

    Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, and Tri Dao. Flash A ttention-3: Fast and accurate attention with asynchrony and low-precision. In Advances in Neural Information Processing Systems, 2024

  20. [20]

    DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models

    Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Mingchuan Zhang, YK Li, Y Wu, and Daya Guo. Deep S eek M ath: Pushing the limits of mathematical reasoning in open language models. arXiv preprint arXiv:2402.03300, 2024

  21. [21]

    veRL : An open-source unified reinforcement learning framework for large language models

    Guangming Sheng, Chi Cao, Zilingfeng Lin, Song Bian, Da Wei, Wenbo Xu, Caicai Yang, Jian Liu, and Tao Zhang. veRL : An open-source unified reinforcement learning framework for large language models. arXiv preprint arXiv:2409.19951, 2024

  22. [22]

    TRL : Transformers reinforcement learning, 2020

    Leandro von Werra, Younes Belkada, Lewis Tunstall, Edward Beeching, et al. TRL : Transformers reinforcement learning, 2020. URL https://github.com/huggingface/trl

  23. [23]

    Qwen3 Technical Report

    An Yang, Baosong Yang, Beichen Zhang, et al. Qwen3 technical report. arXiv preprint arXiv:2505.09388, 2025

  24. [24]

    DAPO: An Open-Source LLM Reinforcement Learning System at Scale

    Qiying Yu, Zheng Zhang, Ruofei Zhu, Yufeng Yuan, Xiaochen Zuo, Yu Yue, Weinan Dai, Tiantian Fan, Gaohong Liu, Lingjun Liu, Xin Liu, Haibin Lin, Zhiqi Lin, Bole Ma, Guangming Sheng, Yuxuan Tong, Chi Zhang, Mofan Zhang, Wang Zhang, Hang Zhu, Jinhua Zhu, Jiaze Chen, Jiangjie Chen, Chengyi Wang, Hongli Yu, Yuxuan Song, Xiangpeng Wei, Hao Zhou, Jingjing Liu, W...

  25. [25]

    SGLang: Efficient Execution of Structured Language Model Programs

    Lianmin Zheng, Liangsheng Yin, Zhiqiang Xie, Shuo Cheng, Jeff Huang, Baris Kasikci, and Ion Stoica. SGLang : Efficient execution of structured language model programs. arXiv preprint arXiv:2312.07104, 2023