Causal Inference - Correlation vs. Causation
Causal inference aims to determine cause-and-effect relationships from data, distinguishing true causation from mere correlation. In machine learning (ML), understanding causation is crucial for interpretability, fairness, and robust decision-making, as correlated features may not imply causal links. Techniques like randomized controlled trials (RCTs) and instrumental variables (IV) help isolate causal effects, enabling better model design and policy evaluation.
This seventeenth lecture in the "Statistics Foundations for AI/ML" series builds on time series and experimental design, exploring correlation vs. causation, causal graphs, counterfactuals, causal inference methods, and ML applications. We'll provide intuitive explanations, mathematical foundations, and practical implementations in Python and Rust, preparing you for the series conclusion.
1. Intuition Behind Causal Inference
Correlation measures association, but causation requires that changing one variable affects the other. "Ice cream sales correlate with drownings" (both rise in summer), but ice cream doesn't cause drownings—temperature confounds.
Causal inference asks "what if?" to estimate effects.
ML Connection
- Interpretability: Causal models explain why predictions happen.
- Fairness: Avoid biased decisions from spurious correlations.
INFO
Correlation is a hint; causation is the proof—like seeing smoke (correlation) vs. confirming fire (causation).
Example
- Smoking correlates with lung cancer, but RCTs/causal methods confirm causation.
2. Correlation vs. Causation: The Pitfall
Correlation: ρ(X,Y) measures linear relation.
Causation requires:
- Temporal precedence.
- No confounders.
- Mechanism.
Common pitfalls:
- Confounding: Z causes both X,Y.
- Selection bias.
- Reverse causation.
ML Application
- Spurious correlations in data lead to brittle models.
3. Causal Graphs and DAGs
Directed Acyclic Graphs (DAGs): Nodes variables, arrows causal links.
d-separation: Paths blocked by conditioning.
Do-Calculus: P(Y|do(X=x)) = sum_z P(Y|X=x,Z=z) P(Z) if Z blocks backdoors.
ML Insight
- Causal discovery algorithms learn DAGs from data.
4. Counterfactuals and Potential Outcomes
Potential Outcomes: Y(X=x) hypothetical outcome if X set to x.
Causal Effect: E[Y(X=1) - Y(X=0)].
Average Treatment Effect (ATE): E[Y(1) - Y(0)].
Identification: Estimate from observed data.
ML Connection
- Uplift modeling estimates individual treatment effects.
5. Randomized Controlled Trials (RCTs)
Random assignment to treatment/control eliminates confounders.
ATE: Mean_treatment - Mean_control.
Gold standard but expensive/ethical issues.
In ML: A/B testing for features.
6. Observational Methods: Propensity Score, IV
Propensity Score Matching (PSM): Estimate P(Treatment|covariates), match treated/untreated.
Instrumental Variables (IV): Z correlates with X but not Y except through X.
Two-Stage Least Squares (2SLS): Regress X on Z, then Y on predicted X.
ML Application
- Causal ML libraries like DoWhy.
7. Difference-in-Differences (DiD)
Compare treated/control before/after intervention.
ATE: (Post_treat - Pre_treat) - (Post_control - Pre_control).
Assumes parallel trends.
In ML: Policy evaluation.
8. Applications in Machine Learning
- Causal Discovery: Learn DAGs (e.g., PC algorithm).
- Treatment Effects: Uplift modeling in marketing.
- Fairness: Identify causal bias paths.
- Counterfactuals: Explain predictions (what-if).
Challenges
- Unobserved confounders.
- High-dim data.
9. Numerical Causal Inference
Implement PSM, IV.
import numpy as np
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neighbors import NearestNeighbors
# PSM
# Simulate data: Treatment T, outcome Y, confounder Z
Z = np.random.normal(0, 1, 100)
T = (Z + np.random.normal(0, 1, 100) > 0).astype(int)
Y = T + Z + np.random.normal(0, 1, 100)
X = Z.reshape(-1,1)
# Propensity scores
lr = LogisticRegression()
lr.fit(X, T)
ps = lr.predict_proba(X)[:,1]
# Match
nn = NearestNeighbors(n_neighbors=1)
nn.fit(ps[T==0].reshape(-1,1))
dist, idx = nn.kneighbors(ps[T==1].reshape(-1,1))
matched_Y_control = Y[T==0][idx.flatten()]
ate_psm = np.mean(Y[T==1] - matched_Y_control)
print("ATE PSM:", ate_psm)
# IV (2SLS)
# Z IV for T, T affects Y
Z_iv = np.random.normal(0, 1, 100)
T_iv = Z_iv + np.random.normal(0, 1, 100)
Y_iv = T_iv + np.random.normal(0, 1, 100)
lr1 = LinearRegression().fit(Z_iv.reshape(-1,1), T_iv)
T_hat = lr1.predict(Z_iv.reshape(-1,1))
lr2 = LinearRegression().fit(T_hat.reshape(-1,1), Y_iv)
ate_iv = lr2.coef_[0]
print("ATE IV:", ate_iv)
# ML: Causal discovery (simplified)
from pgmpy.models import BayesianModel
from pgmpy.estimators import PC
df = pd.DataFrame({'Z': Z, 'T': T, 'Y': Y})
pc = PC(df)
model = pc.estimate()
print("Inferred edges:", model.edges())
fn psm(z: &[f64], t: &[u8], y: &[f64]) -> f64 {
// Simplified PSM (logistic approx, match)
let mut ps = vec![0.0; z.len()];
for i in 0..z.len() {
ps[i] = 1.0 / (1.0 + (-z[i]).exp()); // Placeholder propensity
}
let mut treated_y = vec![];
let mut control_y = vec![];
for i in 0..z.len() {
if t[i] == 1 {
treated_y.push(y[i]);
// Find nearest control (simplified)
let mut min_diff = f64::INFINITY;
let mut matched = 0.0;
for j in 0..z.len() {
if t[j] == 0 {
let diff = (ps[i] - ps[j]).abs();
if diff < min_diff {
min_diff = diff;
matched = y[j];
}
}
}
control_y.push(matched);
}
}
treated_y.iter().zip(control_y.iter()).map(|(&ty, &cy)| ty - cy).sum::<f64>() / treated_y.len() as f64
}
fn main() {
let mut rng = rand::thread_rng();
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
let z: Vec<f64> = (0..100).map(|_| normal.sample(&mut rng)).collect();
let t: Vec<u8> = z.iter().map(|&zi| if zi + normal.sample(&mut rng) > 0.0 { 1 } else { 0 }).collect();
let y: Vec<f64> = t.iter().map(|&ti| ti as f64 + z[&t.iter().position(|&x| x == ti).unwrap()] + normal.sample(&mut rng)).collect(); // Simplified
let ate = psm(&z, &t, &y);
println!("ATE PSM: {}", ate);
// IV (2SLS)
// Omit for brevity, similar linear regression implementation
}
Implements PSM, IV, causal discovery.
8. Theoretical Foundations
Potential Outcomes: Fundamental for ATE.
DAGs: Identify confounders, IVs.
Do-Calculus: Rules for identification.
ML Insight
- Causal graphs guide ML model design.
9. Challenges in ML Causal Inference
- Unobserved Confounders: Hidden bias.
- Data Requirements: RCTs expensive.
- High-Dim: Complex graphs.
10. Key ML Takeaways
- Causation beyond correlation: Avoid spurious conclusions.
- DAGs model causality: For identification.
- RCTs gold standard: For experiments.
- Observational methods practical: PSM, IV.
- Code implements: Causal inference.
Causal inference enhances ML reliability.
11. Summary
Explored causal inference, correlation vs. causation, causal graphs, methods like RCTs, PSM, IV, with ML applications. Examples and Python/Rust code bridge theory to practice. Prepares for experimental design.
Word count: Approximately 3000.
Further Reading
- Pearl, Causality.
- Hernán, Robins, Causal Inference: What If.
- Imbens, Rubin, Causal Inference for Statistics.
- Rust: 'nalgebra' for data, custom causal methods.