Recurrent Neural Networks
Recurrent Neural Networks
Section titled “Recurrent Neural Networks”Recurrent Neural Networks (RNNs) are specialized neural networks for processing sequential data, such as time series or text, by maintaining a hidden state that captures temporal dependencies. This section provides a comprehensive exploration of RNN architecture, backpropagation through time (BPTT), and variants like Long Short-Term Memory (LSTM) units, with a Rust lab using tch-rs. We’ll dive into sequence processing mechanics, gradient computation challenges, and Rust’s performance advantages, building on convolutional neural networks.
Theory
Section titled “Theory”RNNs process a sequence , where is the input at time , producing outputs . The hidden state is updated recursively:
where , , are parameters, and is a non-linear activation (e.g., tanh, ). The output is:
For classification, is passed through a softmax to predict probabilities.
Derivation: Backpropagation Through Time
Section titled “Derivation: Backpropagation Through Time”RNNs are trained to minimize a loss, such as cross-entropy for sequence classification:
where includes , , , , , is 1 if sample at time is class , and is the batch size. Backpropagation Through Time (BPTT) computes gradients by unrolling the RNN over time steps, treating it as a deep feedforward network with shared weights.
For a single sample, the loss at time is . The gradient for is:
The error term is:
where , and is the activation derivative (e.g., for tanh, ). The weight gradient is:
Gradients are summed over time steps and averaged over the batch.
Under the Hood: BPTT unrolls the RNN, creating a deep computational graph, costing per sample for hidden units. Long sequences cause vanishing gradients (gradients shrink exponentially) or exploding gradients (grow uncontrollably). tch-rs mitigates this with gradient clipping, leveraging Rust’s memory safety to prevent tensor corruption during unrolling, unlike C++ where pointer errors risk crashes. Rust’s compiled performance outpaces Python’s pytorch for CPU-bound BPTT, reducing latency.
LSTM: Addressing Gradient Issues
Section titled “LSTM: Addressing Gradient Issues”Long Short-Term Memory (LSTM) units address vanishing gradients by introducing gates to control information flow:
- Forget Gate:
- Input Gate:
- Cell Update:
- Cell State:
- Output Gate:
- Hidden State:
The cell state retains long-term dependencies, with gates modulating updates. Gradients flow through additive updates, avoiding vanishing issues.
Under the Hood: LSTMs increase computational cost ( per gate) but improve training stability. tch-rs optimizes gate computations with vectorized operations, ensuring memory efficiency via Rust’s ownership model, unlike Python’s dynamic allocation, which can fragment memory for long sequences.
Optimization
Section titled “Optimization”RNNs are trained with BPTT and optimizers like Adam, minimizing the loss. Regularization (e.g., dropout, penalty) prevents overfitting. Truncated BPTT limits unrolling to steps, balancing accuracy and computation.
Under the Hood: Truncated BPTT reduces memory usage but may miss long-term dependencies. tch-rs implements efficient truncation, leveraging Rust’s zero-cost abstractions for performance, outpacing Python’s pytorch for CPU tasks. Rust’s type safety ensures correct tensor shapes, preventing runtime errors common in C++ during sequence unrolling.
Evaluation
Section titled “Evaluation”Performance is evaluated with:
- Classification: Accuracy, Precision, Recall, F1-Score, ROC-AUC.
- Regression: MSE, RMSE, MAE.
- Perplexity (for language models): .
Under the Hood: Perplexity measures sequence prediction quality, with lower values indicating better models. tch-rs computes metrics efficiently, using GPU acceleration when available, with Rust’s compiled performance reducing overhead compared to Python’s interpreter.
Lab: LSTM with tch-rs
Section titled “Lab: LSTM with tch-rs”You’ll train an LSTM on a synthetic sequence dataset for binary classification, evaluating accuracy and loss.
-
Edit
src/main.rsin yourrust_ml_tutorialproject:use tch::{nn, nn::Module, nn::OptimizerConfig, Device, Tensor};use ndarray::{array, Array3, Array2};fn main() -> Result<(), tch::TchError> {// Synthetic dataset: 10 sequences, 5 time steps, 2 featureslet x: Array3<f64> = array![// Class 0: low values[[0.1, 0.2], [0.2, 0.3], [0.1, 0.2], [0.3, 0.4], [0.2, 0.3]],[[0.2, 0.1], [0.3, 0.2], [0.2, 0.1], [0.4, 0.3], [0.3, 0.2]],[[0.1, 0.3], [0.2, 0.4], [0.1, 0.3], [0.3, 0.2], [0.2, 0.1]],[[0.3, 0.2], [0.4, 0.3], [0.3, 0.2], [0.2, 0.1], [0.1, 0.2]],[[0.2, 0.3], [0.1, 0.2], [0.2, 0.3], [0.1, 0.4], [0.3, 0.2]],// Class 1: high values[[0.9, 0.8], [0.8, 0.9], [0.9, 0.8], [0.7, 0.9], [0.8, 0.7]],[[0.8, 0.9], [0.9, 0.8], [0.8, 0.7], [0.9, 0.8], [0.7, 0.9]],[[0.7, 0.8], [0.8, 0.9], [0.9, 0.7], [0.8, 0.9], [0.9, 0.8]],[[0.9, 0.7], [0.8, 0.8], [0.7, 0.9], [0.9, 0.8], [0.8, 0.7]],[[0.8, 0.9], [0.9, 0.8], [0.8, 0.9], [0.7, 0.8], [0.9, 0.7]],];let y: Array2<f64> = array![[0.0], [0.0], [0.0], [0.0], [0.0], [1.0], [1.0], [1.0], [1.0], [1.0]];// Convert to tensorslet device = Device::Cpu;let xs = Tensor::from_slice(x.as_slice().unwrap()).to_device(device).reshape(&[10, 5, 2]);let ys = Tensor::from_slice(y.as_slice().unwrap()).to_device(device).reshape(&[10, 1]);// Define LSTMlet vs = nn::VarStore::new(device);let lstm_config = nn::LSTMConfig { hidden_size: 10, num_layers: 1, ..Default::default() };let net = nn::seq().add(nn::lstm(&vs.root() / "lstm", 2, 10, lstm_config)).add_fn(|xs| xs.slice(1, 4, 5, 1)) // Take last time step.add(nn::linear(&vs.root() / "fc", 10, 1, Default::default())).add_fn(|xs| xs.sigmoid());// Optimizer (Adam)let mut opt = nn::Adam::default().build(&vs, 0.01)?;// Training loopfor epoch in 1..=100 {let logits = net.forward(&xs);let loss = logits.binary_cross_entropy_with_logits::<Tensor>(&ys, None, None, tch::Reduction::Mean);opt.zero_grad();loss.backward();opt.step();if epoch % 20 == 0 {println!("Epoch: {}, Loss: {}", epoch, f64::from(loss));}}// Evaluate accuracylet preds = net.forward(&xs).ge(0.5).to_kind(tch::Kind::Float);let correct = preds.eq_tensor(&ys).sum(tch::Kind::Int64);let accuracy = f64::from(&correct) / y.len() as f64;println!("Accuracy: {}", accuracy);Ok(())} -
Ensure Dependencies:
- Verify
Cargo.tomlincludes:[dependencies]tch = "0.17.0"ndarray = "0.15.0" - Run
cargo build.
- Verify
-
Run the Program:
Terminal window cargo runExpected Output (approximate):
Epoch: 20, Loss: 0.50Epoch: 40, Loss: 0.35Epoch: 60, Loss: 0.25Epoch: 80, Loss: 0.18Epoch: 100, Loss: 0.12Accuracy: 0.90
Understanding the Results
Section titled “Understanding the Results”- Dataset: Synthetic sequences (10 samples, 5 time steps, 2 features) represent two classes (low vs. high values), mimicking time-series data.
- Model: An LSTM with 10 hidden units processes sequences, outputting a class prediction at the last time step, achieving ~90% accuracy.
- Loss: The cross-entropy loss decreases (~0.12), indicating convergence.
- Under the Hood:
tch-rsleverages PyTorch’s optimized LSTM routines, with Rust’s memory safety preventing tensor mismanagement during sequence unrolling, a risk in C++ BPTT. The LSTM’s gates mitigate vanishing gradients, enabling longer sequence modeling than vanilla RNNs. Rust’s compiled performance reduces training time compared to Python’spytorch, especially for CPU-bound tasks with many time steps. - Evaluation: High accuracy confirms effective sequence learning, though validation data would detect overfitting in practice.
This lab introduces sequence modeling, preparing for optimization techniques.
Next Steps
Section titled “Next Steps”Further Reading
Section titled “Further Reading”- Deep Learning by Goodfellow et al. (Chapter 10)
- Hands-On Machine Learning by Géron (Chapter 16)
tch-rsDocumentation: github.com/LaurentMazare/tch-rs