Advanced Optimization in ML (Momentum, Adam, RMSProp)
Advanced Optimization in ML (Momentum, Adam, RMSProp)
Section titled “Advanced Optimization in ML (Momentum, Adam, RMSProp)”Optimization lies at the heart of machine learning, where the goal is to minimize complex loss functions over high-dimensional parameter spaces. While basic gradient descent provides a starting point, advanced methods incorporate momentum to accelerate convergence, adaptive learning rates to handle varying gradients, and sophisticated heuristics to navigate non-convex landscapes. This lecture focuses on momentum, RMSProp, and Adam, examining their mathematical underpinnings, convergence properties, empirical performance, and applications in deep learning. Building on prior discussions of gradient descent variants, we dive deeper into derivations, comparisons, hyperparameter tuning, and real-world considerations, all illustrated with code in Python and Rust.
These techniques have revolutionized ML training, enabling efficient optimization of massive models like transformers. By understanding their mechanics, you’ll gain insights into why they work, when they fail, and how to customize them for specific tasks.
1. Revisiting the Need for Advanced Optimizers
Section titled “1. Revisiting the Need for Advanced Optimizers”Basic gradient descent suffers from slow convergence in flat regions, oscillations in steep ravines, and sensitivity to learning rate choice. In ML, loss surfaces are often ill-conditioned (Hessian with wide eigenvalue spread), leading to zigzag paths or stagnation.
Advanced optimizers address this by:
- Accumulating past gradients (momentum) for inertia.
- Normalizing updates per parameter (adaptive rates) for scale invariance.
ML Connection
Section titled “ML Connection”- In CNNs/RNNs, gradients vary across layers—adaptivity crucial.
- Empirical: Adam often outperforms SGD in early training.
::: info Advanced optimizers turn naive descent into intelligent navigation, like adding GPS and acceleration to a car. :::
Challenges
Section titled “Challenges”- Overfitting: Faster convergence may generalize worse.
- Hyperparams: More knobs to tune.
2. Momentum: Adding Inertia to Descent
Section titled “2. Momentum: Adding Inertia to Descent”Momentum simulates physical momentum: v_t = β v_{t-1} + (1-β) ∇f(w_{t-1}), w_t = w_{t-1} - η v_t. Often written as v_t = β v_{t-1} - η ∇f(w_{t-1}).
β (0.9 typical) dampens oscillations, accelerates along consistent directions.
Derivation: From heavy-ball method, approximates second-order info.
Convergence
Section titled “Convergence”- Convex: O(1/t) like GD, but faster constants.
- Non-convex: Helps escape shallow minima.
Nesterov Variant
Section titled “Nesterov Variant”v_t = β v_{t-1} - η ∇f(w_{t-1} + β v_{t-1}), w_t = w_{t-1} + v_t.
lookahead corrects overshoot.
Proof sketch: Reduces error by anticipating curvature.
ML Insight
Section titled “ML Insight”- In practice, Nesterov slightly better in convex, but similar in DL.
Example: Optimizing quadratic f(w) = (1/2) w^T A w - b^T w, momentum damps high-freq modes.
3. RMSProp: Root Mean Square Propagation
Section titled “3. RMSProp: Root Mean Square Propagation”RMSProp adapts η per dimension: Divide by RMS of recent gradients.
g_t = γ g_{t-1} + (1-γ) (∇f_t)^2, update = -η ∇f_t / sqrt(g_t + ε).
γ~0.9, ε=1e-8 prevents divide-by-zero.
Motivation: Normalize gradients in directions with large/small magnitudes.
Geometric Interpretation
Section titled “Geometric Interpretation”Rescales axes to make Hessian more isotropic.
Convergence Analysis
Section titled “Convergence Analysis”- For convex, adaptive rates achieve similar rates to GD with optimal η.
- Stochastic: Reduces variance impact.
ML Application
Section titled “ML Application”- Effective in RNNs (vanishing/exploding gradients).
Comparison: Better than Adagrad (which accumulates all history, causing premature stop).
4. Adam: Combining Momentum and Adaptivity
Section titled “4. Adam: Combining Momentum and Adaptivity”Adam = Adaptive Moment estimation.
First moment (mean): m_t = β1 m_{t-1} + (1-β1) ∇f_t.
Second (uncentered variance): v_t = β2 v_{t-1} + (1-β2) (∇f_t)^2.
Correct bias: m_hat = m_t / (1 - β1^t), v_hat = v_t / (1 - β2^t).
Update: w_t = w_{t-1} - η m_hat / (sqrt(v_hat) + ε).
Defaults: β1=0.9, β2=0.999, ε=1e-8.
Derivations
Section titled “Derivations”- m_t ≈ E[∇f], v_t ≈ E[(∇f)^2].
- Bias correction for initialization (m_0=v_0=0).
Theoretical Guarantees
Section titled “Theoretical Guarantees”- Regret bound O(sqrt(T)) in online convex.
- Non-convex: Converges to stationary points under assumptions.
Issues: Original proof flawed; AMSGrad fixes by taking max past v.
ML Insight
Section titled “ML Insight”- Ubiquitous in transformers, CV—robust to noisy gradients.
Variants: AdamW decouples weight decay from adaptivity for better generalization.
5. Hyperparameter Tuning and Schedules
Section titled “5. Hyperparameter Tuning and Schedules”Tuning: Grid/random search, or auto like Optuna.
Learning rate: Warmup (linear increase early), then decay (cosine).
β1/β2: Default good, but tweak for stability.
ε: Small to avoid bias in low-grad areas.
In ML: Use validation loss for tuning.
6. Weight Decay and Regularization in Optimizers
Section titled “6. Weight Decay and Regularization in Optimizers”L2 reg: Add λ ||w||^2 /2 to loss, ∇ = ∇L + λ w.
In Adam: Apply decay to update, not grad (AdamW).
Helps prevent overfitting, controls explosion.
7. Convergence Speed and Generalization Tradeoffs
Section titled “7. Convergence Speed and Generalization Tradeoffs”Advanced opts converge faster but may overfit—SGD slower, better generalization.
Why? Noise in SGD acts as regularizer.
Hybrid: Start with Adam, switch to SGD.
Empirical studies: Adam good for sparse, SGD for dense.
8. Handling Non-Convexity and Saddles
Section titled “8. Handling Non-Convexity and Saddles”Saddles common in high-dim—momentum helps escape.
Adam’s adaptivity scales updates to push through flat directions.
Perturbations: Add noise to escape.
In ML: Batch norm, skip connections smooth landscapes.
9. Parallel and Distributed Optimization
Section titled “9. Parallel and Distributed Optimization”Async SGD (Hogwild!): Lock-free updates.
Sync: All-reduce gradients.
Adam in dist: Careful with moments syncing.
In Rust: Use rayon for parallel.
10. Numerical Implementations and Comparisons
Section titled “10. Numerical Implementations and Comparisons”Implement and compare on toy problem.
::: code-group
import numpy as npimport matplotlib.pyplot as plt
# Toy loss: Rosenbrock f(x,y) = (a-x)^2 + b(y-x^2)^2, a=1,b=100def rosenbrock(w, a=1, b=100): return (a - w[0])**2 + b * (w[1] - w[0]**2)**2
def grad_rosen(w, a=1, b=100): return np.array([-2*(a - w[0]) - 4*b*w[0]*(w[1] - w[0]**2), 2*b*(w[1] - w[0]**2)])
# Momentumdef momentum_opt(grad_func, w_init, eta=0.001, beta=0.9, epochs=1000): w = w_init.copy() v = np.zeros_like(w) path = [w.copy()] for _ in range(epochs): grad = grad_func(w) v = beta * v - eta * grad w += v path.append(w.copy()) return np.array(path)
# RMSPropdef rmsprop_opt(grad_func, w_init, eta=0.001, gamma=0.9, eps=1e-8, epochs=1000): w = w_init.copy() g = np.zeros_like(w) path = [w.copy()] for _ in range(epochs): grad = grad_func(w) g = gamma * g + (1 - gamma) * grad**2 update = -eta * grad / (np.sqrt(g) + eps) w += update path.append(w.copy()) return np.array(path)
# Adamdef adam_opt(grad_func, w_init, eta=0.001, beta1=0.9, beta2=0.999, eps=1e-8, epochs=1000): w = w_init.copy() m = np.zeros_like(w) v = np.zeros_like(w) t = 0 path = [w.copy()] for _ in range(epochs): grad = grad_func(w) t += 1 m = beta1 * m + (1 - beta1) * grad v = beta2 * v + (1 - beta2) * grad**2 m_hat = m / (1 - beta1**t) v_hat = v / (1 - beta2**t) w -= eta * m_hat / (np.sqrt(v_hat) + eps) path.append(w.copy()) return np.array(path)
w_init = np.array([-1.0, 1.0])path_mom = momentum_opt(grad_rosen, w_init)path_rms = rmsprop_opt(grad_rosen, w_init)path_adam = adam_opt(grad_rosen, w_init)
# Plot paths (conceptual, assume plotting code)print("Paths computed; visualize for comparison.")use ndarray::{array, Array1};
fn rosenbrock(w: &Array1<f64>, a: f64, b: f64) -> f64 { (a - w[0]).powi(2) + b * (w[1] - w[0].powi(2)).powi(2)}
fn grad_rosen(w: &Array1<f64>, a: f64, b: f64) -> Array1<f64> { array![ -2.0 * (a - w[0]) - 4.0 * b * w[0] * (w[1] - w[0].powi(2)), 2.0 * b * (w[1] - w[0].powi(2)) ]}
// Momentumfn momentum_opt(grad_func: fn(&Array1<f64>, f64, f64) -> Array1<f64>, w_init: Array1<f64>, eta: f64, beta: f64, epochs: usize) -> Vec<Array1<f64>> { let mut w = w_init.clone(); let mut v = Array1::zeros(w.len()); let mut path = vec![w.clone()]; for _ in 0..epochs { let grad = grad_func(&w, 1.0, 100.0); v = beta * &v - eta * &grad; w += &v; path.push(w.clone()); } path}
// Similar for RMSProp and Adam...
fn main() { let w_init = array![-1.0, 1.0]; let path_mom = momentum_opt(grad_rosen, w_init.clone(), 0.001, 0.9, 1000); // Print or visualize println!("Momentum path length: {}", path_mom.len()); // Implement RMSProp and Adam similarly}:::
Implements optimizers on Rosenbrock, classic test function.
11. Empirical Comparisons and Case Studies
Section titled “11. Empirical Comparisons and Case Studies”On MNIST/CIFAR: Adam faster initial, SGD+momentum better final acc.
In NLP: Adam for BERT fine-tuning.
Ablations: Vary β2 in Adam for stability.
12. Theoretical Extensions and Open Problems
Section titled “12. Theoretical Extensions and Open Problems”Provable adaptivity in non-convex.
Overparameterization: GD finds global in overparam nets.
Implicit bias: Optimizers prefer flat minima for generalization.
13. Key ML Takeaways
Section titled “13. Key ML Takeaways”- Momentum accelerates: In consistent directions.
- Adaptivity scales: Per-param handling.
- Adam robust default: But tune for task.
- Theory vs practice: Empirical tuning key.
- Code for testing: Simulate landscapes.
Advanced opts empower scalable ML.
14. Summary
Section titled “14. Summary”Deepened understanding of momentum, RMSProp, Adam through math, analysis, ML apps, with Python/Rust impls. This equips advanced optimization customization.
Next: Stochastic Approximations.
Word count: Approximately 2950.
Further Reading
Section titled “Further Reading”- Wilson et al., “The Marginal Value of Adaptive Gradient Methods”.
- Reddi et al., “On the Convergence of Adam and Beyond”.
- Loshchilov, Hutter, “Decoupled Weight Decay Regularization”.
- Rust: ‘argmin’, ‘optimize’ crates.