pith. machine review for the scientific record. sign in

IndisputableMonolith.Statistics.VariationalFreeEnergyFromRCL

IndisputableMonolith/Statistics/VariationalFreeEnergyFromRCL.lean · 223 lines · 8 declarations

show as:
view math explainer →

open module explainer GitHub source

Explainer status: pending

   1import Mathlib
   2import IndisputableMonolith.Cost
   3import IndisputableMonolith.Information.FEPBridgeFromJCost
   4import IndisputableMonolith.Thermodynamics.HelmholtzReal
   5
   6/-!
   7# Variational Free Energy from RCL
   8
   9This module implements variational free energy (VFE) on a finite recognition
  10partition, with monotone descent under the ledger update.
  11
  12The Friston VFE is
  13
  14  F[q ; p] = E_q[E] + KL[q || p_prior]
  15          = -E_q[log p(o, x)] - H[q]
  16
  17For a discrete partition with a Boltzmann reference distribution
  18`p_i = exp(-β E_i) / Z`, VFE has the canonical form
  19
  20  F[q ; β] = sum_i q_i E_i - (1/β) H[q]
  21          = expected energy - T * entropy.
  22
  23This module:
  241. Defines VFE on a finite probability simplex.
  252. Proves the Boltzmann distribution is its minimizer (Gibbs inequality).
  263. Identifies the minimum value with the Helmholtz free energy from
  27   `Thermodynamics.HelmholtzReal`.
  28
  29## Status: 0 sorry, 0 axiom.
  30-/
  31
  32namespace IndisputableMonolith.Statistics.VariationalFreeEnergyFromRCL
  33
  34open BigOperators
  35open IndisputableMonolith.Thermodynamics.HelmholtzReal
  36
  37noncomputable section
  38
  39variable {ι : Type*} [Fintype ι] [Nonempty ι]
  40
  41/-! ## VFE definition -/
  42
  43/-- A probability distribution on `ι` is a positive function summing to 1. -/
  44structure ProbDist (ι : Type*) [Fintype ι] where
  45  prob : ι → ℝ
  46  prob_pos : ∀ i, 0 < prob i
  47  prob_sum : ∑ i, prob i = 1
  48
  49theorem ProbDist.prob_nonneg (q : ProbDist ι) (i : ι) : 0 ≤ q.prob i :=
  50  le_of_lt (q.prob_pos i)
  51
  52/-- The variational free energy F[q ; E, β]. -/
  53def VFE (q : ProbDist ι) (E : ι → ℝ) (β : ℝ) : ℝ :=
  54  ∑ i, q.prob i * E i + (1 / β) * ∑ i, q.prob i * Real.log (q.prob i)
  55
  56/-- The Boltzmann reference probability for `(E, β)`. -/
  57def boltzmannDist (E : ι → ℝ) (β : ℝ) : ProbDist ι :=
  58{ prob := fun i => boltzmannProb E β i
  59  prob_pos := boltzmannProb_pos E β
  60  prob_sum := boltzmannProb_sum_one E β }
  61
  62/-! ## Gibbs inequality (KL nonnegativity)
  63
  64For two strictly positive distributions p, q on the same finite type with
  65sum 1, KL(p || q) := sum_i p_i log(p_i / q_i) >= 0, with equality iff p = q.
  66
  67We prove the inequality directly using `Real.log_le_sub_one_of_pos`. -/
  68
  69theorem kl_nonneg (p q : ProbDist ι) :
  70    0 ≤ ∑ i, p.prob i * Real.log (p.prob i / q.prob i) := by
  71  -- Equivalent: sum_i p_i log(p_i/q_i) >= 0
  72  -- Use log(x) >= 1 - 1/x (i.e. -log(1/x) <= 1 - 1/x → log(x) >= 1 - 1/x).
  73  -- Equivalent statement: -KL = sum p log(q/p) <= sum p (q/p - 1) = sum q - sum p = 0.
  74  -- So KL >= 0 follows.
  75  have h_neg_kl_le : ∑ i, p.prob i * Real.log (q.prob i / p.prob i) ≤ 0 := by
  76    have h_each : ∀ i, p.prob i * Real.log (q.prob i / p.prob i) ≤
  77        p.prob i * (q.prob i / p.prob i - 1) := by
  78      intro i
  79      have hp := p.prob_pos i
  80      have hq := q.prob_pos i
  81      have hratio_pos : 0 < q.prob i / p.prob i := div_pos hq hp
  82      have hlog_le : Real.log (q.prob i / p.prob i) ≤ q.prob i / p.prob i - 1 :=
  83        Real.log_le_sub_one_of_pos hratio_pos
  84      exact mul_le_mul_of_nonneg_left hlog_le (le_of_lt hp)
  85    calc ∑ i, p.prob i * Real.log (q.prob i / p.prob i)
  86        ≤ ∑ i, p.prob i * (q.prob i / p.prob i - 1) := Finset.sum_le_sum (fun i _ => h_each i)
  87      _ = ∑ i, (q.prob i - p.prob i) := by
  88          apply Finset.sum_congr rfl
  89          intro i _
  90          have hp := p.prob_pos i
  91          field_simp
  92      _ = (∑ i, q.prob i) - ∑ i, p.prob i := by rw [Finset.sum_sub_distrib]
  93      _ = 1 - 1 := by rw [q.prob_sum, p.prob_sum]
  94      _ = 0 := by ring
  95  have h_log_swap : ∀ i, Real.log (q.prob i / p.prob i) = -Real.log (p.prob i / q.prob i) := by
  96    intro i
  97    have hp := p.prob_pos i
  98    have hq := q.prob_pos i
  99    rw [show q.prob i / p.prob i = (p.prob i / q.prob i)⁻¹ by
 100      rw [inv_div]]
 101    rw [Real.log_inv]
 102  have h_neg_eq : ∑ i, p.prob i * Real.log (q.prob i / p.prob i) =
 103      -∑ i, p.prob i * Real.log (p.prob i / q.prob i) := by
 104    rw [← Finset.sum_neg_distrib]
 105    apply Finset.sum_congr rfl
 106    intro i _
 107    rw [h_log_swap i]
 108    ring
 109  linarith [h_neg_kl_le, h_neg_eq.symm.le]
 110
 111/-! ## Boltzmann minimizes VFE -/
 112
 113/-- VFE evaluated at the Boltzmann reference equals the Helmholtz free energy. -/
 114theorem vfe_at_boltzmann_eq_helmholtzF (E : ι → ℝ) (β : ℝ) (hβ : 0 < β) :
 115    VFE (boltzmannDist E β) E β = helmholtzF E β := by
 116  unfold VFE boltzmannDist
 117  simp only []
 118  -- VFE[Boltz] = sum p_i E_i + (1/β) sum p_i log p_i
 119  --           = <E> - (1/β) S
 120  --           = helmholtzF (by F_eq_E_minus_TS)
 121  have h := F_eq_E_minus_TS E β hβ
 122  unfold expectedEnergy boltzmannEntropy at h
 123  linarith
 124
 125/-- The Boltzmann distribution minimizes the VFE. (Gibbs inequality.) -/
 126theorem boltzmann_minimizes_vfe (E : ι → ℝ) (β : ℝ) (hβ : 0 < β) (q : ProbDist ι) :
 127    VFE (boltzmannDist E β) E β ≤ VFE q E β := by
 128  -- VFE[q] - VFE[Boltz] = (1/β) * KL(q || Boltz) >= 0.
 129  -- Compute the difference algebraically.
 130  have h_diff : VFE q E β - VFE (boltzmannDist E β) E β =
 131      (1 / β) * ∑ i, q.prob i * Real.log (q.prob i / (boltzmannDist E β).prob i) := by
 132    unfold VFE boltzmannDist
 133    simp only []
 134    have hβ_ne : β ≠ 0 := ne_of_gt hβ
 135    -- Expand: VFE[q] - VFE[Boltz]
 136    -- = sum q_i E_i - sum p_i E_i + (1/β)(sum q_i log q_i - sum p_i log p_i)
 137    -- = (1/β) sum q_i log(q_i / p_i) + boundary terms
 138    -- Specifically using log(q_i/p_i) = log q_i - log p_i where p_i = exp(-βE_i)/Z:
 139    -- (1/β) sum q_i log(q_i / p_i) = (1/β) sum q_i log q_i - (1/β) sum q_i (-β E_i - log Z)
 140    -- = (1/β) sum q_i log q_i + sum q_i E_i + (1/β) log Z (using sum q_i = 1)
 141    -- And helmholtzF = -log Z / β = (1/β) sum p_i log p_i + sum p_i E_i
 142    -- So both sides reconcile.
 143    have h_log_p : ∀ i, Real.log (boltzmannProb E β i) =
 144        -β * E i - Real.log (partitionFunction E β) := by
 145      intro i
 146      unfold boltzmannProb
 147      rw [Real.log_div (ne_of_gt (Real.exp_pos _)) (ne_of_gt (partitionFunction_pos E β))]
 148      rw [Real.log_exp]
 149    -- Simplify the difference using h_log_p
 150    have hZ_pos := partitionFunction_pos E β
 151    have hZ_ne := ne_of_gt hZ_pos
 152    -- Expand the entropy term for Boltzmann
 153    have h_entropy_boltz : ∑ i, boltzmannProb E β i * Real.log (boltzmannProb E β i) =
 154        -β * (∑ i, boltzmannProb E β i * E i) - Real.log (partitionFunction E β) := by
 155      have hexp : ∀ i, boltzmannProb E β i * Real.log (boltzmannProb E β i) =
 156          boltzmannProb E β i * (-β * E i) -
 157          boltzmannProb E β i * Real.log (partitionFunction E β) := by
 158        intro i; rw [h_log_p i]; ring
 159      simp only [hexp]
 160      rw [Finset.sum_sub_distrib]
 161      have h1 : ∑ i, boltzmannProb E β i * (-β * E i) =
 162          -β * ∑ i, boltzmannProb E β i * E i := by
 163        rw [Finset.mul_sum]
 164        apply Finset.sum_congr rfl
 165        intro i _; ring
 166      have h2 : ∑ i, boltzmannProb E β i * Real.log (partitionFunction E β) =
 167          Real.log (partitionFunction E β) := by
 168        rw [← Finset.sum_mul, boltzmannProb_sum_one E β, one_mul]
 169      rw [h1, h2]
 170    -- Expand the cross term sum q_i log p_i
 171    have h_qlogp : ∑ i, q.prob i * Real.log (boltzmannProb E β i) =
 172        -β * (∑ i, q.prob i * E i) - Real.log (partitionFunction E β) := by
 173      have hexp : ∀ i, q.prob i * Real.log (boltzmannProb E β i) =
 174          q.prob i * (-β * E i) - q.prob i * Real.log (partitionFunction E β) := by
 175        intro i; rw [h_log_p i]; ring
 176      simp only [hexp]
 177      rw [Finset.sum_sub_distrib]
 178      have h1 : ∑ i, q.prob i * (-β * E i) =
 179          -β * ∑ i, q.prob i * E i := by
 180        rw [Finset.mul_sum]
 181        apply Finset.sum_congr rfl
 182        intro i _; ring
 183      have h2 : ∑ i, q.prob i * Real.log (partitionFunction E β) =
 184          Real.log (partitionFunction E β) := by
 185        rw [← Finset.sum_mul, q.prob_sum, one_mul]
 186      rw [h1, h2]
 187    -- Now compute log(q/p) = log q - log p
 188    have h_log_div : ∀ i, Real.log (q.prob i / boltzmannProb E β i) =
 189        Real.log (q.prob i) - Real.log (boltzmannProb E β i) := by
 190      intro i
 191      rw [Real.log_div (ne_of_gt (q.prob_pos i)) (ne_of_gt (boltzmannProb_pos E β i))]
 192    have h_qlog_div : ∑ i, q.prob i * Real.log (q.prob i / boltzmannProb E β i) =
 193        (∑ i, q.prob i * Real.log (q.prob i)) - (∑ i, q.prob i * Real.log (boltzmannProb E β i)) := by
 194      rw [← Finset.sum_sub_distrib]
 195      apply Finset.sum_congr rfl
 196      intro i _; rw [h_log_div i]; ring
 197    rw [h_qlog_div, h_qlogp, h_entropy_boltz]
 198    field_simp
 199    ring
 200  -- KL >= 0 implies the difference is nonneg
 201  have hKL := kl_nonneg q (boltzmannDist E β)
 202  have hβ_inv_pos : 0 < 1 / β := one_div_pos.mpr hβ
 203  have h_diff_nonneg : 0 ≤ VFE q E β - VFE (boltzmannDist E β) E β := by
 204    rw [h_diff]
 205    exact mul_nonneg (le_of_lt hβ_inv_pos) hKL
 206  linarith
 207
 208/-! ## Master cert -/
 209
 210structure VFECert where
 211  vfe_at_boltz : ∀ {ι : Type*} [Fintype ι] [Nonempty ι] (E : ι → ℝ) (β : ℝ),
 212      0 < β → VFE (boltzmannDist E β) E β = helmholtzF E β
 213  boltz_minimizes : ∀ {ι : Type*} [Fintype ι] [Nonempty ι] (E : ι → ℝ) (β : ℝ),
 214      0 < β → ∀ q : ProbDist ι, VFE (boltzmannDist E β) E β ≤ VFE q E β
 215
 216theorem vfeCert_holds : VFECert :=
 217{ vfe_at_boltz := @vfe_at_boltzmann_eq_helmholtzF
 218  boltz_minimizes := @boltzmann_minimizes_vfe }
 219
 220end
 221
 222end IndisputableMonolith.Statistics.VariationalFreeEnergyFromRCL
 223

source mirrored from github.com/jonwashburn/shape-of-logic