pith. machine review for the scientific record. sign in

IndisputableMonolith.Thermodynamics.MaxEntFromCost

IndisputableMonolith/Thermodynamics/MaxEntFromCost.lean · 210 lines · 5 declarations

show as:
view math explainer →

open module explainer GitHub source

Explainer status: pending

   1import Mathlib
   2import IndisputableMonolith.Cost
   3import IndisputableMonolith.Thermodynamics.RecognitionThermodynamics
   4
   5/-!
   6# Maximum Entropy from Cost Minimization
   7
   8This module proves that the Gibbs distribution emerges from the principle of
   9maximum entropy subject to a cost constraint.
  10-/
  11
  12namespace IndisputableMonolith
  13namespace Thermodynamics
  14
  15open Real Cost RecognitionSystem
  16
  17variable {Ω : Type*} [Fintype Ω] [Nonempty Ω]
  18variable (sys : RecognitionSystem) (X : Ω → ℝ)
  19
  20/-- **THEOREM: The Free Energy - KL Divergence Identity**
  21    F_R(q) - F_R(Gibbs) = TR * D_KL(q || Gibbs)
  22
  23    **Proof**: F_R(q) = ⟨X⟩_q - TR*S(q) = ∑ q_ω X_ω + TR ∑ q_ω log q_ω
  24    For Gibbs: p_ω = exp(-X_ω/TR)/Z, so log p_ω = -X_ω/TR - log Z
  25    D_KL(q||p) = ∑ q_ω log(q_ω/p_ω) = ∑ q_ω log q_ω + ∑ q_ω (X_ω/TR + log Z)
  26               = -S(q)/TR + ⟨X⟩_q/TR + log Z
  27    TR * D_KL = -TR*S(q) + ⟨X⟩_q + TR*log Z = F_R(q) - (-TR*log Z) = F_R(q) - F_R(Gibbs) -/
  28theorem free_energy_kl_identity (q : ProbabilityDistribution Ω) :
  29    recognition_free_energy sys q.p X - recognition_free_energy sys (gibbs_measure sys X) X =
  30    sys.TR * kl_divergence q.p (gibbs_measure sys X) := by
  31  -- Use the fact that F_R(Gibbs) = -TR * log(Z) from free_energy_identity
  32  have h_gibbs_FR := free_energy_identity sys X
  33  rw [h_gibbs_FR]
  34  unfold recognition_free_energy expected_cost recognition_entropy free_energy_from_Z
  35
  36  -- F_R(q) = ∑ q J - TR * (-∑ q log q) = ∑ q J + TR ∑ q log q
  37  -- D_KL(q||p) = ∑ q log(q/p) where p = gibbs
  38  -- TR * D_KL = TR ∑ q log q - TR ∑ q log p
  39  -- log p_ω = log(exp(-J_ω/TR)/Z) = -J_ω/TR - log Z
  40  -- TR ∑ q log p = TR ∑ q (-J_ω/TR - log Z) = -∑ q J - TR log Z ∑ q = -∑ q J - TR log Z
  41  -- TR * D_KL = TR ∑ q log q + ∑ q J + TR log Z
  42  -- F_R(q) - F_R(Gibbs) = (∑ q J + TR ∑ q log q) - (-TR log Z) = ∑ q J + TR ∑ q log q + TR log Z
  43  -- These match! QED.
  44
  45  unfold kl_divergence gibbs_measure partition_function
  46  simp only [gibbs_weight]
  47  set Z := ∑ ω, exp (-Jcost (X ω) / sys.TR) with hZ
  48  have hZ_pos : 0 < Z := by
  49    rw [hZ]; apply Finset.sum_pos (fun ω _ => exp_pos _) Finset.univ_nonempty
  50  have hq_sum := q.sum_one
  51  set S := ∑ ω : Ω, (if q.p ω > 0 then q.p ω * log (q.p ω) else 0)
  52  have lhs_simp :
  53    (∑ ω, q.p ω * Jcost (X ω) - sys.TR * -S) - -sys.TR * log Z =
  54    ∑ ω, q.p ω * Jcost (X ω) + sys.TR * S + sys.TR * log Z := by ring
  55  rw [lhs_simp]
  56  rw [show sys.TR * log Z = ∑ ω : Ω, sys.TR * log Z * q.p ω from
  57    by rw [← Finset.mul_sum, hq_sum, mul_one]]
  58  show ∑ ω, q.p ω * Jcost (X ω) + sys.TR * S + ∑ ω, sys.TR * log Z * q.p ω =
  59    sys.TR * ∑ ω, (if q.p ω > 0 ∧ exp (-Jcost (X ω) / sys.TR) / Z > 0
  60      then q.p ω * log (q.p ω / (exp (-Jcost (X ω) / sys.TR) / Z)) else 0)
  61  rw [show sys.TR * S = ∑ ω : Ω, sys.TR * (if q.p ω > 0 then q.p ω * log (q.p ω) else 0) from
  62    by rw [Finset.mul_sum]]
  63  rw [← Finset.sum_add_distrib, ← Finset.sum_add_distrib, Finset.mul_sum]
  64  apply Finset.sum_congr rfl
  65  intro ω _
  66  have h_gibbs_pos : 0 < exp (-Jcost (X ω) / sys.TR) / Z := div_pos (exp_pos _) hZ_pos
  67  by_cases hq_pos : 0 < q.p ω
  68  · simp only [show q.p ω > 0 from hq_pos, show exp (-Jcost (X ω) / sys.TR) / Z > 0 from h_gibbs_pos,
  69               and_self, ite_true]
  70    rw [log_div (ne_of_gt hq_pos) (ne_of_gt h_gibbs_pos),
  71        log_div (exp_pos _).ne' hZ_pos.ne', log_exp]
  72    field_simp [sys.TR_pos.ne']
  73    ring
  74  · push_neg at hq_pos
  75    have hq_zero : q.p ω = 0 := le_antisymm hq_pos (q.nonneg ω)
  76    simp [hq_zero]
  77
  78/-- **THEOREM: Free Energy Minimization**
  79    The Gibbs distribution minimizes the Recognition Free Energy. -/
  80theorem gibbs_minimizes_free_energy_basic (p : ProbabilityDistribution Ω) :
  81    recognition_free_energy sys (gibbs_measure sys X) X ≤ recognition_free_energy sys p.p X := by
  82  have h := free_energy_kl_identity sys X p
  83  have hkl := kl_divergence_nonneg p.p (gibbs_measure sys X)
  84    p.nonneg
  85    (fun ω => gibbs_measure_pos sys X ω)
  86    p.sum_one
  87    (gibbs_measure_sum_one sys X)
  88  calc recognition_free_energy sys (gibbs_measure sys X) X
  89      = recognition_free_energy sys p.p X - sys.TR * kl_divergence p.p (gibbs_measure sys X) := by
  90        rw [← h]; ring
  91    _ ≤ recognition_free_energy sys p.p X := by
  92        have hTR := sys.TR_pos
  93        nlinarith
  94
  95/-- **THEOREM: MaxEnt Subject to Cost**
  96    The Gibbs distribution has maximum entropy among all distributions with the same
  97    expected cost. -/
  98theorem max_ent_subject_to_cost (p : ProbabilityDistribution Ω)
  99    (h_cost : expected_cost p.p X = expected_cost (gibbs_measure sys X) X) :
 100    recognition_entropy p.p ≤ recognition_entropy (gibbs_measure sys X) := by
 101  have h_min := gibbs_minimizes_free_energy_basic sys X p
 102  unfold recognition_free_energy at h_min
 103  rw [h_cost] at h_min
 104  have hTR := sys.TR_pos
 105  -- expected_cost - TR * entropy_gibbs ≤ expected_cost - TR * entropy_p
 106  -- -TR * entropy_gibbs ≤ -TR * entropy_p
 107  -- TR * entropy_p ≤ TR * entropy_gibbs
 108  -- entropy_p ≤ entropy_gibbs
 109  nlinarith
 110
 111/-- **THEOREM: KL Divergence Zero Characterization**
 112    D_KL(q || p) = 0 iff q = p.
 113
 114    **Proof**: KL divergence is non-negative by Jensen's inequality applied to
 115    the convex function -log. It equals zero iff log(q/p) = 0 a.s., i.e., q = p. -/
 116theorem kl_divergence_zero_iff_eq {Ω : Type*} [Fintype Ω]
 117    (q p : Ω → ℝ) (hq_nn : ∀ ω, 0 ≤ q ω) (hp_pos : ∀ ω, 0 < p ω)
 118    (hq_sum : ∑ ω, q ω = 1) (hp_sum : ∑ ω, p ω = 1) :
 119    kl_divergence q p = 0 ↔ ∀ ω, q ω = p ω := by
 120  constructor
 121  · intro h_kl_zero ω
 122    unfold kl_divergence at h_kl_zero
 123    let excess := fun ω' =>
 124      (if q ω' > 0 ∧ p ω' > 0 then q ω' * log (q ω' / p ω') else 0) - (q ω' - p ω')
 125    have h_excess_nn : ∀ ω' ∈ Finset.univ, 0 ≤ excess ω' := by
 126      intro ω' _
 127      simp only [excess]
 128      by_cases hq : 0 < q ω'
 129      · have hp' := hp_pos ω'
 130        simp only [hq, hp', and_self, ite_true]
 131        have hr := div_pos hp' hq
 132        have h_le := log_le_sub_one_of_pos hr
 133        rw [log_div hp'.ne' hq.ne'] at h_le
 134        have h_mult := mul_le_mul_of_nonneg_left h_le hq.le
 135        have h_expand : q ω' * (p ω' / q ω' - 1) = p ω' - q ω' := by
 136          field_simp [hq.ne']
 137        rw [log_div hq.ne' hp'.ne']
 138        linarith
 139      · push_neg at hq
 140        have hq' := hq_nn ω'
 141        have hq_zero : q ω' = 0 := le_antisymm hq hq'
 142        simp [hq_zero, (hp_pos ω').le]
 143    have h_qp_sum : ∑ ω', (q ω' - p ω') = 0 := by
 144      rw [Finset.sum_sub_distrib, hq_sum, hp_sum, sub_self]
 145    have h_excess_sum : ∑ ω', excess ω' = 0 := by
 146      have h_unfold : ∑ ω', excess ω' =
 147        ∑ ω', ((if q ω' > 0 ∧ p ω' > 0 then q ω' * log (q ω' / p ω') else 0) - (q ω' - p ω')) :=
 148        Finset.sum_congr rfl (fun _ _ => rfl)
 149      rw [h_unfold, Finset.sum_sub_distrib, h_kl_zero, h_qp_sum, sub_self]
 150    have h_each_zero : ∀ ω' ∈ Finset.univ, excess ω' = 0 :=
 151      (Finset.sum_eq_zero_iff_of_nonneg h_excess_nn).mp h_excess_sum
 152    have h_this := h_each_zero ω (Finset.mem_univ ω)
 153    simp only [excess] at h_this
 154    by_cases hq : 0 < q ω
 155    · have hp' := hp_pos ω
 156      simp only [hq, hp', and_self, ite_true] at h_this
 157      -- h_this : q ω * log (q ω / p ω) - (q ω - p ω) = 0
 158      -- Derive: log(q/p) = 1 - p/q
 159      have h_eq : q ω * log (q ω / p ω) = q ω - p ω := by linarith
 160      have h_log_eq : log (q ω / p ω) = 1 - p ω / q ω := by
 161        have step1 : log (q ω / p ω) = (q ω - p ω) / q ω := by
 162          rw [eq_div_iff hq.ne']; linarith [h_eq]
 163        rw [step1, sub_div, div_self hq.ne']
 164      -- Derive: log(p/q) = p/q - 1
 165      have hpq_pos : 0 < p ω / q ω := div_pos hp' hq
 166      have h_log_pq : log (p ω / q ω) = p ω / q ω - 1 := by
 167        have : log (p ω / q ω) = -log (q ω / p ω) := by
 168          rw [log_div hp'.ne' hq.ne', log_div hq.ne' hp'.ne']; ring
 169        rw [this, h_log_eq]; ring
 170      -- If q ≠ p then p/q ≠ 1, so strict inequality log(p/q) < p/q - 1 contradicts equality
 171      by_contra h_ne_eq
 172      have h_pq_ne_one : p ω / q ω ≠ 1 := by
 173        intro h; exact h_ne_eq ((div_eq_one_iff_eq hq.ne').mp h).symm
 174      have h_log_pq_ne : log (p ω / q ω) ≠ 0 := by
 175        rw [h_log_pq, sub_ne_zero]; exact h_pq_ne_one
 176      have h_strict := Real.add_one_lt_exp h_log_pq_ne
 177      rw [exp_log hpq_pos] at h_strict
 178      linarith [h_log_pq]
 179    · push_neg at hq
 180      have hq' := hq_nn ω
 181      have hq_zero : q ω = 0 := le_antisymm hq hq'
 182      simp only [hq_zero, show ¬(0 < (0:ℝ)) from not_lt.mpr le_rfl, false_and, ite_false] at h_this
 183      linarith [hp_pos ω]
 184  · -- q = p → D_KL = 0: direct computation
 185    intro h_eq
 186    unfold kl_divergence
 187    apply Finset.sum_eq_zero
 188    intro ω _
 189    simp only [h_eq ω, div_self (ne_of_gt (hp_pos ω)), log_one, mul_zero]
 190    split_ifs <;> rfl
 191
 192/-- The Gibbs distribution is the unique minimizer of free energy. -/
 193theorem gibbs_unique_minimizer (q : ProbabilityDistribution Ω)
 194    (h_eq : recognition_free_energy sys q.p X = recognition_free_energy sys (gibbs_measure sys X) X) :
 195    ∀ ω, q.p ω = gibbs_measure sys X ω := by
 196  have h := free_energy_kl_identity sys X q
 197  rw [h_eq, sub_self] at h
 198  have hTR := sys.TR_pos
 199  have hkl_zero : kl_divergence q.p (gibbs_measure sys X) = 0 := by
 200    rw [eq_comm] at h
 201    have := mul_eq_zero.mp h
 202    cases this with
 203    | inl hTR0 => linarith
 204    | inr hkl0 => exact hkl0
 205  apply (kl_divergence_zero_iff_eq q.p (gibbs_measure sys X) q.nonneg (fun ω => gibbs_measure_pos sys X ω) q.sum_one (gibbs_measure_sum_one sys X)).mp
 206  exact hkl_zero
 207
 208end Thermodynamics
 209end IndisputableMonolith
 210

source mirrored from github.com/jonwashburn/shape-of-logic