IndisputableMonolith.Cost.Ndim.Hessian
IndisputableMonolith/Cost/Ndim/Hessian.lean · 120 lines · 14 declarations
show as:
view math explainer →
1import IndisputableMonolith.Cost.Ndim.Core
2
3/-!
4# Hessian formulas for the `n`-dimensional reciprocal cost
5
6This module exposes the public replacement for the missing private
7`IndisputableMonolith.Cost.Ndim.Hessian` file referenced by
8`Metric.lean`.
9
10The key point is that in log-coordinates the `n`-dimensional cost
11depends only on the single weighted aggregate `dot α t`, so its Hessian
12is rank one and factors through the outer product `α ⊗ α`.
13-/
14
15namespace IndisputableMonolith
16namespace Cost
17namespace Ndim
18
19open scoped BigOperators
20
21/-- Log-coordinate gradient entry for `JlogN`. -/
22noncomputable def gradientEntry {n : ℕ} (α t : Vec n) (i : Fin n) : ℝ :=
23 α i * Real.sinh (dot α t)
24
25/-- Log-coordinate Hessian entry for `JlogN`. -/
26noncomputable def hessianEntry {n : ℕ} (α t : Vec n) (i j : Fin n) : ℝ :=
27 α i * α j * Real.cosh (dot α t)
28
29/-- The equilibrium Hessian model is the outer product `α ⊗ α`. -/
30def hessianMatrix {n : ℕ} (α : Vec n) : Fin n → Fin n → ℝ :=
31 fun i j => α i * α j
32
33/-- The Hessian matrix at an arbitrary log-state. -/
34noncomputable def hessianAt {n : ℕ} (α t : Vec n) : Fin n → Fin n → ℝ :=
35 fun i j => hessianEntry α t i j
36
37/-- Apply a tensor written in coordinates to a vector. -/
38noncomputable def applyTensor {n : ℕ}
39 (H : Fin n → Fin n → ℝ) (v : Vec n) : Vec n :=
40 fun i => ∑ j : Fin n, H i j * v j
41
42/-- The action of the log-coordinate Hessian on a vector. -/
43noncomputable def applyHessian {n : ℕ} (α t v : Vec n) : Vec n :=
44 applyTensor (hessianAt α t) v
45
46/-- Quadratic form associated with the Hessian. -/
47noncomputable def quadraticHessian {n : ℕ} (α t v : Vec n) : ℝ :=
48 dot v (applyHessian α t v)
49
50@[simp] theorem hessianEntry_zero {n : ℕ} (α : Vec n) (i j : Fin n) :
51 hessianEntry α (fun _ => 0) i j = α i * α j := by
52 unfold hessianEntry dot
53 simp [Real.cosh_zero]
54
55@[simp] theorem hessianAt_zero {n : ℕ} (α : Vec n) :
56 hessianAt α (fun _ => 0) = hessianMatrix α := by
57 funext i j
58 simp [hessianAt, hessianMatrix, hessianEntry_zero]
59
60/-- The full Hessian is a scalar multiple of the equilibrium outer-product model. -/
61theorem hessianAt_factor {n : ℕ} (α t : Vec n) :
62 hessianAt α t = fun i j => Real.cosh (dot α t) * hessianMatrix α i j := by
63 funext i j
64 unfold hessianAt hessianEntry hessianMatrix
65 ring
66
67/-- The Hessian action is always parallel to `α`. -/
68theorem applyHessian_eq_direction {n : ℕ} (α t v : Vec n) :
69 applyHessian α t v = fun i => Real.cosh (dot α t) * α i * dot α v := by
70 funext i
71 unfold applyHessian applyTensor hessianAt hessianEntry dot
72 calc
73 ∑ j : Fin n, (α i * α j * Real.cosh (dot α t)) * v j
74 = ∑ j : Fin n, (α i * Real.cosh (dot α t)) * (α j * v j) := by
75 apply Finset.sum_congr rfl
76 intro j hj
77 ring
78 _ = (α i * Real.cosh (dot α t)) * ∑ j : Fin n, α j * v j := by
79 rw [Finset.mul_sum]
80 _ = Real.cosh (dot α t) * α i * dot α v := by
81 simp [dot, mul_comm, mul_assoc]
82
83/-- Vectors orthogonal to `α` lie in the kernel of the Hessian. -/
84theorem applyHessian_of_dot_zero {n : ℕ} (α t v : Vec n)
85 (hv : dot α v = 0) :
86 applyHessian α t v = 0 := by
87 funext i
88 simp [applyHessian_eq_direction, hv]
89
90/-- The Hessian quadratic form depends only on the single active direction `dot α v`. -/
91theorem quadraticHessian_eq {n : ℕ} (α t v : Vec n) :
92 quadraticHessian α t v = Real.cosh (dot α t) * (dot α v) ^ 2 := by
93 unfold quadraticHessian dot
94 rw [applyHessian_eq_direction]
95 calc
96 ∑ i : Fin n, v i * (Real.cosh (dot α t) * α i * dot α v)
97 = ∑ i : Fin n, Real.cosh (dot α t) * dot α v * (v i * α i) := by
98 apply Finset.sum_congr rfl
99 intro i hi
100 ring
101 _ = (Real.cosh (dot α t) * dot α v) * ∑ i : Fin n, v i * α i := by
102 rw [Finset.mul_sum]
103 _ = Real.cosh (dot α t) * (dot α v) * dot α v := by
104 congr 1
105 unfold dot
106 apply Finset.sum_congr rfl
107 intro i hi
108 ring
109 _ = Real.cosh (dot α t) * (dot α v) ^ 2 := by
110 ring
111
112theorem quadraticHessian_nonneg {n : ℕ} (α t v : Vec n) :
113 0 ≤ quadraticHessian α t v := by
114 rw [quadraticHessian_eq]
115 positivity
116
117end Ndim
118end Cost
119end IndisputableMonolith
120