Skip to content

First ML Lab

This section introduces your first machine learning (ML) task: linear regression using the linfa library in Rust. You’ll train a model to predict a continuous output, learning the basics of supervised learning. No prior ML experience is required.

What is Linear Regression?

Linear regression predicts a continuous value (e.g., house prices) from input features (e.g., size, location). It fits a line to the data, minimizing the sum of squared errors (least squares).

Mathematical Basis: Given features x and output y, the model predicts:

y=w0+w1x1+w2x2++wnxn

where w0 is the intercept, and w1,,wn are weights learned by minimizing:

Loss=(yy^)2

This lab uses linfa to compute these weights.

Lab: Linear Regression with linfa

You’ll train a linear regression model on a synthetic dataset (e.g., predicting values from a single feature) and test its predictions.

  1. Edit src/main.rs in your rust_ml_tutorial project:

    rust
    use linfa::prelude::*;
    use linfa_linear::LinearRegression;
    use ndarray::{array, Array1};
    
    fn main() {
        // Synthetic dataset: feature (x) and target (y)
        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
        let y = array![2.1, 4.2, 6.1, 8.3, 10.0];
    
        // Create dataset
        let dataset = Dataset::new(x, y);
    
        // Train linear regression model
        let model = LinearRegression::default().fit(&dataset).unwrap();
    
        // Predict on new data
        let new_x = array![[6.0]];
        let prediction = model.predict(&new_x);
        println!("Prediction for x=6: {}", prediction[0]);
    
        // Print model parameters
        let intercept = model.intercept();
        let weights = model.params();
        println!("Intercept: {}, Weights: {:?}", intercept, weights);
    }
  2. Ensure Dependencies:

    • Verify Cargo.toml includes:
      toml
      [dependencies]
      linfa = "0.7.1"
      linfa-linear = "0.7.0"
      ndarray = "0.15.0"
    • Run cargo build to fetch dependencies.
  3. Run the Program:

    bash
    cargo run

    Expected Output:

    Prediction for x=6: ~12.06
    Intercept: ~0.14, Weights: [~2.0]

    The model predicts y12.06 for x=6, with a line y0.14+2.0x.

Understanding the Results

  • Dataset: The synthetic data mimics a linear relationship (y2x).
  • Model: linfa’s LinearRegression learns weights (w12.0) and intercept (w00.14).
  • Prediction: The model generalizes to new data, predicting y for x=6.

This lab introduces supervised learning, setting the stage for the Core Machine Learning module.

Learning from Official Resources

For deeper Rust skills, explore:

  • The Rust Programming Language (The Book): Free guide at doc.rust-lang.org/book.
  • Programming Rust: Book by Blandy, Orendorff, and Tindall, ideal for ML.

Next Steps

Proceed to Mathematical Foundations for ML’s mathematical basis, or revisit Rust Basics.

Further Reading

  • An Introduction to Statistical Learning by James et al. (Chapter 3)
  • Andrew Ng’s Machine Learning Specialization (Course 1, Week 1)
  • linfa Documentation: github.com/rust-ml/linfa