跳到主要内容

决策树

决策树是最直观的机器学习算法,通过一系列 if-else 规则进行分类或回归,结果完全可解释。

核心概念

三个要素

  • 根节点:包含全部样本的起始节点
  • 内部节点:对应一个特征测试(如「年龄 < 30?」)
  • 叶节点:最终的预测结果(类别或数值)

分裂目标

每次分裂要尽可能使子节点「纯净」——同一类样本集中在一起。

分裂标准

信息增益(ID3)

用熵衡量不纯度,分裂后熵减少越多越好:

Entropy(S)=cpclog2pc\text{Entropy}(S) = -\sum_{c} p_c \log_2 p_c InfoGain=Entropy(S)vSvSEntropy(Sv)\text{InfoGain} = \text{Entropy}(S) - \sum_{v} \frac{|S_v|}{|S|} \text{Entropy}(S_v)

信息增益率(C4.5)

对信息增益除以分裂信息,惩罚产生过多分支的分裂:

GainRatio=InfoGainSplitInfo\text{GainRatio} = \frac{\text{InfoGain}}{\text{SplitInfo}}

基尼系数(CART)

CART 树使用的标准,计算更快:

Gini(S)=1cpc2\text{Gini}(S) = 1 - \sum_{c} p_c^2

手写决策树

import numpy as np
from collections import Counter

class SimpleDecisionTree:
def __init__(self, max_depth=5, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split

def _gini(self, y):
"""计算基尼系数"""
_, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
return 1 - np.sum(probs ** 2)

def _best_split(self, X, y):
"""穷举所有可能的分裂点,找最优"""
best_gain = -1
best_feature, best_threshold = None, None
n = len(y)
parent_gini = self._gini(y)

for feature in range(X.shape[1]):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left = y[X[:, feature] <= threshold]
right = y[X[:, feature] > threshold]
if len(left) < self.min_samples_split or len(right) < self.min_samples_split:
continue

# 加权子节点基尼
child_gini = (len(left) / n) * self._gini(left) + \
(len(right) / n) * self._gini(right)
gain = parent_gini - child_gini

if gain > best_gain:
best_gain, best_feature, best_threshold = gain, feature, threshold

return best_feature, best_threshold

def _build_tree(self, X, y, depth):
# 终止条件
if depth >= self.max_depth or len(np.unique(y)) == 1 or len(y) < self.min_samples_split:
return Counter(y).most_common(1)[0][0]

feature, threshold = self._best_split(X, y)
if feature is None: # 无法再分裂
return Counter(y).most_common(1)[0][0]

left_idx = X[:, feature] <= threshold
right_idx = X[:, feature] > threshold

return {
'feature': feature,
'threshold': threshold,
'left': self._build_tree(X[left_idx], y[left_idx], depth + 1),
'right': self._build_tree(X[right_idx], y[right_idx], depth + 1),
}

def fit(self, X, y):
self.tree = self._build_tree(X, y, depth=0)

def _predict_one(self, x, node):
if not isinstance(node, dict):
return node
if x[node['feature']] <= node['threshold']:
return self._predict_one(x, node['left'])
else:
return self._predict_one(x, node['right'])

def predict(self, X):
return np.array([self._predict_one(x, self.tree) for x in X])

Scikit-learn 实现

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

X, y = load_iris(return_X_y=True)

model = DecisionTreeClassifier(max_depth=3, min_samples_split=5)
model.fit(X, y)

# 可视化决策树
plt.figure(figsize=(12, 8))
plot_tree(model, filled=True, feature_names=load_iris().feature_names,
class_names=load_iris().target_names)
plt.show()

print(f"特征重要性: {model.feature_importances_}")

剪枝

决策树容易过拟合(无限制生长会把每个样本当成一个叶节点),需要剪枝:

方法说明
预剪枝限制 max_depthmin_samples_splitmin_samples_leaf
后剪枝先建完整树,再用验证集自底向上剪掉冗余分支(CCP 剪枝)
# Scikit-learn 的 CCP 后剪枝
path = model.cost_complexity_pruning_path(X, y)
# 取最优 ccp_alpha

总结

特性说明
优点直观可解释、不需要特征缩放、能处理非线性关系
缺点容易过拟合、对数据旋转敏感、不稳定(小变化可能导致完全不同的树)
适用场景需要可解释性的场景、作为随机森林/Gradient Boosting 的子模块