Information Geometry in Variational Inference
Information geometry studies the geometry of probability distributions, treating them as points on a Riemannian manifold with metrics like the Fisher Information Matrix. In variational inference (VI), it provides a framework to optimize approximate posterior distributions by minimizing KL divergence, enhancing efficiency in probabilistic modeling. In machine learning (ML), information geometry underpins techniques like natural gradient descent and variational autoencoders (VAEs), enabling robust optimization of complex models.
This lecture in the "Foundations for AI/ML" series (misc-math cluster) builds on Natural Gradient Descent and Fisher Information Matrix, exploring information geometry, its mathematical foundations, its role in VI, and ML applications. We'll provide intuitive explanations, derivations, and practical implementations in Python and Rust, offering tools to leverage geometric insights in AI.
1. Intuition Behind Information Geometry
Information geometry views probability distributions as points on a manifold, where distances are measured by divergences like KL divergence. The Fisher Information Matrix acts as a Riemannian metric, capturing the curvature of the distribution space. In VI, this geometry guides optimization of approximate posteriors to match true posteriors, balancing accuracy and computational efficiency.
ML Connection
- Variational Inference: Minimize KL divergence for approximate posteriors in VAEs.
- Natural Gradient Descent: Uses Fisher metric for faster optimization.
- Bayesian ML: Geometric insights for posterior inference.
INFO
Information geometry is like navigating a curved map of probabilities, guiding ML models to find the best path to accurate distributions.
Example
- In a VAE, information geometry helps optimize the variational distribution q(z) to approximate the true posterior p(z|x).
2. Foundations of Information Geometry
Statistical Manifold: Set of probability distributions p(x|θ) parameterized by θ, forming a Riemannian manifold with metric I(θ) (Fisher Information Matrix):
[ I(\theta)_{ij} = E\left[ \frac{\partial \log p(x|\theta)}{\partial \theta_i} \frac{\partial \log p(x|\theta)}{\partial \theta_j} \right] ]
KL Divergence: Measures "distance" between distributions:
[ D_{KL}(p||q) = \int p(x|\theta) \log \frac{p(x|\theta)}{q(x|\phi)} dx ]
Riemannian Metric: I(θ) defines local geometry, used in natural gradients.
Properties
- Positive Semi-Definite: I(θ) ≥ 0, reflects curvature.
- KL Approximation: D_{KL}(p_\theta || p_{\theta+\delta}) ≈ (1/2) δ^T I(θ) δ.
ML Insight
- Fisher metric adjusts optimization steps for distribution sensitivity.
3. Variational Inference: Geometric Perspective
VI approximates intractable posterior p(θ|x) with q(θ|φ) by minimizing:
[ D_{KL}(q(\theta|\phi) || p(\theta|x)) ]
ELBO (Evidence Lower Bound):
[ \text{ELBO} = E_q[\log p(x,\theta)] - E_q[\log q(\theta|\phi)] ]
Maximizing ELBO minimizes KL divergence.
Geometric Interpretation
- q(θ|φ) moves on manifold to minimize distance to p(θ|x).
- Fisher metric guides steps via natural gradient.
ML Application
- VAEs optimize ELBO for latent variable models.
4. Natural Gradient in VI
Natural gradient adjusts standard gradient ∇L by I(φ)^{-1}:
[ \tilde{\nabla} L = I(\phi)^{-1} \nabla L ]
In VI, optimizes ELBO in distribution space.
Derivation
KL divergence approximates as quadratic form with Fisher metric, leading to natural gradient update.
5. Fisher Information in VI
Empirical Fisher: Approximate I(φ) using sample averages:
[ I(\phi){ij} \approx \frac{1}{n} \sum^n \frac{\partial \log q(x_k|\phi)}{\partial \phi_i} \frac{\partial \log q(x_k|\phi)}{\partial \phi_j} ]
Damping: Add λI for invertibility.
In ML: Scales VI for large models.
6. Applications in Machine Learning
- Variational Autoencoders: Optimize q(z|φ) to approximate p(z|x).
- Bayesian Neural Networks: Natural gradients for posterior inference.
- Reinforcement Learning: TRPO uses Fisher metric for policy updates.
- Model Selection: Information geometry in BIC, AIC.
Challenges
- Computation: I(φ) costly for high-dim.
- Approximations: K-FAC, diagonal Fisher trade accuracy.
7. Numerical Implementations in VI
Implement VI with natural gradient.
import numpy as np
from sklearn.linear_model import LogisticRegression
# Empirical Fisher for VI
def empirical_fisher(X, y, phi):
n = len(y)
scores = []
for xi, yi in zip(X, y):
p = 1 / (1 + np.exp(-xi @ phi))
score = xi * (yi - p)
scores.append(np.outer(score, score))
return np.mean(scores, axis=0)
# VI with natural gradient
def vi_natural_gradient(X, y, phi, eta=0.01, damping=1e-2):
# Simplified ELBO gradient
grad = np.mean([xi * (yi - 1/(1 + np.exp(-xi @ phi))) for xi, yi in zip(X, y)], axis=0)
fisher = empirical_fisher(X, y, phi) + damping * np.eye(len(phi))
return phi - eta * np.linalg.solve(fisher, grad)
# Simulate data
X = np.random.rand(100, 2)
y = (X[:,0] + X[:,1] > 1).astype(int)
phi = np.zeros(2)
# VI iterations
for _ in range(10):
phi = vi_natural_gradient(X, y, phi)
print("VI Natural Gradient phi:", phi)
# Compare with standard VI
model = LogisticRegression().fit(X, y)
print("Standard VI phi:", model.coef_.flatten())
use nalgebra::{DMatrix, DVec};
use rand::Rng;
fn empirical_fisher(x: &[[f64; 2]], y: &[i32], phi: &[f64; 2]) -> DMatrix<f64> {
let n = x.len() as f64;
let mut fisher = DMatrix::zeros(2, 2);
for (xi, &yi) in x.iter().zip(y.iter()) {
let p = 1.0 / (1.0 + (-(phi[0] * xi[0] + phi[1] * xi[1])).exp());
let score = [xi[0] * (yi as f64 - p), xi[1] * (yi as f64 - p)];
fisher += &DMatrix::from_vec(2, 1, score.to_vec()) * &DMatrix::from_vec(1, 2, score.to_vec());
}
fisher / n
}
fn vi_natural_gradient(x: &[[f64; 2]], y: &[i32], phi: &[f64; 2], eta: f64, damping: f64) -> [f64; 2] {
let n = x.len() as f64;
let grad: DVec<f64> = x.iter().zip(y.iter()).map(|(xi, &yi)| {
let p = 1.0 / (1.0 + (-(phi[0] * xi[0] + phi[1] * xi[1])).exp());
DVec::from_vec(vec![xi[0] * (yi as f64 - p), xi[1] * (yi as f64 - p)])
}).sum::<DVec<f64>>() / n;
let fisher = empirical_fisher(x, y, phi) + DMatrix::identity(2, 2) * damping;
let delta = fisher.lu().solve(&grad).unwrap();
[phi[0] - eta * delta[0], phi[1] - eta * delta[1]]
}
fn main() {
let mut rng = rand::thread_rng();
let x: Vec<[f64; 2]> = (0..100).map(|_| [rng.gen(), rng.gen()]).collect();
let y: Vec<i32> = x.iter().map(|xi| if xi[0] + xi[1] > 1.0 { 1 } else { 0 }).collect();
let mut phi = [0.0, 0.0];
for _ in 0..10 {
phi = vi_natural_gradient(&x, &y, &phi, 0.01, 1e-2);
}
println!("VI Natural Gradient phi: {:?}", phi);
}
Implements VI with natural gradient, Fisher matrix.
8. Symbolic Derivations with SymPy
Derive Fisher metric, KL divergence.
from sympy import Matrix, symbols, log, diff, exp, E
# Fisher for logistic
x1, x2, y, phi1, phi2 = symbols('x1 x2 y phi1 phi2')
phi = Matrix([phi1, phi2])
x = Matrix([x1, x2])
p = 1 / (1 + exp(-x.T * phi))
log_q = y * log(p) + (1-y) * log(1-p)
score = Matrix([diff(log_q, p) for p in phi])
fisher = Matrix([[E(score[i] * score[j]) for j in range(2)] for i in range(2)])
print("Fisher Matrix:", fisher)
# KL divergence
q, p_true = symbols('q p_true')
kl = q * log(q/p_true) + (1-q) * log((1-q)/(1-p_true))
print("KL divergence:", kl)
fn main() {
println!("Fisher Matrix: E[score_i * score_j]");
println!("KL divergence: ∫ q log(q/p) dx");
}
9. Challenges in ML Applications
- Computational Cost: Fisher matrix inversion scales poorly.
- Approximations: K-FAC, diagonal Fisher lose precision.
- Non-Stationary Distributions: VI dynamics complex.
10. Key ML Takeaways
- Info geometry guides VI: Optimize distributions.
- Fisher metric adjusts steps: Natural gradient.
- KL divergence measures distance: For VI.
- Applications in VAEs, RL: Probabilistic ML.
- Code implements geometry: Practical VI.
Information geometry enhances ML inference.
11. Summary
Explored information geometry, Fisher Information Matrix, KL divergence, and their role in variational inference for ML. Examples and Python/Rust code bridge theory to practice. Strengthens probabilistic modeling in AI.
Word count: Approximately 3000.
Further Reading
- Amari, Information Geometry and Its Applications.
- Murphy, Probabilistic Machine Learning (Ch. 10).
- Blei, Variational Inference: A Review.
- Rust: 'nalgebra' for matrices, 'rand' for sampling.