Fine-Tuning Masked Diffusion for Provable Self-Correction
read the original abstract
A natural desideratum for generative models is self-correction--detecting and revising low-quality tokens at inference. While Masked Diffusion Models (MDMs) have emerged as a promising approach for generative modeling in discrete spaces, their capacity for self-correction remains poorly understood. Prior attempts to incorporate self-correction into MDMs either require overhauling MDM architectures/training or rely on imprecise proxies for token quality, limiting their applicability. Motivated by this, we introduce PRISM--Plug-in Remasking for Inference-time Self-correction of Masked Diffusions--a lightweight, model-agnostic approach that applies to any pretrained MDM. Theoretically, PRISM defines a self-correction loss that provably learns per-token quality scores, without RL or a verifier. These quality scores are computed in the same forward pass with MDM and used to detect low-quality tokens. Empirically, PRISM advances MDM inference across domains and scales: Sudoku; unconditional text (170M); and code with LLaDA (8B).
This paper has not been read by Pith yet.
Forward citations
Cited by 4 Pith papers
-
Inference-Time Scaling in Diffusion Models through Iterative Partial Refinement
IPR improves valid solution rates on MNIST Sudoku from 55.8% to 75.0% by iteratively refining partial regions in sequential diffusion models without external verifiers or reward models.
-
Remask, Don't Replace: Token-to-Mask Refinement in Diffusion Large Language Models
Token-to-Mask remasking improves self-correction in diffusion LLMs by resetting erroneous commitments to masks rather than overwriting them, yielding +13.33 points on AIME 2025 and +8.56 on CMATH.
-
Locally Coherent Parallel Decoding in Diffusion Language Models
CoDiLA adds a compact auxiliary AR model on diffusion latents to enforce local sequential validity during parallel token sampling in discrete diffusion language models.
-
Discrete Stochastic Localization for Non-autoregressive Generation
Discrete Stochastic Localization lets a single trained network support an entire family of per-token SNR paths for discrete sequence generation, with masked diffusion as a special case, and improves MAUVE scores when ...
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.