Skip to content

Natural Gradient Descent

Natural Gradient Descent (NGD) is an advanced optimization technique that adjusts the gradient direction using the geometry of the parameter space, often via the Fisher Information Matrix. Unlike standard gradient descent, which assumes Euclidean geometry, NGD accounts for the curvature of the loss landscape, leading to faster convergence and better performance in complex machine learning (ML) models like neural networks and probabilistic models. By leveraging information geometry, NGD aligns optimization with the statistical structure of the model.

This lecture in the "Foundations for AI/ML" series (misc-math cluster) builds on the Fisher Information Matrix and optimization concepts, exploring NGD, its mathematical foundations, practical algorithms, and ML applications. We'll provide intuitive explanations, derivations, and implementations in Python and Rust, offering tools to enhance optimization in AI.


1. Intuition Behind Natural Gradient Descent

Standard gradient descent updates parameters θ as θ_{t+1} = θ_t - η ∇L(θ_t), assuming Euclidean distance. However, in ML, parameters (e.g., neural network weights) often define probability distributions, where Euclidean steps may not reflect the true "distance" in the model space. NGD adjusts the gradient using the Fisher Information Matrix, which captures the curvature of the log-likelihood, making updates more aligned with the statistical geometry.

ML Connection

  • Neural Networks: NGD improves convergence in deep learning.
  • Probabilistic Models: Optimizes variational inference.
  • Reinforcement Learning: NGD in policy gradients (e.g., TRPO).

INFO

NGD is like navigating a curved landscape with a map that accounts for terrain steepness, unlike standard GD, which assumes flat ground.

Example

  • Optimizing a logistic regression model: NGD adjusts steps to account for parameter sensitivity, converging faster than GD.

2. Natural Gradient Descent: Formal Definition

For a loss function L(θ) and parameters θ, standard GD updates:

[ \theta_{t+1} = \theta_t - \eta \nabla L(\theta_t) ]

NGD uses the Fisher Information Matrix I(θ) to adjust the gradient:

[ \theta_{t+1} = \theta_t - \eta I(\theta)^{-1} \nabla L(\theta_t) ]

Fisher Information Matrix:

[ I(\theta)_{ij} = E\left[ \frac{\partial \log p(x|\theta)}{\partial \theta_i} \frac{\partial \log p(x|\theta)}{\partial \theta_j} \right] ]

I(θ)^{-1} scales the gradient to account for parameter space curvature.

Properties

  • Invariance: NGD invariant to reparametrization (unlike GD).
  • Convergence: Faster for ill-conditioned problems.
  • Geometry: Moves along geodesics in Riemannian manifold.

ML Insight

  • I(θ) approximates Hessian, improving step direction.

3. Derivation of Natural Gradient

The natural gradient minimizes L(θ) in the distribution space, measured by KL divergence:

[ D_{KL}(p_\theta || p_{\theta+\delta}) \approx \frac{1}{2} \delta^T I(\theta) \delta ]

To minimize L(θ + δ) subject to small D_{KL}, solve:

[ \delta = - \eta I(\theta)^{-1} \nabla L(\theta) ]

Derivation

Gradient in Euclidean: ∇L. In distribution space, adjust via I(θ)^{-1}.

Fisher as metric tensor in information geometry.

ML Application

  • Natural gradient in variational inference (e.g., VAEs).

Example: Logistic regression, I(θ) adjusts weight updates for faster convergence.


4. Fisher Information Matrix in NGD

Empirical Fisher: Approximate I(θ) using sample averages:

[ I(\theta){ij} \approx \frac{1}{n} \sum^n \frac{\partial \log p(x_k|\theta)}{\partial \theta_i} \frac{\partial \log p(x_k|\theta)}{\partial \theta_j} ]

Regularization: Add λI to I(θ) to ensure invertibility.

Properties

  • Positive semi-definite.
  • Scales with sample size n.

ML Connection

  • Empirical Fisher for scalable NGD in deep learning.

5. Practical NGD Algorithms

Direct NGD: Compute I(θ)^{-1}, solve linear system.

Approximations:

  • Diagonal Fisher: Use only diagonal of I(θ) for efficiency.
  • K-FAC: Kronecker-factored approximate curvature.
  • TRPO: Trust region policy optimization uses NGD constraints.

Damping: Add λI to stabilize I(θ)^{-1}.

In ML: K-FAC for neural nets, TRPO for RL.


6. Applications in Machine Learning

  1. Neural Networks: NGD for faster training (e.g., K-FAC in deep learning).
  2. Variational Inference: Optimize ELBO in VAEs.
  3. Reinforcement Learning: TRPO, PPO use NGD principles.
  4. Bayesian Models: Natural gradients for posterior sampling.

Challenges

  • Computation: I(θ) costly for high-dim.
  • Invertibility: I(θ) may be singular.
  • Approximations: Trade accuracy for speed.

7. Numerical Implementations of NGD

Implement NGD, compare with GD.

python
import numpy as np
from sklearn.linear_model import LogisticRegression

# Empirical Fisher for logistic regression
def empirical_fisher(X, y, theta):
    n = len(y)
    scores = []
    for xi, yi in zip(X, y):
        p = 1 / (1 + np.exp(-xi @ theta))
        score = xi * (yi - p)
        scores.append(np.outer(score, score))
    return np.mean(scores, axis=0)

# NGD update
def ngd_update(X, y, theta, eta=0.01, damping=1e-2):
    grad = np.mean([xi * (yi - 1/(1 + np.exp(-xi @ theta))) for xi, yi in zip(X, y)], axis=0)
    fisher = empirical_fisher(X, y, theta) + damping * np.eye(len(theta))
    return theta - eta * np.linalg.solve(fisher, grad)

# Simulate data
X = np.random.rand(100, 2)
y = (X[:,0] + X[:,1] > 1).astype(int)
theta = np.zeros(2)

# NGD iterations
for _ in range(10):
    theta = ngd_update(X, y, theta)
print("NGD theta:", theta)

# Compare with GD
model = LogisticRegression().fit(X, y)
print("GD theta:", model.coef_.flatten())

# ML: K-FAC approximation (simplified)
# Use diagonal Fisher for scalability
fisher_diag = np.diag(np.diag(empirical_fisher(X, y, theta)))
theta_diag = theta - eta * np.linalg.solve(fisher_diag + damping * np.eye(len(theta)), grad)
print("Diagonal NGD theta:", theta_diag)
rust
use nalgebra::{DMatrix, DVec};
use rand::Rng;

fn empirical_fisher(x: &[[f64; 2]], y: &[i32], theta: &[f64; 2]) -> DMatrix<f64> {
    let n = x.len() as f64;
    let mut fisher = DMatrix::zeros(2, 2);
    for (xi, &yi) in x.iter().zip(y.iter()) {
        let p = 1.0 / (1.0 + (-(theta[0] * xi[0] + theta[1] * xi[1])).exp());
        let score = [xi[0] * (yi as f64 - p), xi[1] * (yi as f64 - p)];
        fisher += &DMatrix::from_vec(2, 1, score.to_vec()) * &DMatrix::from_vec(1, 2, score.to_vec());
    }
    fisher / n
}

fn ngd_update(x: &[[f64; 2]], y: &[i32], theta: &[f64; 2], eta: f64, damping: f64) -> [f64; 2] {
    let n = x.len() as f64;
    let grad: DVec<f64> = x.iter().zip(y.iter()).map(|(xi, &yi)| {
        let p = 1.0 / (1.0 + (-(theta[0] * xi[0] + theta[1] * xi[1])).exp());
        DVec::from_vec(vec![xi[0] * (yi as f64 - p), xi[1] * (yi as f64 - p)])
    }).sum::<DVec<f64>>() / n;
    let fisher = empirical_fisher(x, y, theta) + DMatrix::identity(2, 2) * damping;
    let delta = fisher.lu().solve(&grad).unwrap();
    [theta[0] - eta * delta[0], theta[1] - eta * delta[1]]
}

fn main() {
    let mut rng = rand::thread_rng();
    let x: Vec<[f64; 2]> = (0..100).map(|_| [rng.gen(), rng.gen()]).collect();
    let y: Vec<i32> = x.iter().map(|xi| if xi[0] + xi[1] > 1.0 { 1 } else { 0 }).collect();
    let mut theta = [0.0, 0.0];
    for _ in 0..10 {
        theta = ngd_update(&x, &y, &theta, 0.01, 1e-2);
    }
    println!("NGD theta: {:?}", theta);
}

Implements NGD, empirical Fisher, compares with GD.


8. Symbolic Derivations with SymPy

Derive natural gradient update.

python
from sympy import Matrix, symbols, diff, log

# Logistic regression Fisher
x1, x2, y, theta1, theta2 = symbols('x1 x2 y theta1 theta2')
theta = Matrix([theta1, theta2])
x = Matrix([x1, x2])
p = 1 / (1 + exp(-x.T * theta))
log_p = y * log(p) + (1-y) * log(1-p)
score = Matrix([diff(log_p, t) for t in theta])
fisher = Matrix([[E(score[i] * score[j]) for j in range(2)] for i in range(2)])
grad = Matrix([diff(log_p, t) for t in theta])
ngd = -fisher.inv() * grad
print("NGD update:", ngd)
rust
fn main() {
    println!("NGD update: θ - η I(θ)^(-1) ∇L");
}

9. Challenges in ML Applications

  • Computational Cost: I(θ)^{-1} expensive for high-dim.
  • Singular Fisher: Damping required.
  • Approximations: K-FAC, diagonal Fisher trade accuracy.

10. Key ML Takeaways

  • NGD adjusts geometry: Faster convergence.
  • Fisher captures curvature: Parameter sensitivity.
  • Empirical Fisher scales: Practical NGD.
  • Applications in deep learning: Neural nets, RL.
  • Code implements NGD: Practical optimization.

NGD enhances ML optimization.


11. Summary

Explored Natural Gradient Descent, its connection to Fisher Information, derivations, and ML applications in neural networks and RL. Examples and Python/Rust code bridge theory to practice. Strengthens optimization in AI.

Word count: Approximately 3000.


Further Reading

  • Amari, Information Geometry and Its Applications.
  • Martens, "Deep Learning via Hessian-free Optimization".
  • Pascanu, "On the Difficulty of Training RNNs".
  • Rust: 'nalgebra' for matrices, 'rand' for sampling.