决策树算法

机器学习专题 · 掌握决策树算法的原理与实现

专题:Python机器学习系统学习

关键词:Python, 机器学习, 决策树, ID3, C4.5, CART, 信息增益, Gini指数, 剪枝, 特征重要性

一、决策树概述

决策树(Decision Tree)是一种基于树形结构的监督学习算法,它通过对数据进行递归划分来建立决策规则,最终形成一棵从根节点到叶节点的决策路径树。决策树既可以处理分类任务,也可以处理回归任务,是机器学习中最直观、最易于解释的算法之一。

决策树的核心思想来源于人类决策过程——当我们做决策时,通常会根据一系列条件逐步判断,最终得出结论。例如,判断一个水果是否为苹果时,我们会先看颜色(是否为红色或绿色),再看形状(是否圆形),然后看大小等。决策树正是将这种人类思维方式模型化。

决策树由以下四个基本部分组成:

决策树的训练过程本质上是一个递归二分(或多元划分)的过程:从根节点出发,选择一个最优特征将数据集划分为若干子集,然后对每个子集递归进行同样的操作,直到满足停止条件(如节点中的样本全部属于同一类别、达到最大深度、或节点中样本数小于阈值等)。

决策树的一个巨大优势在于其可解释性(Interpretability)。一棵训练好的决策树可以直接转换为一组IF-THEN规则,每条从根节点到叶节点的路径对应一条规则。这使得决策树在医疗诊断、信用评估等需要解释模型决策逻辑的领域备受欢迎。

常见的决策树算法包括:ID3(Iterative Dichotomiser 3)、C4.5(ID3的改进版本)、CART(Classification and Regression Tree)。它们的核心区别在于特征选择时使用的度量标准不同。

二、特征选择度量

决策树构建的核心问题是:在每个节点上,应该选择哪个特征进行划分?不同的决策树算法采用不同的特征选择度量(Feature Selection Measure)来评估每个特征的划分质量。选择的特征应能使划分后的子节点数据尽可能"纯净"——即同一子节点中的样本尽可能属于同一类别。

2.1 信息熵(Entropy)

信息熵是信息论中用于衡量不确定性或纯度的概念,由克劳德·香农(Claude Shannon)于1948年提出。熵的值越大,表示系统的不确定性越高,数据越"混乱";熵的值越小,表示数据越"纯净"。对于一个包含K个类别的数据集S,其熵的计算公式如下:

熵的定义:Entropy(S) = -Σᵢ(pᵢ · log₂(pᵢ)),其中i从1到K,pᵢ是类别i在数据集S中的占比。

举例来说,一个二分类问题,如果数据集中正负样本各占一半(p=0.5),则熵为 -0.5·log₂(0.5) - 0.5·log₂(0.5) = 1.0,此时不确定性最大。如果数据集中所有样本都属于同一类别(p=1或p=0),则熵为0,表示完全纯净。当p偏向一侧时,熵介于0和1之间。

2.2 信息增益与ID3算法

信息增益(Information Gain)是ID3算法采用的特征选择度量。它衡量的是使用某个特征进行划分后,熵减少的程度,即不确定性的降低量。信息增益越大,说明该特征对分类的贡献越大。

信息增益的计算公式:

Gain(S, A) = Entropy(S) - Σᵥ(|Sᵥ| / |S|) · Entropy(Sᵥ)

其中,S为当前数据集,A为待选特征,Sᵥ为特征A取值v时的子集,|Sᵥ|/|S|为子集的权重。

具体计算步骤:

  1. 计算当前数据集S的熵Entropy(S)。
  2. 对特征A的每个可能取值v,将S划分为子集Sᵥ,计算每个子集的熵Entropy(Sᵥ)。
  3. 计算加权平均熵:Σᵥ(|Sᵥ|/|S|)·Entropy(Sᵥ)。
  4. 信息增益 = Entropy(S) - 加权平均熵。
  5. 对所有特征重复步骤1-4,选择信息增益最大的特征进行划分。

ID3算法的局限性:信息增益存在一个明显的偏向性——它倾向于选择取值较多的特征。例如,如果有一个"编号"特征(每个样本都有一个唯一编号),按此特征划分后每个子集都只包含一个样本,熵为0,信息增益最大。但这种划分毫无泛化能力。为了解决这个问题,C4.5算法应运而生。

2.3 信息增益率与C4.5算法

C4.5算法是ID3的改进版本,它使用信息增益率(Gain Ratio)替代信息增益,以惩罚取值过多的特征。信息增益率通过引入分裂信息(Split Information)来归一化信息增益。

分裂信息的定义:

SplitInfo(S, A) = -Σᵥ(|Sᵥ| / |S|) · log₂(|Sᵥ| / |S|)

其中,|Sᵥ|/|S|是特征A取值为v的样本占比。分裂信息本质上是特征A的熵——特征取值越多且分布越均匀,分裂信息越大。

增益率的定义:

GainRatio(S, A) = Gain(S, A) / SplitInfo(S, A)

对于取值极多的特征(如"编号"),其SplitInfo非常大,导致增益率很小,从而有效抑制了特征偏向问题。C4.5在信息增益率之外还引入了其他改进:能够处理连续值特征(通过二分法离散化)、能够处理缺失值、支持剪枝、以及将决策树转换为更易于理解的规则集。

需要注意的是,当SplitInfo接近0时(即特征只有一个取值或大部分样本取同一值),增益率会变得不稳定。C4.5的处理方式是:先在候选特征中选择信息增益高于平均水平的特征,然后再从中选择增益率最大的特征。

2.4 Gini不纯度与CART算法

CART(Classification and Regression Tree)算法使用Gini不纯度(Gini Impurity)作为分类任务的特征选择度量。Gini指数衡量的是从数据集中随机抽取两个样本,其类别不一致的概率。Gini指数越小,数据纯度越高。

Gini指数的计算公式:

Gini(S) = 1 - Σᵢ(pᵢ²),其中i从1到K,pᵢ是类别i在数据集S中的占比。

使用特征A划分后的Gini指数:Gini_index(S, A) = Σᵥ(|Sᵥ| / |S|) · Gini(Sᵥ)

CART算法选择使Gini_index最小的特征进行划分。

与熵的对比:

2.5 MSE与CART回归树

当CART用于回归任务时,特征选择度量不再是Gini指数或熵,而是均方误差(Mean Squared Error, MSE)。回归树的每个叶节点存储一个预测值(通常是该叶节点中所有样本目标变量的均值),划分的目标是使子节点中的目标值尽可能相似。

回归树的划分标准:

选择特征A的划分,使得划分后的加权MSE最小化。

MSE(Sᵥ) = (1 / |Sᵥ|) · Σ(y - ȳᵥ)²,其中ȳᵥ是子集Sᵥ中目标变量的均值。

总MSE = Σᵥ(|Sᵥ| / |S|) · MSE(Sᵥ)

回归树还有一个常用替换度量是平均绝对误差(MAE),其鲁棒性优于MSE(对异常值不那么敏感)。

三、决策树的生成

本节详细介绍三种经典决策树算法的生成过程,以及防止过拟合的剪枝策略。

3.1 ID3算法

ID3(Iterative Dichotomiser 3)由Ross Quinlan于1986年提出,是最早的决策树算法之一。ID3使用信息增益作为特征选择标准,只能处理离散特征,不支持剪枝。

ID3算法步骤:

  1. 从根节点开始,包含所有训练样本。
  2. 计算每个特征的信息增益。
  3. 选择信息增益最大的特征作为当前节点的划分特征。
  4. 根据该特征的每个取值创建分支,将样本划分到对应子节点。
  5. 对每个子节点递归执行步骤2-4,直到满足终止条件:所有样本属于同一类别,或所有特征已用完,或没有剩余样本。

ID3的缺点:

3.2 C4.5算法

C4.5是ID3的改进版本,由Quinlan于1993年提出,是ID3的升级替代品。C4.5在多个方面对ID3进行了重要改进。

C4.5的主要改进:

C4.5曾是业界最主流的决策树算法之一,直到被更高效且性能相当的CART算法所取代。

3.3 CART算法

CART(Classification and Regression Tree)由Breiman等人于1984年提出,是当前最广泛使用的决策树算法。scikit-learn中的DecisionTreeClassifier和DecisionTreeRegressor均基于CART算法。

CART的核心特点:

CART生成步骤:

  1. 从根节点开始。
  2. 遍历所有特征的所有可能二分点,计算每个划分的Gini指数(分类)或MSE(回归)。
  3. 选择Gini指数最小(或MSE最小)的划分将当前节点分裂为两个子节点。
  4. 对每个子节点递归执行步骤2-3,直到满足停止条件。
  5. 使用成本复杂度剪枝对生成的完全生长树进行后剪枝。

3.4 三种算法对比

特性ID3C4.5CART
提出时间198619931984
特征选择标准信息增益信息增益率Gini指数 / MSE
树的结构多叉树多叉树二叉树
连续特征处理不支持支持(二分法离散化)支持
缺失值处理不支持支持支持(代理划分)
剪枝策略后剪枝(错误率剪枝)成本复杂度剪枝(CCP)
任务类型分类分类分类 + 回归
流行度(现代)已淘汰较少使用最广泛使用

3.5 剪枝策略

决策树如果完全生长而不加约束,很容易产生过拟合(Overfitting)——树对训练数据学习得过于细致,甚至学习了噪声,导致在测试数据上表现不佳。剪枝(Pruning)是防止决策树过拟合的核心技术,分为预剪枝和后剪枝两大类。

预剪枝(Pre-pruning):在树的生长过程中提前停止分裂,即"边建边剪"。常见策略包括:

预剪枝的优点是效率高(无需先构建完整树再剪枝),但缺点是可能存在"视野局限"——当前看似无用的分裂后续可能会带来显著收益。

后剪枝(Post-pruning):先让树充分生长,然后自底向上将不显著的子树替换为叶节点。CART采用的成本复杂度剪枝(Cost-Complexity Pruning, CCP)是后剪枝的代表方法:

四、决策树的实现(scikit-learn)

scikit-learn提供了简洁高效的决策树实现,支持分类和回归任务。下面通过代码示例展示如何使用scikit-learn构建决策树模型。

4.1 DecisionTreeClassifier 基本用法

from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score # 加载数据 iris = load_iris() X, y = iris.data, iris.target # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, random_state=42 ) # 创建决策树分类器 clf = DecisionTreeClassifier( criterion='gini', # 特征选择标准:'gini'或'entropy' max_depth=3, # 最大深度 min_samples_split=10, # 内部节点最小样本数 min_samples_leaf=5, # 叶节点最小样本数 random_state=42 ) # 训练模型 clf.fit(X_train, y_train) # 预测 y_pred = clf.predict(X_test) y_pred_proba = clf.predict_proba(X_test) # 评估 accuracy = accuracy_score(y_test, y_pred) print(f"准确率: {accuracy:.4f}") print(f"预测概率: {y_pred_proba[:3]}")

4.2 完整参数详解

DecisionTreeClassifier的构造参数对模型性能有决定性影响,理解每个参数的作用至关重要:

4.3 特征重要性提取

决策树的一个强大特性是能够输出特征重要性(Feature Importance),反映每个特征在决策过程中的贡献程度。特征重要性的计算基于该特征在所有节点上带来的不纯度减少的总和(经过归一化)。

import matplotlib.pyplot as plt import numpy as np # 训练完成后提取特征重要性 importances = clf.feature_importances_ feature_names = iris.feature_names # 排序 indices = np.argsort(importances)[::-1] print("特征重要性排序:") for i, idx in enumerate(indices): print(f"{i+1}. {feature_names[idx]}: {importances[idx]:.4f}") # 可视化 plt.figure(figsize=(8, 5)) plt.title("特征重要性") plt.bar(range(len(importances)), importances[indices]) plt.xticks(range(len(importances)), [feature_names[i] for i in indices], rotation=45) plt.tight_layout() plt.show()

4.4 决策树可视化

scikit-learn提供了多种决策树可视化方式,帮助理解模型的决策逻辑:

from sklearn.tree import plot_tree, export_text, export_graphviz import matplotlib.pyplot as plt # 方法1:使用 plot_tree(最简单) plt.figure(figsize=(12, 8)) plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True) plt.show() # 方法2:使用 export_text(纯文本表示) text_representation = export_text( clf, feature_names=iris.feature_names, show_weights=True ) print(text_representation) # 方法3:export_graphviz(导出Graphviz格式,可自定义更多样式) import graphviz dot_data = export_graphviz( clf, out_file=None, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, special_characters=True ) graph = graphviz.Source(dot_data) graph.render("iris_decision_tree") # 保存为PDF

4.5 回归树示例

from sklearn.tree import DecisionTreeRegressor from sklearn.datasets import fetch_california_housing # 加载加州房价数据 housing = fetch_california_housing() X, y = housing.data, housing.target # 划分数据集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) # 创建回归树 reg = DecisionTreeRegressor( criterion='squared_error', # MSE max_depth=5, min_samples_leaf=10, random_state=42 ) reg.fit(X_train, y_train) # 评估 from sklearn.metrics import mean_squared_error, r2_score y_pred = reg.predict(X_test) mse = mean_squared_error(y_test, y_pred) r2 = r2_score(y_test, y_pred) print(f"MSE: {mse:.4f}") print(f"R²: {r2:.4f}") # 特征重要性 for name, imp in zip(housing.feature_names, reg.feature_importances_): print(f"{name}: {imp:.4f}")

五、决策树的优缺点

5.1 优点

5.2 缺点

5.3 改进方向

鉴于单棵决策树的局限性,研究者和工程师们提出了多种集成改进方法,这些方法已成为现代机器学习的基石:

六、决策树的应用

决策树及其改进算法在众多领域有着广泛的应用,下面介绍几个典型场景。

6.1 特征重要性排序

在数据探索阶段,训练一棵决策树并分析其特征重要性,是快速理解数据结构和特征贡献度的有效手段。这对于高维数据的特征筛选(Feature Selection)尤其有用——可以保留重要性高的特征、剔除冗余特征,从而降低后续模型的过拟合风险并提高训练效率。

# 使用决策树进行特征筛选 from sklearn.feature_selection import SelectFromModel selector = SelectFromModel( DecisionTreeClassifier(max_depth=5, random_state=42), threshold='median' # 选择重要性高于中位数的特征 ) selector.fit(X_train, y_train) # 获取被选中的特征 selected_features = X_train.columns[selector.get_support()] print(f"选中的特征: {list(selected_features)}") # 转换数据 X_train_selected = selector.transform(X_train) X_test_selected = selector.transform(X_test)

6.2 规则提取

决策树可以直接转换为可读的IF-THEN规则,这在需要向用户解释模型决策的领域(如医疗诊断、信用评分、保险定价)中极为重要。规则提取的基本方法是遍历从根节点到每个叶节点的路径,将路径上的条件连接起来形成一条规则。

from sklearn.tree import export_text # 生成规则文本 rules = export_text(clf, feature_names=list(feature_names), show_weights=True) print("提取的决策规则:") print(rules) # 也可以手动遍历树结构提取规则 def extract_rules(tree, feature_names, node=0, depth=0, conditions=[]): left_child = tree.children_left[node] right_child = tree.children_right[node] if left_child == right_child: # 叶节点 value = tree.value[node] class_idx = value.argmax() print(f"IF {' AND '.join(conditions)} THEN class={class_idx}") if left_child != right_child: feature = feature_names[tree.feature[node]] threshold = tree.threshold[node] # 左分支:特征 <= 阈值 extract_rules(tree, feature_names, left_child, depth+1, conditions + [f"{feature} <= {threshold:.2f}"]) # 右分支:特征 > 阈值 extract_rules(tree, feature_names, right_child, depth+1, conditions + [f"{feature} > {threshold:.2f}"]) extract_rules(clf.tree_, iris.feature_names)

6.3 缺失值处理能力

CART算法具备一定的缺失值处理能力,主要通过代理划分(Surrogate Splits)实现。当某个样本在最优划分特征上缺失时,算法会寻找与该特征最相关的其他特征(即代理划分)来代替。scikit-learn的当前实现不支持代理划分,但许多工业级实现(如R语言的rpart包)和现代树模型(如XGBoost、LightGBM)都内建了缺失值处理机制。

在实际应用中,处理缺失值的常用策略包括:

6.4 实际应用场景

决策树及其集成方法在以下领域有着丰富的成功应用:

学习要点总结:决策树是机器学习入门的必修算法。掌握以下几点即成体系:(1)理解信息熵、Gini指数等核心度量指标的含义和计算;(2)清楚ID3、C4.5、CART三种算法的演进关系和本质差异;(3)熟练使用scikit-learn构建和调优决策树模型;(4)理解剪枝策略和正则化参数的作用;(5)认识到单棵树的局限性,为后续学习随机森林、XGBoost等集成方法打下坚实基础。