决策树
决策树是最直观的机器学习算法,通过一系列 if-else 规则进行分类或回归,结果完全可解释。
核心概念
三个要素
- 根节点:包含全部样本的起始节点
- 内部节点:对应一个特征测试(如「年龄 < 30?」)
- 叶节点:最终的预测结果(类别或数值)
分裂目标
每次分裂要尽可能使子节点「纯净」——同一类样本集中在一起。
分裂标准
信息增益(ID3)
用熵衡量不纯度,分裂后熵减少越多越好:
信息增益率(C4.5)
对信息增益除以分裂信息,惩罚产生过多分支的分裂:
基尼系数(CART)
CART 树使用的标准,计算更快:
手写决策树
import numpy as np
from collections import Counter
class SimpleDecisionTree:
def __init__(self, max_depth=5, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
def _gini(self, y):
"""计算基尼系数"""
_, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
return 1 - np.sum(probs ** 2)
def _best_split(self, X, y):
"""穷举所有可能的分裂点,找最优"""
best_gain = -1
best_feature, best_threshold = None, None
n = len(y)
parent_gini = self._gini(y)
for feature in range(X.shape[1]):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left = y[X[:, feature] <= threshold]
right = y[X[:, feature] > threshold]
if len(left) < self.min_samples_split or len(right) < self.min_samples_split:
continue
# 加权子节点基尼
child_gini = (len(left) / n) * self._gini(left) + \
(len(right) / n) * self._gini(right)
gain = parent_gini - child_gini
if gain > best_gain:
best_gain, best_feature, best_threshold = gain, feature, threshold
return best_feature, best_threshold
def _build_tree(self, X, y, depth):
# 终止条件
if depth >= self.max_depth or len(np.unique(y)) == 1 or len(y) < self.min_samples_split:
return Counter(y).most_common(1)[0][0]
feature, threshold = self._best_split(X, y)
if feature is None: # 无法再分裂
return Counter(y).most_common(1)[0][0]
left_idx = X[:, feature] <= threshold
right_idx = X[:, feature] > threshold
return {
'feature': feature,
'threshold': threshold,
'left': self._build_tree(X[left_idx], y[left_idx], depth + 1),
'right': self._build_tree(X[right_idx], y[right_idx], depth + 1),
}
def fit(self, X, y):
self.tree = self._build_tree(X, y, depth=0)
def _predict_one(self, x, node):
if not isinstance(node, dict):
return node
if x[node['feature']] <= node['threshold']:
return self._predict_one(x, node['left'])
else:
return self._predict_one(x, node['right'])
def predict(self, X):
return np.array([self._predict_one(x, self.tree) for x in X])
Scikit-learn 实现
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
X, y = load_iris(return_X_y=True)
model = DecisionTreeClassifier(max_depth=3, min_samples_split=5)
model.fit(X, y)
# 可视化决策树
plt.figure(figsize=(12, 8))
plot_tree(model, filled=True, feature_names=load_iris().feature_names,
class_names=load_iris().target_names)
plt.show()
print(f"特征重要性: {model.feature_importances_}")
剪枝
决策树容易过拟合(无限制生长会把每个样本当成一个叶节点),需要剪枝:
| 方法 | 说明 |
|---|---|
| 预剪枝 | 限制 max_depth、min_samples_split、min_samples_leaf |
| 后剪枝 | 先建完整树,再用验证集自底向上剪掉冗余分支(CCP 剪枝) |
# Scikit-learn 的 CCP 后剪枝
path = model.cost_complexity_pruning_path(X, y)
# 取最优 ccp_alpha
总结
| 特性 | 说明 |
|---|---|
| 优点 | 直观可解释、不需要特征缩放、能处理非线性关系 |
| 缺点 | 容易过拟合、对数据旋转敏感、不稳定(小变化可能导致完全不同的树) |
| 适用场景 | 需要可解释性的场景、作为随机森林/Gradient Boosting 的子模块 |