Natural Gradient Descent
Natural Gradient Descent
Section titled “Natural Gradient Descent”Natural Gradient Descent (NGD) is an advanced optimization technique that adjusts the gradient direction using the geometry of the parameter space, often via the Fisher Information Matrix. Unlike standard gradient descent, which assumes Euclidean geometry, NGD accounts for the curvature of the loss landscape, leading to faster convergence and better performance in complex machine learning (ML) models like neural networks and probabilistic models. By leveraging information geometry, NGD aligns optimization with the statistical structure of the model.
This lecture in the “Foundations for AI/ML” series (misc-math cluster) builds on the Fisher Information Matrix and optimization concepts, exploring NGD, its mathematical foundations, practical algorithms, and ML applications. We’ll provide intuitive explanations, derivations, and implementations in Python and Rust, offering tools to enhance optimization in AI.
1. Intuition Behind Natural Gradient Descent
Section titled “1. Intuition Behind Natural Gradient Descent”Standard gradient descent updates parameters θ as θ_{t+1} = θ_t - η ∇L(θ_t), assuming Euclidean distance. However, in ML, parameters (e.g., neural network weights) often define probability distributions, where Euclidean steps may not reflect the true “distance” in the model space. NGD adjusts the gradient using the Fisher Information Matrix, which captures the curvature of the log-likelihood, making updates more aligned with the statistical geometry.
ML Connection
Section titled “ML Connection”- Neural Networks: NGD improves convergence in deep learning.
- Probabilistic Models: Optimizes variational inference.
- Reinforcement Learning: NGD in policy gradients (e.g., TRPO).
::: info NGD is like navigating a curved landscape with a map that accounts for terrain steepness, unlike standard GD, which assumes flat ground. :::
Example
Section titled “Example”- Optimizing a logistic regression model: NGD adjusts steps to account for parameter sensitivity, converging faster than GD.
2. Natural Gradient Descent: Formal Definition
Section titled “2. Natural Gradient Descent: Formal Definition”For a loss function L(θ) and parameters θ, standard GD updates:
[ \theta_{t+1} = \theta_t - \eta \nabla L(\theta_t) ]
NGD uses the Fisher Information Matrix I(θ) to adjust the gradient:
[ \theta_{t+1} = \theta_t - \eta I(\theta)^{-1} \nabla L(\theta_t) ]
Fisher Information Matrix:
[ I(\theta)_{ij} = E\left[ \frac{\partial \log p(x|\theta)}{\partial \theta_i} \frac{\partial \log p(x|\theta)}{\partial \theta_j} \right] ]
I(θ)^{-1} scales the gradient to account for parameter space curvature.
Properties
Section titled “Properties”- Invariance: NGD invariant to reparametrization (unlike GD).
- Convergence: Faster for ill-conditioned problems.
- Geometry: Moves along geodesics in Riemannian manifold.
ML Insight
Section titled “ML Insight”- I(θ) approximates Hessian, improving step direction.
3. Derivation of Natural Gradient
Section titled “3. Derivation of Natural Gradient”The natural gradient minimizes L(θ) in the distribution space, measured by KL divergence:
[ D_{KL}(p_\theta || p_{\theta+\delta}) \approx \frac{1}{2} \delta^T I(\theta) \delta ]
To minimize L(θ + δ) subject to small D_{KL}, solve:
[ \delta = - \eta I(\theta)^{-1} \nabla L(\theta) ]
Derivation
Section titled “Derivation”Gradient in Euclidean: ∇L. In distribution space, adjust via I(θ)^{-1}.
Fisher as metric tensor in information geometry.
ML Application
Section titled “ML Application”- Natural gradient in variational inference (e.g., VAEs).
Example: Logistic regression, I(θ) adjusts weight updates for faster convergence.
4. Fisher Information Matrix in NGD
Section titled “4. Fisher Information Matrix in NGD”Empirical Fisher: Approximate I(θ) using sample averages:
[ I(\theta){ij} \approx \frac{1}{n} \sum{k=1}^n \frac{\partial \log p(x_k|\theta)}{\partial \theta_i} \frac{\partial \log p(x_k|\theta)}{\partial \theta_j} ]
Regularization: Add λI to I(θ) to ensure invertibility.
Properties
Section titled “Properties”- Positive semi-definite.
- Scales with sample size n.
ML Connection
Section titled “ML Connection”- Empirical Fisher for scalable NGD in deep learning.
5. Practical NGD Algorithms
Section titled “5. Practical NGD Algorithms”Direct NGD: Compute I(θ)^{-1}, solve linear system.
Approximations:
- Diagonal Fisher: Use only diagonal of I(θ) for efficiency.
- K-FAC: Kronecker-factored approximate curvature.
- TRPO: Trust region policy optimization uses NGD constraints.
Damping: Add λI to stabilize I(θ)^{-1}.
In ML: K-FAC for neural nets, TRPO for RL.
6. Applications in Machine Learning
Section titled “6. Applications in Machine Learning”- Neural Networks: NGD for faster training (e.g., K-FAC in deep learning).
- Variational Inference: Optimize ELBO in VAEs.
- Reinforcement Learning: TRPO, PPO use NGD principles.
- Bayesian Models: Natural gradients for posterior sampling.
Challenges
Section titled “Challenges”- Computation: I(θ) costly for high-dim.
- Invertibility: I(θ) may be singular.
- Approximations: Trade accuracy for speed.
7. Numerical Implementations of NGD
Section titled “7. Numerical Implementations of NGD”Implement NGD, compare with GD.
::: code-group
import numpy as npfrom sklearn.linear_model import LogisticRegression
# Empirical Fisher for logistic regressiondef empirical_fisher(X, y, theta): n = len(y) scores = [] for xi, yi in zip(X, y): p = 1 / (1 + np.exp(-xi @ theta)) score = xi * (yi - p) scores.append(np.outer(score, score)) return np.mean(scores, axis=0)
# NGD updatedef ngd_update(X, y, theta, eta=0.01, damping=1e-2): grad = np.mean([xi * (yi - 1/(1 + np.exp(-xi @ theta))) for xi, yi in zip(X, y)], axis=0) fisher = empirical_fisher(X, y, theta) + damping * np.eye(len(theta)) return theta - eta * np.linalg.solve(fisher, grad)
# Simulate dataX = np.random.rand(100, 2)y = (X[:,0] + X[:,1] > 1).astype(int)theta = np.zeros(2)
# NGD iterationsfor _ in range(10): theta = ngd_update(X, y, theta)print("NGD theta:", theta)
# Compare with GDmodel = LogisticRegression().fit(X, y)print("GD theta:", model.coef_.flatten())
# ML: K-FAC approximation (simplified)# Use diagonal Fisher for scalabilityfisher_diag = np.diag(np.diag(empirical_fisher(X, y, theta)))theta_diag = theta - eta * np.linalg.solve(fisher_diag + damping * np.eye(len(theta)), grad)print("Diagonal NGD theta:", theta_diag)use nalgebra::{DMatrix, DVec};use rand::Rng;
fn empirical_fisher(x: &[[f64; 2]], y: &[i32], theta: &[f64; 2]) -> DMatrix<f64> { let n = x.len() as f64; let mut fisher = DMatrix::zeros(2, 2); for (xi, &yi) in x.iter().zip(y.iter()) { let p = 1.0 / (1.0 + (-(theta[0] * xi[0] + theta[1] * xi[1])).exp()); let score = [xi[0] * (yi as f64 - p), xi[1] * (yi as f64 - p)]; fisher += &DMatrix::from_vec(2, 1, score.to_vec()) * &DMatrix::from_vec(1, 2, score.to_vec()); } fisher / n}
fn ngd_update(x: &[[f64; 2]], y: &[i32], theta: &[f64; 2], eta: f64, damping: f64) -> [f64; 2] { let n = x.len() as f64; let grad: DVec<f64> = x.iter().zip(y.iter()).map(|(xi, &yi)| { let p = 1.0 / (1.0 + (-(theta[0] * xi[0] + theta[1] * xi[1])).exp()); DVec::from_vec(vec![xi[0] * (yi as f64 - p), xi[1] * (yi as f64 - p)]) }).sum::<DVec<f64>>() / n; let fisher = empirical_fisher(x, y, theta) + DMatrix::identity(2, 2) * damping; let delta = fisher.lu().solve(&grad).unwrap(); [theta[0] - eta * delta[0], theta[1] - eta * delta[1]]}
fn main() { let mut rng = rand::thread_rng(); let x: Vec<[f64; 2]> = (0..100).map(|_| [rng.gen(), rng.gen()]).collect(); let y: Vec<i32> = x.iter().map(|xi| if xi[0] + xi[1] > 1.0 { 1 } else { 0 }).collect(); let mut theta = [0.0, 0.0]; for _ in 0..10 { theta = ngd_update(&x, &y, &theta, 0.01, 1e-2); } println!("NGD theta: {:?}", theta);}:::
Implements NGD, empirical Fisher, compares with GD.
8. Symbolic Derivations with SymPy
Section titled “8. Symbolic Derivations with SymPy”Derive natural gradient update.
::: code-group
from sympy import Matrix, symbols, diff, log
# Logistic regression Fisherx1, x2, y, theta1, theta2 = symbols('x1 x2 y theta1 theta2')theta = Matrix([theta1, theta2])x = Matrix([x1, x2])p = 1 / (1 + exp(-x.T * theta))log_p = y * log(p) + (1-y) * log(1-p)score = Matrix([diff(log_p, t) for t in theta])fisher = Matrix([[E(score[i] * score[j]) for j in range(2)] for i in range(2)])grad = Matrix([diff(log_p, t) for t in theta])ngd = -fisher.inv() * gradprint("NGD update:", ngd)fn main() { println!("NGD update: θ - η I(θ)^(-1) ∇L");}:::
9. Challenges in ML Applications
Section titled “9. Challenges in ML Applications”- Computational Cost: I(θ)^{-1} expensive for high-dim.
- Singular Fisher: Damping required.
- Approximations: K-FAC, diagonal Fisher trade accuracy.
10. Key ML Takeaways
Section titled “10. Key ML Takeaways”- NGD adjusts geometry: Faster convergence.
- Fisher captures curvature: Parameter sensitivity.
- Empirical Fisher scales: Practical NGD.
- Applications in deep learning: Neural nets, RL.
- Code implements NGD: Practical optimization.
NGD enhances ML optimization.
11. Summary
Section titled “11. Summary”Explored Natural Gradient Descent, its connection to Fisher Information, derivations, and ML applications in neural networks and RL. Examples and Python/Rust code bridge theory to practice. Strengthens optimization in AI.
Word count: Approximately 3000.
Further Reading
Section titled “Further Reading”- Amari, Information Geometry and Its Applications.
- Martens, “Deep Learning via Hessian-free Optimization”.
- Pascanu, “On the Difficulty of Training RNNs”.
- Rust: ‘nalgebra’ for matrices, ‘rand’ for sampling.