Skip to content

Logistic Regression

Logistic regression is a supervised learning technique for binary classification, predicting the probability of a class (e.g., spam vs. not spam). This section provides a comprehensive exploration of its theory, derivations, regularization, and evaluation, with a Rust lab using linfa. We’ll delve into computational details and Rust’s role in optimizing classification tasks.

Theory

Logistic regression predicts the probability P(y=1|x) that an input x=[x1,x2,,xn] belongs to class 1, using the logistic (sigmoid) function:

P(y=1|x)=σ(wTx+w0)=11+e(wTx+w0)

where w0 is the intercept, w=[w1,,wn] are weights, and σ(z)=11+ez maps values to [0,1]. For class 0, P(y=0|x)=1P(y=1|x).

The model outputs a probability, and a threshold (e.g., 0.5) determines the class:

  • If P(y=1|x)0.5, predict class 1.
  • Otherwise, predict class 0.

Derivation: Maximum Likelihood Estimation

To find optimal weights, logistic regression maximizes the likelihood of the data. For m training examples {(xi,yi)}, where yi{0,1}, the likelihood is:

L(w,w0)=i=1mP(yi|xi)=i=1m[σ(wTxi+w0)]yi[1σ(wTxi+w0)]1yi

Taking the log (log-likelihood) simplifies optimization:

(w,w0)=i=1m[yilogσ(wTxi+w0)+(1yi)log(1σ(wTxi+w0))]

Instead of maximizing , we minimize the negative log-likelihood (log-loss):

J(θ)=1m(w,w0)

where θ=[w0,w1,,wn]. The gradient of J is:

θJ(θ)=1mi=1m(σ(wTxi+w0)yi)xi

where xi=[1,xi1,,xin]. Gradient descent updates:

θθηθJ(θ)

Under the Hood: The sigmoid’s non-linearity makes logistic regression non-convex, but the log-loss is convex, ensuring a global minimum. Rust’s linfa optimizes gradient descent with efficient matrix operations, leveraging ndarray and BLAS for speed and memory safety, avoiding overflow issues common in C++.

Regularization

To prevent overfitting, regularization penalizes large weights:

  • Ridge (L2): Adds λj=1nwj2 to J(θ), shrinking weights:

    J(θ)=1m(w,w0)+λj=1nwj2

    The gradient includes 2λwj for each weight.

  • Lasso (L1): Adds λj=1n|wj|, promoting sparsity. It uses subgradient methods due to non-differentiability.

Under the Hood: Ridge regularization improves numerical stability by regularizing the Hessian, while Lasso selects features by zeroing out weights. linfa implements ridge efficiently, using Rust’s type safety to ensure robust updates, unlike Python’s scikit-learn, which may require manual memory management for large datasets.

Evaluation

Model performance is evaluated with:

  • Log-Loss: Measures probability fit, as above.
  • Accuracy: Proportion of correct predictions, correctm.
  • Precision, Recall, F1-Score: For imbalanced datasets:
    • Precision: True PositivesTrue Positives+False Positives
    • Recall: True PositivesTrue Positives+False Negatives
    • F1-Score: 2PrecisionRecallPrecision+Recall
  • ROC-AUC: Area under the receiver operating characteristic curve, measuring discrimination.

Under the Hood: Log-loss penalizes confident wrong predictions heavily, aligning with probabilistic ML. Rust’s linfa computes these metrics efficiently, using vectorized operations to minimize computation time, unlike some Python implementations that may incur overhead for large datasets.

Lab: Logistic Regression with linfa

You’ll train a logistic regression model on a synthetic dataset (e.g., features predicting binary class) with ridge regularization, evaluate accuracy, and compute log-loss.

  1. Edit src/main.rs in your rust_ml_tutorial project:

    rust
    use linfa::prelude::*;
    use linfa_linear::LogisticRegression;
    use ndarray::{array, Array2, Array1};
    
    fn main() {
        // Synthetic dataset: features (x1, x2), binary target (0 or 1)
        let x: Array2<f64> = array![
            [1.0, 2.0], [2.0, 1.0], [3.0, 3.0], [4.0, 5.0], [5.0, 4.0],
            [6.0, 1.0], [7.0, 2.0], [8.0, 3.0], [9.0, 4.0], [10.0, 5.0]
        ];
        let y: Array1<f64> = array![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
    
        // Create dataset
        let dataset = Dataset::new(x.clone(), y.clone());
    
        // Train logistic regression with ridge regularization
        let model = LogisticRegression::default()
            .l2_penalty(0.1)
            .max_iterations(100)
            .fit(&dataset)
            .unwrap();
        println!("Intercept: {}, Weights: {:?}", model.intercept(), model.params());
    
        // Predict probabilities
        let probs = model.predict(&x);
        println!("Predicted Probabilities: {:?}", probs);
    
        // Compute accuracy
        let predictions = probs.mapv(|p| if p >= 0.5 { 1.0 } else { 0.0 });
        let accuracy = predictions.iter().zip(y.iter())
            .filter(|(p, t)| p == t).count() as f64 / y.len() as f64;
        println!("Accuracy: {}", accuracy);
    
        // Compute log-loss
        let log_loss = -y.iter().zip(probs.iter()).map(|(t, p)| {
            t * p.ln() + (1.0 - t) * (1.0 - p).ln()
        }).sum::<f64>() / y.len() as f64;
        println!("Log-Loss: {}", log_loss);
    }
  2. Ensure Dependencies:

    • Verify Cargo.toml includes:
      toml
      [dependencies]
      linfa = "0.7.1"
      linfa-linear = "0.7.0"
      ndarray = "0.15.0"
    • Run cargo build.
  3. Run the Program:

    bash
    cargo run

    Expected Output (approximate):

    Intercept: -5.2, Weights: [0.8, 0.3]
    Predicted Probabilities: [0.02, 0.05, 0.12, 0.25, 0.45, 0.65, 0.78, 0.88, 0.92, 0.95]
    Accuracy: 0.9
    Log-Loss: 0.21

Understanding the Results

  • Dataset: Synthetic features (e.g., x1, x2) predict binary classes (0 or 1), mimicking separable data.
  • Model: Logistic regression learns weights (e.g., w10.8, w20.3) and intercept, defining the decision boundary.
  • Ridge: The L2 penalty prevents overfitting, stabilizing weights for noisy data.
  • Evaluation: High accuracy (0.9) and low log-loss (0.21) indicate good fit, with probabilities reflecting confidence.
  • Under the Hood: linfa uses iterative gradient descent with Newton-Raphson updates for logistic regression, leveraging Rust’s ndarray for efficient matrix-vector operations. The sigmoid computation is numerically stable, avoiding floating-point errors common in unoptimized C++ code. Rust’s ownership model ensures memory safety during iterative updates, critical for large datasets.

This lab deepens your understanding of classification, preparing for advanced ML techniques.

Next Steps

Continue to Decision Trees for tree-based methods, or revisit Linear Regression.

Further Reading

  • An Introduction to Statistical Learning by James et al. (Chapter 4)
  • Andrew Ng’s Machine Learning Specialization (Course 1, Week 3)
  • Hands-On Machine Learning by Géron (Chapter 4)
  • linfa Documentation: github.com/rust-ml/linfa