Appearance
Decision Trees
Decision trees are versatile machine learning models for classification and regression, recursively splitting data based on feature values to make decisions. Random forests, an ensemble of trees, enhance robustness. This section provides a comprehensive exploration of their theory, derivations, ensemble methods, and evaluation, with a Rust lab using linfa
. We’ll delve into computational details and Rust’s advantages in tree-based algorithms.
Theory
A decision tree represents a series of decisions as a tree structure, with nodes (decision points), branches (feature-based splits), and leaves (predictions). For classification, leaves predict class labels (e.g., spam vs. not spam); for regression, they predict continuous values (e.g., house prices).
Given a dataset with features
Derivation: Splitting Criteria
For classification, splits reduce impurity, measured by:
Gini Index: Probability of misclassification if labels are randomly assigned:
where
is the proportion of class in a node, and is the number of classes. Entropy: Uncertainty of class distribution:
The information gain for a split is:
where
is the number of child nodes, is the number of samples in child , and is the parent’s sample count.
For regression, splits minimize variance:
where
Under the Hood: Splitting involves evaluating all possible thresholds for each feature, which is computationally intensive (linfa
optimizes this with efficient sorting and iteration, leveraging memory safety to prevent errors during recursive tree construction, unlike C++ where manual memory management risks bugs.
Random Forests
Decision trees can overfit, memorizing training data. Random forests mitigate this by averaging predictions from multiple trees, each trained on:
- Bootstrap Samples: Random subsets of the data with replacement.
- Random Feature Subsets: A random subset of features at each split, reducing correlation.
The ensemble prediction is:
- Classification: Majority vote across trees.
- Regression: Average of tree predictions.
Under the Hood: Random forests trade bias for lower variance, improving generalization. The randomness in feature selection reduces tree correlation, enhancing robustness. linfa
implements parallel tree training, exploiting Rust’s concurrency model to speed up computation, outperforming Python’s scikit-learn
for large datasets due to zero-cost abstractions.
Evaluation
Performance is assessed with:
- Classification:
- Accuracy:
. - Precision, Recall, F1-Score: For imbalanced classes, as in logistic regression.
- ROC-AUC: Area under the ROC curve.
- Accuracy:
- Regression:
- Mean Squared Error (MSE):
. - R-squared (
): .
- Mean Squared Error (MSE):
Under the Hood: Random forests provide feature importance scores, computed as the average reduction in impurity across trees, aiding interpretability. Rust’s linfa
efficiently tracks these metrics, using vectorized operations to minimize overhead, unlike some Python libraries that may slow down with large feature sets.
Lab: Random Forest with linfa
You’ll train a random forest classifier on a synthetic dataset (e.g., features predicting binary class) and evaluate accuracy and feature importance.
Edit
src/main.rs
in yourrust_ml_tutorial
project:rustuse linfa::prelude::*; use linfa_trees::DecisionTreeClassifier; use ndarray::{array, Array2, Array1}; fn main() { // Synthetic dataset: features (x1, x2), binary target (0 or 1) let x: Array2<f64> = array![ [1.0, 2.0], [2.0, 1.0], [3.0, 3.0], [4.0, 5.0], [5.0, 4.0], [6.0, 1.0], [7.0, 2.0], [8.0, 3.0], [9.0, 4.0], [10.0, 5.0] ]; let y: Array1<f64> = array![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]; // Create dataset let dataset = Dataset::new(x.clone(), y.clone()); // Train random forest (ensemble of decision trees) let model = DecisionTreeClassifier::new() .max_depth(Some(3)) .min_samples_split(2) .fit(&dataset) .unwrap(); // Predict classes let predictions = model.predict(&x); println!("Predictions: {:?}", predictions); // Compute accuracy let accuracy = predictions.iter().zip(y.iter()) .filter(|(p, t)| p == t).count() as f64 / y.len() as f64; println!("Accuracy: {}", accuracy); // Simulate feature importance (not directly supported in linfa yet) let importance = model.feature_importance(); println!("Feature Importance: {:?}", importance.unwrap_or(vec![0.0; x.ncols()])); }
Ensure Dependencies:
- Verify
Cargo.toml
includes:toml[dependencies] linfa = "0.7.1" linfa-trees = "0.7.0" ndarray = "0.15.0"
- Run
cargo build
.
- Verify
Run the Program:
bashcargo run
Expected Output (approximate):
Predictions: [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] Accuracy: 1.0 Feature Importance: [0.0, 0.0] // Placeholder, as linfa may not yet support this
Understanding the Results
- Dataset: Synthetic features (
, ) predict binary classes (0 or 1), designed to be separable. - Model: The random forest (simulated as a single tree here due to
linfa
’s current implementation) learns splits based on Gini impurity, achieving perfect accuracy. - Feature Importance: Not fully supported in
linfa
yet, but typically reflects feature contribution to splits. Rust’slinfa
ensures efficient tree traversal, minimizing memory usage compared to Python’sscikit-learn
, which may allocate redundant memory for large trees. - Under the Hood: Decision trees partition the feature space recursively, with each split optimized for impurity reduction. Rust’s
linfa
uses safe iterators and immutable data structures, preventing memory corruption during tree construction, a risk in C++ implementations. Random forests’ parallelism, when implemented, will leverage Rust’srayon
crate for scalable multi-threading.
This lab introduces tree-based methods, preparing for ensemble techniques.
Next Steps
Continue to Support Vector Machines for advanced classification, or revisit Logistic Regression.
Further Reading
- An Introduction to Statistical Learning by James et al. (Chapter 8)
- Hands-On Machine Learning by Géron (Chapters 6–7)
linfa
Documentation: github.com/rust-ml/linfa