pith. sign in

arxiv: 2605.26713 · v1 · pith:GDQ2JXMVnew · submitted 2026-05-26 · 📊 stat.ML · cs.LG

Transformers Can Learn Posterior Predictive Distributions In-Context

Pith reviewed 2026-06-29 16:02 UTC · model grok-4.3

classification 📊 stat.ML cs.LG
keywords transformersin-context learningposterior predictive distributionGaussian processesBayesian inferenceattention depthnormalizationprior-data fitted networks
0
0 comments X

The pith

Transformers can implement gradient descent to approximate posterior predictive distributions of Gaussian processes.

A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.

The paper establishes that transformers can learn to produce the full posterior predictive distribution for Gaussian process regression solely from in-context examples. It does this by constructing an explicit mechanism in which attention layers carry out gradient descent updates on the posterior mean and variance, after which nonlinear mappings convert the results into binned probability values. This construction supplies a theoretical account of how prior-data fitted networks achieve Bayesian-style predictions without hand-coded probabilistic machinery. The work further shows that error in the approximation is controlled by attention depth and bin resolution, and that normalization plus sufficient depth are required for the model to extrapolate beyond the range of pretraining examples.

Core claim

By explicit construction, a transformer can realize a gradient descent algorithm that targets the posterior predictive mean and variance for Gaussian process regression. Nonlinear mappings then turn these quantities into binned probabilities that approximate the full posterior predictive distribution. The analysis provides error bounds in terms of the number of attention layers and the fineness of the probability bins, and identifies normalization as essential for extrapolation outside the pretraining range.

What carries the argument

Gradient descent steps on the Gaussian process posterior mean and variance implemented by transformer attention, combined with nonlinear mappings to produce binned PPD probabilities.

If this is right

  • The approximation error of the PPD decreases with greater attention depth and finer bin resolution.
  • Normalization is required for the transformer to extrapolate its PPD estimates beyond the sample sizes seen in pretraining.
  • Attention depth directly governs how accurately the model can perform in-context Bayesian prediction.
  • Transformers of this form can supply full distributional outputs rather than point predictions alone.

Where Pith is reading between the lines

These are editorial extensions of the paper, not claims the author makes directly.

  • Analogous constructions could embed gradient updates for other Bayesian models inside attention layers.
  • Attention depth may serve as a general control on accuracy across a broader set of in-context learning problems.
  • Bin resolution can be chosen in practice to match the precision needed for downstream probability estimates.
  • The approach may encounter limits when data-generating processes depart from Gaussian process assumptions.

Load-bearing premise

The transformer architecture with normalization is sufficient to realize the exact gradient descent steps on the Gaussian-process posterior mean and variance without additional fitting.

What would settle it

A trained transformer on Gaussian process regression tasks produces binned probabilities that deviate from the true posterior predictive distribution even when attention depth is increased, or extrapolation fails once normalization is removed.

Figures

Figures reproduced from arXiv: 2605.26713 by Changwoo J. Lee, Gyeonghun Kang, Xiang Cheng.

Figure 1
Figure 1. Figure 1: Comparison of PPDs produced by a transformer-based PFN and an empirical-Bayes Gaussian process (GP) on the Sacra￾mento home price dataset (Kuhn, 2008). The axes represent spatial coordinates (longitude and latitude), and the color scale represents price per square foot (V). The top row shows PFN posterior predic￾tive summaries, and the bottom row shows the corresponding GP outputs. Columns correspond to th… view at source ↗
Figure 2
Figure 2. Figure 2: Verifying Theorem 4.1. log E[TV(p(a,b] , qϑ)] on the evaluation set as a function of depth L ∈ {2, 4, 8, 16, 32} and bin count C ∈ {16, 32, 64, 128, 256}, for Bayesian linear regression (BLR) (d = 5, n ∈ [128, 512]) and GP regression with RBF kernel (d = 2, n ∈ [64, 128]). Left: heatmaps comparing theory and learnable parameterizations. Right: slices at fixed L = 32 (varying C) and fixed C = 256 (varying L… view at source ↗
Figure 3
Figure 3. Figure 3: Richardson step size varies with n without normal￾ization. The upper bound of admissible step size of Richard￾son iteration solving (11), computed as (12), plotted against n ∈ {100 × i : i ∈ [10]} for d = 16 and σ 2 = 0.2. Re￾sults are averaged over 100 independent trials. where the expectation over n is taken with respect to a dis￾tribution on {nmin, . . . , nmax} used in pretraining, such as discrete uni… view at source ↗
Figure 4
Figure 4. Figure 4: Normalized attention enables sample size general￾ization (RBF, d = 8). log E[TV(p(a,b] , qϑ)] versus evaluation sample size n ′ ∈ {40, 50, . . . , 200} for learnable unnormal￾ized TFϑ,L (left) and normalized TFpr ϑ,L (right), pretrained on n ∈ [64, 128]. Normalization yields stable performance across n ′ . the spectrum of the iteration matrix uniformly over n (Cu￾tajar et al., 2016). Specifically, left-mul… view at source ↗
Figure 5
Figure 5. Figure 5: Depth and pretraining range improve generalization quality (RBF, d = 16). Prediction MSE, 90% interval coverage, and 90% interval width versus evaluation sample size for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. Red dashed curves denote the corresponding true PPD interval width/nominal coverage. 6. Numerical Studi… view at source ↗
Figure 6
Figure 6. Figure 6: (d = 4) Prediction MSE, 90% interval coverage, and 90% interval width versus evaluation sample size for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. Red dashed curves denote the corresponding true PPD interval width/nominal coverage. mean and the 5% and 95% predictive quantiles. As a baseline, we fit an empirical-Bay… view at source ↗
Figure 7
Figure 7. Figure 7: (d = 4) log E[TV(p(a,b] , qϑ)] and moment MSEs E(µ − mϑ,1) 2 and E(τ + µ 2 − mϑ,2) 2 versus evaluation sample size n ′ for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. 2.0 1.5 1.0 L o g M S E y nmax = 128 L=4 L=8 L=16 L=32 2.0 1.5 1.0 nmax = 256 2.0 1.5 1.0 nmax = 512 0.80 0.85 0.90 90% Coverage True PPD 0.80 0.85 0.… view at source ↗
Figure 8
Figure 8. Figure 8: (d = 8) Prediction MSE, 90% interval coverage, and 90% interval width versus evaluation sample size for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. Red dashed curves denote the corresponding true PPD interval width/nominal coverage. 28 [PITH_FULL_IMAGE:figures/full_fig_p028_8.png] view at source ↗
Figure 9
Figure 9. Figure 9: (d = 8) log E[TV(p(a,b] , qϑ)] and moment MSEs E(µ − mϑ,1) 2 and E(τ + µ 2 − mϑ,2) 2 versus evaluation sample size n ′ for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. 3.0 2.5 2.0 1.5 L o g T V nmax = 128 L=4 L=8 L=16 L=32 3.0 2.5 2.0 1.5 nmax = 256 3.0 2.5 2.0 1.5 nmax = 512 5 4 3 2 L o g M S E 5 4 3 2 5 4 3 2 200 4… view at source ↗
Figure 10
Figure 10. Figure 10: (d = 16) log E[TV(p(a,b] , qϑ)] and moment MSEs E(µ − mϑ,1) 2 and E(τ + µ 2 − mϑ,2) 2 versus evaluation sample size n ′ for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. Errors decrease with larger L and nmax and grow with n ′ , mirroring the trends in prediction and interval metrics in [PITH_FULL_IMAGE:figures/full… view at source ↗
Figure 11
Figure 11. Figure 11: (d = 4, covariate shift) Prediction MSE, 90% interval coverage, and 90% interval width versus evaluation sample size for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. Red dashed curves denote the corresponding true PPD interval width/nominal coverage. 3.00 2.75 2.50 2.25 2.00 1.75 L o g T V nmax = 128 L=4 L=8 L=16 L=… view at source ↗
Figure 12
Figure 12. Figure 12: (d = 4, covariate shift) log E[TV(p(a,b] , qϑ)] and moment MSEs E(µ − mϑ,1) 2 and E(τ + µ 2 − mϑ,2) 2 versus evaluation sample size n ′ for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. 30 [PITH_FULL_IMAGE:figures/full_fig_p030_12.png] view at source ↗
Figure 13
Figure 13. Figure 13: (d = 8, covariate shift) Prediction MSE, 90% interval coverage, and 90% interval width versus evaluation sample size for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. Red dashed curves denote the corresponding true PPD interval width/nominal coverage. 2.5 2.0 1.5 1.0 L o g T V nmax = 128 L=4 L=8 L=16 L=32 2.5 2.0 1.5… view at source ↗
Figure 14
Figure 14. Figure 14: (d = 8, covariate shift) log E[TV(p(a,b] , qϑ)] and moment MSEs E(µ − mϑ,1) 2 and E(τ + µ 2 − mϑ,2) 2 versus evaluation sample size n ′ for learnable normalized models with C = 256, depths L ∈ {4, 8, 16, 32}, and pretraining ranges n ∈ [64, nmax] with nmax ∈ {128, 256, 512}. 31 [PITH_FULL_IMAGE:figures/full_fig_p031_14.png] view at source ↗
Figure 15
Figure 15. Figure 15: Illustration of similarities between PPDs produced by a transformer (PFN) and a Gaussian process (GP) on the Walker Lake dataset (Isaaks & Srivastava, 1989). The x, y axes represent standardized spatial coordinates, while the color scale shows estimated mineral (V) concentrations in ppm. The top row displays PPD outputs from a transformer, while the bottom row shows PPD based on a GP. Columns correspond t… view at source ↗
read the original abstract

Prior-data fitted networks (PFNs) have recently emerged as a powerful approach for Bayesian prediction tasks, approximating the posterior predictive distribution (PPD) through in-context learning. Despite their strong empirical performance and ability to go beyond point predictions, theoretical understandings of the algorithmic capability of transformers to learn distributions in context are still lacking. Focusing on Gaussian process regression problems, we show by construction that transformers can implement a gradient descent algorithm targeting the posterior predictive mean and variance, followed by nonlinear mappings that yield binned probabilities of PPD. We study the error bounds of the approximated PPD in terms of attention depth and bin resolution. Based on these results, we further demonstrate the key role of normalization and the choice of attention depth in enabling the extrapolation abilities of transformers beyond the pretraining sample size range. We conduct simulations that corroborate our findings, providing insight into the expressivity of PFNs targeting PPDs and how architectural choices may influence generalization capabilities.

Editorial analysis

A structured set of objections, weighed in public.

Desk editor's note, referee report, simulated authors' rebuttal, and a circularity audit. Tearing a paper down is the easy half of reading it; the pith above is the substance, this is the friction.

Referee Report

0 major / 3 minor

Summary. The paper claims to show by construction that transformers can implement a finite-step gradient descent algorithm targeting the posterior predictive mean and variance for Gaussian process regression, followed by fixed nonlinear mappings to produce binned probabilities of the PPD. Error bounds are stated in terms of attention depth and bin resolution. The work further identifies the key role of normalization (and attention depth) in enabling extrapolation beyond the pretraining sample-size range, with simulations reported to corroborate both the construction and the extrapolation behavior.

Significance. If the explicit construction holds, the result supplies a concrete theoretical account of how transformers can realize Bayesian posterior predictive inference in-context, directly explaining the empirical success of prior-data fitted networks on distribution-valued tasks. The error analysis, the isolation of normalization as the mechanism for out-of-range generalization, and the provision of simulation evidence constitute clear strengths for a manuscript in this area.

minor comments (3)
  1. [Construction section] The abstract and introduction refer to 'the specific transformer architecture (including normalization)' but do not list the precise layer-norm placement or scaling constants used in the construction; this should be stated explicitly in the section presenting the weight construction.
  2. [Simulations] In the simulation section, the reported figures compare in-context performance inside versus outside the pretraining length range, but the exact bin resolution and number of GD steps used for each curve are not tabulated; adding a small table would improve reproducibility.
  3. [Error analysis] The error-bound statements are given in terms of depth and bin resolution, yet the dependence on the GP kernel hyperparameters is left implicit; a short remark clarifying whether the bounds are uniform over a compact set of kernels would be helpful.

Simulated Author's Rebuttal

0 responses · 0 unresolved

We thank the referee for the positive assessment of our work and the recommendation of minor revision. The provided summary accurately captures the paper's construction of transformers implementing finite-step gradient descent on posterior predictive mean and variance for Gaussian process regression, followed by binning, along with the error analysis and the role of normalization in extrapolation.

Circularity Check

0 steps flagged

No significant circularity: constructive existence result with independent verification steps

full rationale

The paper's central claim is an explicit construction showing that a specific transformer architecture (with normalization) can realize finite-step gradient descent on the GP posterior mean and variance, followed by fixed nonlinear maps to binned PPD probabilities. Error bounds are derived in terms of depth and bin resolution; simulations corroborate both the construction and the role of normalization. No load-bearing step reduces by definition to a fitted parameter, no self-citation chain is invoked to justify the core mapping, and the derivation does not rename or smuggle in prior results from the same authors. The result is therefore self-contained against external benchmarks and receives the default non-circularity finding.

Axiom & Free-Parameter Ledger

0 free parameters · 0 axioms · 0 invented entities

Abstract only; no explicit free parameters, axioms, or invented entities are stated.

pith-pipeline@v0.9.1-grok · 5687 in / 1062 out tokens · 54492 ms · 2026-06-29T16:02:04.681071+00:00 · methodology

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.

Reference graph

Works this paper leans on

13 extracted references · 3 canonical work pages · 2 internal anchors

  1. [1]

    Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J

    ISBN 9780199535255. Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-V oss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D., Wu, J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M., Gray, S., Chess, B., Clark, J., Berner, C., McCandl...

  2. [2]

    A survey on in-context learning

    Dong, Q., Li, L., Dai, D., Zheng, C., Ma, J., Li, R., Xia, H., Xu, J., Wu, Z., Chang, B., Sun, X., Li, L., and Sui, Z. A survey on in-context learning. In Al-Onaizan, Y ., Bansal, M., and Chen, Y .-N. (eds.),Proceedings of the 2024 Conference on Empirical Methods in Natural Lan- guage Processing, pp. 1107–1128, Miami, Florida, USA, November

  3. [3]

    TabPFN-2.5: Advancing the State of the Art in Tabular Foundation Models

    10 Transformers Can Learn Posterior Predictive Distributions In-Context Grinsztajn, L., Fl¨oge, K., Key, O., Birkel, F., Jund, P., Roof, B., J¨ager, B., Safaric, D., Alessi, S., Hayler, A., Manium, M., Yu, R., Jablonski, F., Hoo, S. B., Garg, A., Robertson, J., B ¨uhler, M., Moroshan, V ., Purucker, L., Cornu, C., Wehrhahn, L. C., Bonetto, A., Sch ¨olkopf...

  4. [4]

    TabPFN-3: Technical Report

    Grinsztajn, L., Fl¨oge, K., Key, O., Birkel, F., Jund, P., Roof, B., Manium, M., Bin, S., Hoo, B ¨uhler, M., Garg, A., Safaric, D., Robertson, J., J¨ager, B., Alessi, S., Hayler, A., Moroshan, V ., Purucker, L., Singer, P., Arazi, A., Siems, J., Metzen, J. H., Grab, G., Erickson, N., Guo, S., Kalfon, E., Bing, S., Salinas, D., Cornu, C., Wehrhahn, L. C., ...

  5. [5]

    P., and Sesh Kumar, K

    Moriconi, R., Deisenroth, M. P., and Sesh Kumar, K. High-dimensional Bayesian optimization using low- dimensional feature spaces.Machine Learning, 109(9): 1925–1943,

  6. [6]

    Tabpfn: One model to rule them all?arXiv preprint arXiv:2505.20003, 2025

    Zhang, Q., Tan, Y . S., Tian, Q., and Li, P. TabPFN: One model to rule them all?arXiv preprint arXiv:2505.20003,

  7. [7]

    Hence, the above quantities can be expressed in terms ofλ j(D−1/2AD−1/2),j∈ {1, n}as well

    Note that D−1A and D−1/2AD−1/2 are similar matrices, so they have the same characteristic polynomials, hence the same eigenvalues. Hence, the above quantities can be expressed in terms ofλ j(D−1/2AD−1/2),j∈ {1, n}as well. B. Auxiliary Lemmas and Theorems B.1. Approximation based on Discretization Lemma B.1.Let p(x) be a LLip-Lipschitz continuous density s...

  8. [8]

    -st (hence smallest) sample eigenvalues is upper bounded through trace bounds and Markov inequality, following the technique in Burt et al. (2019). Let κ(x, x′) = exp −0.5∥x−x ′∥2 be the RBF kernel. By Mercer’s theorem, the spectral decomposition is κ(x, x′) =P j≥1 µjϕj(x)ϕj(x′)where{ϕ j}are orthonormal basis and{µ j}are eigenvalues. The largest eigenvalu...

  9. [9]

    Fix a compact set K ⊂R×(0,∞) on which the mapping ψ:K →R 2 is defined as ψ: (µ, τ)7→ µ τ ,− 1 2τ

    the target PPD belongs to a one-dimensional exponential familyf θ(y), 3.f θ(y)is Lipschitz continuous on(a, b]and the grid is equidistant. Fix a compact set K ⊂R×(0,∞) on which the mapping ψ:K →R 2 is defined as ψ: (µ, τ)7→ µ τ ,− 1 2τ . Such K is plausible given that y∈(a, b] and τ(Z (0))∈[σ 2, σ2 +κ(x, x)] for any Z(0). By the universal approximation th...

  10. [10]

    Task FigureLR base N BLR (theory), trainTF ϑ,L Fig

    All the metrics in the figures are evaluated on4096evaluation samples. Task FigureLR base N BLR (theory), trainTF ϑ,L Fig. 22×10 −4 104 BLR (learnable), trainTF ϑ,L Fig. 22×10 −4 2×10 4 RBF (theory), trainTF ϑ,L Fig. 21×10 −3 5×10 4 RBF (learnable), trainTF ϑ,L Fig. 22×10 −4 105 RBF (learnable), trainTF ϑ,L Fig. 42×10 −4 105 RBF (learnable), trainTF pr ϑ,...

  11. [11]

    Note that this case study is intended as an illustrative validation of the proposed mechanism rather than as a comprehensive benchmark

    and the Walker Lake dataset (Isaaks & Srivastava, 1989). Note that this case study is intended as an illustrative validation of the proposed mechanism rather than as a comprehensive benchmark. Sacramento.The dataset is from the R package caret version 7.0.1 (Kuhn, 2008), from which we selected data corresponding to the cities of Sacramento and Elk Grove. ...

  12. [12]

    Following Rasmussen & Williams (2005, Chapter 5), we first fit an empirical-Bayes anisotropic RBF GP to the Sacramento data

    This preprocessing is used consistently for empirical-Bayes fitting, PFN pretraining, and evaluation. Following Rasmussen & Williams (2005, Chapter 5), we first fit an empirical-Bayes anisotropic RBF GP to the Sacramento data. Specifically, we use the ARD kernel κ(x, x′) =α 2 exp −1 2 2X k=1 (xk −x ′ k)2 ℓ2 k ! , together with Gaussian observation noise. ...

  13. [13]

    Consistent with Theorem 5.3, for fixed L and nmax, the error of the finite-iteration solver increases as the evaluation sample size grows. 29 Transformers Can Learn Posterior Predictive Distributions In-Context 3.1 3.0 2.9 2.8 2.7 2.6 Log MSE y nmax = 128 L=4 L=8 L=16 L=32 3.1 3.0 2.9 2.8 2.7 2.6 nmax = 256 3.1 3.0 2.9 2.8 2.7 2.6 nmax = 512 0.80 0.85 0.9...