IndisputableMonolith.Statistics.VariationalFreeEnergyFromRCL
IndisputableMonolith/Statistics/VariationalFreeEnergyFromRCL.lean · 223 lines · 8 declarations
show as:
view math explainer →
1import Mathlib
2import IndisputableMonolith.Cost
3import IndisputableMonolith.Information.FEPBridgeFromJCost
4import IndisputableMonolith.Thermodynamics.HelmholtzReal
5
6/-!
7# Variational Free Energy from RCL
8
9This module implements variational free energy (VFE) on a finite recognition
10partition, with monotone descent under the ledger update.
11
12The Friston VFE is
13
14 F[q ; p] = E_q[E] + KL[q || p_prior]
15 = -E_q[log p(o, x)] - H[q]
16
17For a discrete partition with a Boltzmann reference distribution
18`p_i = exp(-β E_i) / Z`, VFE has the canonical form
19
20 F[q ; β] = sum_i q_i E_i - (1/β) H[q]
21 = expected energy - T * entropy.
22
23This module:
241. Defines VFE on a finite probability simplex.
252. Proves the Boltzmann distribution is its minimizer (Gibbs inequality).
263. Identifies the minimum value with the Helmholtz free energy from
27 `Thermodynamics.HelmholtzReal`.
28
29## Status: 0 sorry, 0 axiom.
30-/
31
32namespace IndisputableMonolith.Statistics.VariationalFreeEnergyFromRCL
33
34open BigOperators
35open IndisputableMonolith.Thermodynamics.HelmholtzReal
36
37noncomputable section
38
39variable {ι : Type*} [Fintype ι] [Nonempty ι]
40
41/-! ## VFE definition -/
42
43/-- A probability distribution on `ι` is a positive function summing to 1. -/
44structure ProbDist (ι : Type*) [Fintype ι] where
45 prob : ι → ℝ
46 prob_pos : ∀ i, 0 < prob i
47 prob_sum : ∑ i, prob i = 1
48
49theorem ProbDist.prob_nonneg (q : ProbDist ι) (i : ι) : 0 ≤ q.prob i :=
50 le_of_lt (q.prob_pos i)
51
52/-- The variational free energy F[q ; E, β]. -/
53def VFE (q : ProbDist ι) (E : ι → ℝ) (β : ℝ) : ℝ :=
54 ∑ i, q.prob i * E i + (1 / β) * ∑ i, q.prob i * Real.log (q.prob i)
55
56/-- The Boltzmann reference probability for `(E, β)`. -/
57def boltzmannDist (E : ι → ℝ) (β : ℝ) : ProbDist ι :=
58{ prob := fun i => boltzmannProb E β i
59 prob_pos := boltzmannProb_pos E β
60 prob_sum := boltzmannProb_sum_one E β }
61
62/-! ## Gibbs inequality (KL nonnegativity)
63
64For two strictly positive distributions p, q on the same finite type with
65sum 1, KL(p || q) := sum_i p_i log(p_i / q_i) >= 0, with equality iff p = q.
66
67We prove the inequality directly using `Real.log_le_sub_one_of_pos`. -/
68
69theorem kl_nonneg (p q : ProbDist ι) :
70 0 ≤ ∑ i, p.prob i * Real.log (p.prob i / q.prob i) := by
71 -- Equivalent: sum_i p_i log(p_i/q_i) >= 0
72 -- Use log(x) >= 1 - 1/x (i.e. -log(1/x) <= 1 - 1/x → log(x) >= 1 - 1/x).
73 -- Equivalent statement: -KL = sum p log(q/p) <= sum p (q/p - 1) = sum q - sum p = 0.
74 -- So KL >= 0 follows.
75 have h_neg_kl_le : ∑ i, p.prob i * Real.log (q.prob i / p.prob i) ≤ 0 := by
76 have h_each : ∀ i, p.prob i * Real.log (q.prob i / p.prob i) ≤
77 p.prob i * (q.prob i / p.prob i - 1) := by
78 intro i
79 have hp := p.prob_pos i
80 have hq := q.prob_pos i
81 have hratio_pos : 0 < q.prob i / p.prob i := div_pos hq hp
82 have hlog_le : Real.log (q.prob i / p.prob i) ≤ q.prob i / p.prob i - 1 :=
83 Real.log_le_sub_one_of_pos hratio_pos
84 exact mul_le_mul_of_nonneg_left hlog_le (le_of_lt hp)
85 calc ∑ i, p.prob i * Real.log (q.prob i / p.prob i)
86 ≤ ∑ i, p.prob i * (q.prob i / p.prob i - 1) := Finset.sum_le_sum (fun i _ => h_each i)
87 _ = ∑ i, (q.prob i - p.prob i) := by
88 apply Finset.sum_congr rfl
89 intro i _
90 have hp := p.prob_pos i
91 field_simp
92 _ = (∑ i, q.prob i) - ∑ i, p.prob i := by rw [Finset.sum_sub_distrib]
93 _ = 1 - 1 := by rw [q.prob_sum, p.prob_sum]
94 _ = 0 := by ring
95 have h_log_swap : ∀ i, Real.log (q.prob i / p.prob i) = -Real.log (p.prob i / q.prob i) := by
96 intro i
97 have hp := p.prob_pos i
98 have hq := q.prob_pos i
99 rw [show q.prob i / p.prob i = (p.prob i / q.prob i)⁻¹ by
100 rw [inv_div]]
101 rw [Real.log_inv]
102 have h_neg_eq : ∑ i, p.prob i * Real.log (q.prob i / p.prob i) =
103 -∑ i, p.prob i * Real.log (p.prob i / q.prob i) := by
104 rw [← Finset.sum_neg_distrib]
105 apply Finset.sum_congr rfl
106 intro i _
107 rw [h_log_swap i]
108 ring
109 linarith [h_neg_kl_le, h_neg_eq.symm.le]
110
111/-! ## Boltzmann minimizes VFE -/
112
113/-- VFE evaluated at the Boltzmann reference equals the Helmholtz free energy. -/
114theorem vfe_at_boltzmann_eq_helmholtzF (E : ι → ℝ) (β : ℝ) (hβ : 0 < β) :
115 VFE (boltzmannDist E β) E β = helmholtzF E β := by
116 unfold VFE boltzmannDist
117 simp only []
118 -- VFE[Boltz] = sum p_i E_i + (1/β) sum p_i log p_i
119 -- = <E> - (1/β) S
120 -- = helmholtzF (by F_eq_E_minus_TS)
121 have h := F_eq_E_minus_TS E β hβ
122 unfold expectedEnergy boltzmannEntropy at h
123 linarith
124
125/-- The Boltzmann distribution minimizes the VFE. (Gibbs inequality.) -/
126theorem boltzmann_minimizes_vfe (E : ι → ℝ) (β : ℝ) (hβ : 0 < β) (q : ProbDist ι) :
127 VFE (boltzmannDist E β) E β ≤ VFE q E β := by
128 -- VFE[q] - VFE[Boltz] = (1/β) * KL(q || Boltz) >= 0.
129 -- Compute the difference algebraically.
130 have h_diff : VFE q E β - VFE (boltzmannDist E β) E β =
131 (1 / β) * ∑ i, q.prob i * Real.log (q.prob i / (boltzmannDist E β).prob i) := by
132 unfold VFE boltzmannDist
133 simp only []
134 have hβ_ne : β ≠ 0 := ne_of_gt hβ
135 -- Expand: VFE[q] - VFE[Boltz]
136 -- = sum q_i E_i - sum p_i E_i + (1/β)(sum q_i log q_i - sum p_i log p_i)
137 -- = (1/β) sum q_i log(q_i / p_i) + boundary terms
138 -- Specifically using log(q_i/p_i) = log q_i - log p_i where p_i = exp(-βE_i)/Z:
139 -- (1/β) sum q_i log(q_i / p_i) = (1/β) sum q_i log q_i - (1/β) sum q_i (-β E_i - log Z)
140 -- = (1/β) sum q_i log q_i + sum q_i E_i + (1/β) log Z (using sum q_i = 1)
141 -- And helmholtzF = -log Z / β = (1/β) sum p_i log p_i + sum p_i E_i
142 -- So both sides reconcile.
143 have h_log_p : ∀ i, Real.log (boltzmannProb E β i) =
144 -β * E i - Real.log (partitionFunction E β) := by
145 intro i
146 unfold boltzmannProb
147 rw [Real.log_div (ne_of_gt (Real.exp_pos _)) (ne_of_gt (partitionFunction_pos E β))]
148 rw [Real.log_exp]
149 -- Simplify the difference using h_log_p
150 have hZ_pos := partitionFunction_pos E β
151 have hZ_ne := ne_of_gt hZ_pos
152 -- Expand the entropy term for Boltzmann
153 have h_entropy_boltz : ∑ i, boltzmannProb E β i * Real.log (boltzmannProb E β i) =
154 -β * (∑ i, boltzmannProb E β i * E i) - Real.log (partitionFunction E β) := by
155 have hexp : ∀ i, boltzmannProb E β i * Real.log (boltzmannProb E β i) =
156 boltzmannProb E β i * (-β * E i) -
157 boltzmannProb E β i * Real.log (partitionFunction E β) := by
158 intro i; rw [h_log_p i]; ring
159 simp only [hexp]
160 rw [Finset.sum_sub_distrib]
161 have h1 : ∑ i, boltzmannProb E β i * (-β * E i) =
162 -β * ∑ i, boltzmannProb E β i * E i := by
163 rw [Finset.mul_sum]
164 apply Finset.sum_congr rfl
165 intro i _; ring
166 have h2 : ∑ i, boltzmannProb E β i * Real.log (partitionFunction E β) =
167 Real.log (partitionFunction E β) := by
168 rw [← Finset.sum_mul, boltzmannProb_sum_one E β, one_mul]
169 rw [h1, h2]
170 -- Expand the cross term sum q_i log p_i
171 have h_qlogp : ∑ i, q.prob i * Real.log (boltzmannProb E β i) =
172 -β * (∑ i, q.prob i * E i) - Real.log (partitionFunction E β) := by
173 have hexp : ∀ i, q.prob i * Real.log (boltzmannProb E β i) =
174 q.prob i * (-β * E i) - q.prob i * Real.log (partitionFunction E β) := by
175 intro i; rw [h_log_p i]; ring
176 simp only [hexp]
177 rw [Finset.sum_sub_distrib]
178 have h1 : ∑ i, q.prob i * (-β * E i) =
179 -β * ∑ i, q.prob i * E i := by
180 rw [Finset.mul_sum]
181 apply Finset.sum_congr rfl
182 intro i _; ring
183 have h2 : ∑ i, q.prob i * Real.log (partitionFunction E β) =
184 Real.log (partitionFunction E β) := by
185 rw [← Finset.sum_mul, q.prob_sum, one_mul]
186 rw [h1, h2]
187 -- Now compute log(q/p) = log q - log p
188 have h_log_div : ∀ i, Real.log (q.prob i / boltzmannProb E β i) =
189 Real.log (q.prob i) - Real.log (boltzmannProb E β i) := by
190 intro i
191 rw [Real.log_div (ne_of_gt (q.prob_pos i)) (ne_of_gt (boltzmannProb_pos E β i))]
192 have h_qlog_div : ∑ i, q.prob i * Real.log (q.prob i / boltzmannProb E β i) =
193 (∑ i, q.prob i * Real.log (q.prob i)) - (∑ i, q.prob i * Real.log (boltzmannProb E β i)) := by
194 rw [← Finset.sum_sub_distrib]
195 apply Finset.sum_congr rfl
196 intro i _; rw [h_log_div i]; ring
197 rw [h_qlog_div, h_qlogp, h_entropy_boltz]
198 field_simp
199 ring
200 -- KL >= 0 implies the difference is nonneg
201 have hKL := kl_nonneg q (boltzmannDist E β)
202 have hβ_inv_pos : 0 < 1 / β := one_div_pos.mpr hβ
203 have h_diff_nonneg : 0 ≤ VFE q E β - VFE (boltzmannDist E β) E β := by
204 rw [h_diff]
205 exact mul_nonneg (le_of_lt hβ_inv_pos) hKL
206 linarith
207
208/-! ## Master cert -/
209
210structure VFECert where
211 vfe_at_boltz : ∀ {ι : Type*} [Fintype ι] [Nonempty ι] (E : ι → ℝ) (β : ℝ),
212 0 < β → VFE (boltzmannDist E β) E β = helmholtzF E β
213 boltz_minimizes : ∀ {ι : Type*} [Fintype ι] [Nonempty ι] (E : ι → ℝ) (β : ℝ),
214 0 < β → ∀ q : ProbDist ι, VFE (boltzmannDist E β) E β ≤ VFE q E β
215
216theorem vfeCert_holds : VFECert :=
217{ vfe_at_boltz := @vfe_at_boltzmann_eq_helmholtzF
218 boltz_minimizes := @boltzmann_minimizes_vfe }
219
220end
221
222end IndisputableMonolith.Statistics.VariationalFreeEnergyFromRCL
223