Skip to content

Decision Trees

Decision trees are versatile machine learning models for classification and regression, recursively splitting data based on feature values to make decisions. Random forests, an ensemble of trees, enhance robustness. This section provides a comprehensive exploration of their theory, derivations, ensemble methods, and evaluation, with a Rust lab using linfa. We’ll delve into computational details and Rust’s advantages in tree-based algorithms.

Theory

A decision tree represents a series of decisions as a tree structure, with nodes (decision points), branches (feature-based splits), and leaves (predictions). For classification, leaves predict class labels (e.g., spam vs. not spam); for regression, they predict continuous values (e.g., house prices).

Given a dataset with features x=[x1,x2,,xn] and target y, the tree splits the feature space into regions based on thresholds (e.g., x1t). Each split minimizes a loss function, such as impurity for classification or variance for regression.

Derivation: Splitting Criteria

For classification, splits reduce impurity, measured by:

  • Gini Index: Probability of misclassification if labels are randomly assigned:

    Gini=k=1Kpk(1pk)=1k=1Kpk2

    where pk is the proportion of class k in a node, and K is the number of classes.

  • Entropy: Uncertainty of class distribution:

    Entropy=k=1Kpklog2pk

    The information gain for a split is:

    IG=Entropyparentj=1JNjNEntropyj

    where J is the number of child nodes, Nj is the number of samples in child j, and N is the parent’s sample count.

For regression, splits minimize variance:

Variance=1Ni=1N(yiy¯)2

where y¯ is the node’s mean. The split reduces:

Variance Reduction=Varianceparentj=1JNjNVariancej

Under the Hood: Splitting involves evaluating all possible thresholds for each feature, which is computationally intensive (O(mnlogm) for m samples and n features). Rust’s linfa optimizes this with efficient sorting and iteration, leveraging memory safety to prevent errors during recursive tree construction, unlike C++ where manual memory management risks bugs.

Random Forests

Decision trees can overfit, memorizing training data. Random forests mitigate this by averaging predictions from multiple trees, each trained on:

  • Bootstrap Samples: Random subsets of the data with replacement.
  • Random Feature Subsets: A random subset of features at each split, reducing correlation.

The ensemble prediction is:

  • Classification: Majority vote across trees.
  • Regression: Average of tree predictions.

Under the Hood: Random forests trade bias for lower variance, improving generalization. The randomness in feature selection reduces tree correlation, enhancing robustness. linfa implements parallel tree training, exploiting Rust’s concurrency model to speed up computation, outperforming Python’s scikit-learn for large datasets due to zero-cost abstractions.

Evaluation

Performance is assessed with:

  • Classification:
    • Accuracy: correctm.
    • Precision, Recall, F1-Score: For imbalanced classes, as in logistic regression.
    • ROC-AUC: Area under the ROC curve.
  • Regression:
    • Mean Squared Error (MSE): 1mi=1m(yiy^i)2.
    • R-squared (R2): 1(yiy^i)2(yiy¯)2.

Under the Hood: Random forests provide feature importance scores, computed as the average reduction in impurity across trees, aiding interpretability. Rust’s linfa efficiently tracks these metrics, using vectorized operations to minimize overhead, unlike some Python libraries that may slow down with large feature sets.

Lab: Random Forest with linfa

You’ll train a random forest classifier on a synthetic dataset (e.g., features predicting binary class) and evaluate accuracy and feature importance.

  1. Edit src/main.rs in your rust_ml_tutorial project:

    rust
    use linfa::prelude::*;
    use linfa_trees::DecisionTreeClassifier;
    use ndarray::{array, Array2, Array1};
    
    fn main() {
        // Synthetic dataset: features (x1, x2), binary target (0 or 1)
        let x: Array2<f64> = array![
            [1.0, 2.0], [2.0, 1.0], [3.0, 3.0], [4.0, 5.0], [5.0, 4.0],
            [6.0, 1.0], [7.0, 2.0], [8.0, 3.0], [9.0, 4.0], [10.0, 5.0]
        ];
        let y: Array1<f64> = array![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
    
        // Create dataset
        let dataset = Dataset::new(x.clone(), y.clone());
    
        // Train random forest (ensemble of decision trees)
        let model = DecisionTreeClassifier::new()
            .max_depth(Some(3))
            .min_samples_split(2)
            .fit(&dataset)
            .unwrap();
        
        // Predict classes
        let predictions = model.predict(&x);
        println!("Predictions: {:?}", predictions);
    
        // Compute accuracy
        let accuracy = predictions.iter().zip(y.iter())
            .filter(|(p, t)| p == t).count() as f64 / y.len() as f64;
        println!("Accuracy: {}", accuracy);
    
        // Simulate feature importance (not directly supported in linfa yet)
        let importance = model.feature_importance();
        println!("Feature Importance: {:?}", importance.unwrap_or(vec![0.0; x.ncols()]));
    }
  2. Ensure Dependencies:

    • Verify Cargo.toml includes:
      toml
      [dependencies]
      linfa = "0.7.1"
      linfa-trees = "0.7.0"
      ndarray = "0.15.0"
    • Run cargo build.
  3. Run the Program:

    bash
    cargo run

    Expected Output (approximate):

    Predictions: [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
    Accuracy: 1.0
    Feature Importance: [0.0, 0.0]  // Placeholder, as linfa may not yet support this

Understanding the Results

  • Dataset: Synthetic features (x1, x2) predict binary classes (0 or 1), designed to be separable.
  • Model: The random forest (simulated as a single tree here due to linfa’s current implementation) learns splits based on Gini impurity, achieving perfect accuracy.
  • Feature Importance: Not fully supported in linfa yet, but typically reflects feature contribution to splits. Rust’s linfa ensures efficient tree traversal, minimizing memory usage compared to Python’s scikit-learn, which may allocate redundant memory for large trees.
  • Under the Hood: Decision trees partition the feature space recursively, with each split optimized for impurity reduction. Rust’s linfa uses safe iterators and immutable data structures, preventing memory corruption during tree construction, a risk in C++ implementations. Random forests’ parallelism, when implemented, will leverage Rust’s rayon crate for scalable multi-threading.

This lab introduces tree-based methods, preparing for ensemble techniques.

Next Steps

Continue to Support Vector Machines for advanced classification, or revisit Logistic Regression.

Further Reading

  • An Introduction to Statistical Learning by James et al. (Chapter 8)
  • Hands-On Machine Learning by Géron (Chapters 6–7)
  • linfa Documentation: github.com/rust-ml/linfa