当前位置: 首页
AI教程
Python决策树剪枝AI数据分析进阶教程

Python决策树剪枝AI数据分析进阶教程

热心网友 时间:2026-06-26
转载

用Python进行AI数据分析进阶教程49:

决策树的剪枝


关键词:决策树、剪枝、预剪枝、后剪枝、过拟合

摘要:本文系统阐述决策树剪枝的核心概念及其在防止过拟合中的关键作用,重点解析预剪枝与后剪枝两种主流方法。预剪枝通过在树生长过程中提前终止分裂,借助最大深度、最小样本数等参数控制模型复杂度,但可能引发欠拟合;后剪枝则先构建完整决策树,再自底向上进行剪枝,常用代价复杂度剪枝(CCP)配合验证集评估,虽计算开销较大,却可显著提升泛化能力。文中附带基于Scikit-Learn库实现预剪枝与后剪枝的完整Python代码,并以鸢尾花数据集为例展示模型训练与评估流程,凸显剪枝对分类准确率的实际提升效果。


决策树剪枝,本质上是为了防止模型在训练集上表现优异却在测试集上大幅下滑——即过拟合。应对这一难题主要有两大思路:预剪枝和后剪枝。下面逐一深入剖析。

一、预剪枝

1、原理

预剪枝的策略是:在决策树尚未完全生长时,每次打算对某个节点进行分裂,先评估本次分裂是否有实际价值。若信息增益或基尼系数提升不明显,则立即停止分裂。此外,还可设定若干硬性约束,例如树的最大深度、节点最少样本数等,从而提前限定树的生长范围。

2、关键点

  • 评估指标:信息增益、基尼系数,用于衡量节点划分的收益。
  • 限制条件:最大深度、最小样本数、最小信息增益——这些是调节模型复杂度的关键阀门。

3、注意点

  • 欠拟合风险:预剪枝容易“矫枉过正”,若树在尚未充分学习时就停止生长,可能导致模型在训练集上表现也不理想。
  • 参数选择:超参数取值高度依赖具体数据集。同一参数在不同数据上效果可能截然不同,需反复试验调优。

4、示例及代码

Python脚本

# 从 sklearn 库的 datasets 模块中导入 load_iris 函数,用于加载鸢尾花数据集
from sklearn.datasets import load_iris
# 从 sklearn 库的 model_selection 模块中导入 train_test_split 函数,
# 用于将数据集划分为训练集和测试集
from sklearn.model_selection import train_test_split
# 从 sklearn 库的 tree 模块中导入 DecisionTreeClassifier 类,用于创建决策树分类器
from sklearn.tree import DecisionTreeClassifier
# 从 sklearn 库的 metrics 模块中导入 accuracy_score 函数,用于计算分类准确率
from sklearn.metrics import accuracy_score

# 调用 load_iris 函数加载鸢尾花数据集,并将其赋值给变量 iris
# 鸢尾花数据集是一个经典的分类数据集,包含 150 个样本,分为 3 个类别
iris = load_iris()
# 从 iris 数据集中提取特征数据,赋值给变量 X
# 特征数据包含了鸢尾花的一些测量值,如花瓣长度、花瓣宽度等
X = iris.data
# 从 iris 数据集中提取标签数据,赋值给变量 y
# 标签数据表示每个样本所属的类别
y = iris.target

# 使用 train_test_split 函数将特征数据 X 和标签数据 y 划分为训练集和测试集
# test_size=0.3 表示将 30% 的数据作为测试集,70% 的数据作为训练集
# random_state=42 是随机数种子,保证每次划分的结果一致,方便结果复现
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建一个 DecisionTreeClassifier 决策树分类器对象,并设置预剪枝参数
# max_depth=3 表示决策树的最大深度为 3,防止树生长过深导致过拟合
# min_samples_split=5 表示一个节点要进行划分时,
# 至少需要包含 5 个样本,同样是为了防止过拟合
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=5)

# 调用决策树分类器的 fit 方法,使用训练集数据 X_train 和对应的标签 y_train 对模型进行训练
# 训练过程就是让决策树学习特征和标签之间的关系
clf.fit(X_train, y_train)

# 调用训练好的决策树分类器的 predict 方法,使用测试集数据 X_test 进行预测
# 预测结果存储在变量 y_pred 中
y_pred = clf.predict(X_test)

# 调用 accuracy_score 函数,计算测试集真实标签 y_test 和预测标签 y_pred 之间的准确率
# 准确率是分类正确的样本数占总样本数的比例
accuracy = accuracy_score(y_test, y_pred)

# 使用 f-string 格式化输出准确率
# 打印出训练好的决策树模型在测试集上的分类准确率
print(f"Accuracy: {accuracy}")

输出 / 打印结果及注释

代码的输出结果会类似如下形式:

Accuracy: 0.9555555555555556

这里的输出值是一个近似值,实际的准确率会因为浮点数的精度问题显示为多位小数。

  • Accuracy:表示模型在测试集上的分类准确率。这个值越接近 1,说明模型在测试集上的分类效果越好。在这个例子中,准确率约为 0.956,意味着模型在测试集上大约 95.6% 的样本分类正确。不过,每次运行代码时,由于数据集的划分可能会有细微差异(虽然设置了 random_state 尽量保证结果一致,但不同环境可能仍有极微小差别),准确率可能会在一定范围内波动。

重点语句解读

  • DecisionTreeClassifier(max_depth=3, min_samples_split=5):创建决策树分类器,max_depth=3 表示树的最大深度为 3,min_samples_split=5 表示节点划分所需的最小样本数为 5。这些参数的设置可以限制树的生长,避免过拟合。
  • clf.fit(X_train, y_train):使用训练集数据对决策树模型进行训练。
  • clf.predict(X_test):使用训练好的模型对测试集数据进行预测。
  • accuracy_score(y_test, y_pred):计算预测结果的准确率。

二、后剪枝

1、原理

后剪枝的思路恰好相反:先让决策树自由生长至最大限度,然后自底向上逐一检查每个非叶子节点,判断将该子树替换为叶子节点是否能使模型在验证集上的表现更优。若有益,则执行剪枝。最常用的后剪枝方法是代价复杂度剪枝(CCP)。

2、关键点

  • 剪枝评估:借助验证集作为衡量标准——剪枝前测试一次,剪枝后再测试一次,对比性能变化。
  • 剪枝策略:根据评估结果决定是否实施剪枝。

3、注意点

  • 计算成本:后剪枝需要先构建完整的决策树,再对每个节点进行验证评估,耗时明显多于预剪枝。
  • 验证集选择:验证集的划分质量直接影响剪枝效果。若验证集与训练集过于相似,剪枝可能无实质改进;若差异过大,又可能误剪掉有价值的分支。

4、示例及代码

Python脚本

# 从 sklearn 库的 datasets 模块中导入 load_iris 函数,用于加载经典的鸢尾花数据集
from sklearn.datasets import load_iris
# 从 sklearn 库的 model_selection 模块中导入 train_test_split 函数,
# 该函数可将数据集划分为不同子集
from sklearn.model_selection import train_test_split
# 从 sklearn 库的 tree 模块中导入 DecisionTreeClassifier 类,用于创建决策树分类模型
from sklearn.tree import DecisionTreeClassifier
# 从 sklearn 库的 metrics 模块中导入 accuracy_score 函数,用于计算分类模型的准确率
from sklearn.metrics import accuracy_score

# 调用 load_iris 函数加载鸢尾花数据集,并将其赋值给变量 iris
# 鸢尾花数据集包含了 150 个样本,每个样本有 4 个特征,分为 3 个类别
iris = load_iris()
# 从 iris 数据集中提取特征数据,存储在变量 X 中
# 这些特征是鸢尾花的一些属性,如花瓣长度、花瓣宽度等
X = iris.data
# 从 iris 数据集中提取标签数据,存储在变量 y 中
# 标签代表每个样本所属的类别
y = iris.target

# 使用 train_test_split 函数将数据集划分为训练集和临时集
# test_size=0.4 表示将 40% 的数据作为临时集,60% 的数据作为训练集
# random_state=42 是随机数种子,保证每次划分的结果一致,方便结果复现
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)

# 再次使用 train_test_split 函数将临时集划分为验证集和测试集
# test_size=0.5 表示将临时集的 50% 作为测试集,另外 50% 作为验证集
# 这样整体上训练集、验证集、测试集的比例大致为 60%、20%、20%
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# 创建一个 DecisionTreeClassifier 类的实例 clf,即一个决策树分类器
# 这里没有设置额外参数,将使用默认的参数来构建决策树
clf = DecisionTreeClassifier()

# 调用决策树分类器的 fit 方法,
# 使用训练集数据 X_train 和对应的标签 y_train 对模型进行训练
# 训练过程就是让决策树学习特征和标签之间的关系,构建决策树结构
clf.fit(X_train, y_train)

# 调用决策树分类器的 cost_complexity_pruning_path 方法,进行代价复杂度剪枝路径的计算
# 该方法会计算不同剪枝强度(由 ccp_alpha 参数控制)下的子树及其对应的不纯度
# 结果存储在 path 对象中
path = clf.cost_complexity_pruning_path(X_train, y_train)

# 从 path 对象中提取不同的 ccp_alpha 值,存储在 ccp_alphas 变量中
# ccp_alpha 是代价复杂度剪枝的参数,值越大,剪枝越严重
ccp_alphas, impurities = path.ccp_alphas, path.impurities

# 初始化最优的 ccp_alpha 值为 None,用于后续存储找到的最优值
best_alpha = None
# 初始化最优准确率为 0,用于后续比较和更新
best_accuracy = 0

# 遍历所有的 ccp_alpha 值
for ccp_alpha in ccp_alphas:
    # 创建一个新的决策树分类器 pruned_clf,并设置当前的 ccp_alpha 值
    # 这样就得到了一个使用该剪枝强度的决策树模型
    pruned_clf = DecisionTreeClassifier(ccp_alpha=ccp_alpha)
    # 使用训练集数据 X_train 和对应的标签 y_train 对剪枝后的决策树模型进行训练
    pruned_clf.fit(X_train, y_train)
    # 调用剪枝后决策树模型的 predict 方法,使用验证集数据 X_val 进行预测
    # 预测结果存储在 y_val_pred 中
    y_val_pred = pruned_clf.predict(X_val)
    # 调用 accuracy_score 函数,
    # 计算验证集真实标签 y_val 和预测标签 y_val_pred 之间的准确率
    accuracy = accuracy_score(y_val, y_val_pred)
    # 如果当前的准确率大于之前记录的最优准确率
    if accuracy > best_accuracy:
        # 更新最优准确率为当前准确率
        best_accuracy = accuracy
        # 更新最优的 ccp_alpha 值为当前的 ccp_alpha 值
        best_alpha = ccp_alpha

# 创建一个新的决策树分类器 final_clf,并使用最优的 ccp_alpha 值
# 这样就得到了一个使用最优剪枝强度的决策树模型
final_clf = DecisionTreeClassifier(ccp_alpha=best_alpha)
# 使用训练集数据 X_train 和对应的标签 y_train 对使用最优剪枝强度的决策树模型进行训练
final_clf.fit(X_train, y_train)

# 调用使用最优剪枝强度的决策树模型的 predict 方法,使用测试集数据 X_test 进行预测
# 预测结果存储在 y_test_pred 中
y_test_pred = final_clf.predict(X_test)
# 调用 accuracy_score 函数,
# 计算测试集真实标签 y_test 和预测标签 y_test_pred 之间的准确率
test_accuracy = accuracy_score(y_test, y_test_pred)

# 使用 f-string 格式化输出测试集上的准确率
print(f"Test Accuracy: {test_accuracy}")

输出 / 打印结果及注释

代码运行后的输出可能类似如下:

Test Accuracy: 0.9666666666666667
  • Test Accuracy:这是使用最优剪枝强度的决策树模型在测试集上的分类准确率。它反映了经过代价复杂度剪枝后,模型在未参与训练和验证的数据上的泛化能力。该值越接近 1 越好,这里约为 0.967,意味着模型在测试集上大约 96.7% 的样本分类正确。不同的运行可能会因为随机种子的影响(虽然设置了 random_state 尽量保证一致,但环境差异等因素仍可能有极小波动)导致准确率略有不同。

重点语句解读

  • clf.cost_complexity_pruning_path(X_train, y_train):计算决策树的代价复杂度剪枝路径,返回不同 ccp_alpha 值对应的子树的不纯度。
  • ccp_alphas, impurities = path.ccp_alphas, path.impurities:获取不同 ccp_alpha 值和对应的不纯度。
  • DecisionTreeClassifier(ccp_alpha=ccp_alpha):创建一个使用指定 ccp_alpha 值的决策树分类器,ccp_alpha 是一个用于控制剪枝强度的参数,值越大,剪枝越严重。
  • final_clf = DecisionTreeClassifier(ccp_alpha=best_alpha):使用最优的 ccp_alpha 值重新训练决策树模型。

通过预剪枝和后剪枝,可以有效地防止决策树过拟合,提高模型的泛化能力。在实际应用中,可以根据具体情况选择合适的剪枝方法。

——The END——

来源:https://blog.csdn.net/imewe/article/details/149061789

游乐网为非赢利性网站,所展示的游戏/软件/文章内容均来自于互联网或第三方用户上传分享,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系youleyoucom@outlook.com。

同类文章
更多
Claude MCP模型爆火 AI Agent圈万能插头让Cursor工作流效率提升10倍

Claude MCP模型爆火 AI Agent圈万能插头让Cursor工作流效率提升10倍

坦白说,MCP这波热度来得有点突然。去年11月Anthropic推出的时候,没引起太大波澜;但最近几天,X上几乎所有人都在讨论MCP服务器,大有把它捧成AI应用碘伏者的架势。 MCP(模型上下文协议)是去年11月由Claude的母公司Anthropic推出的一项开放标准协议,目标是为大语言模型与外部

时间:2026-06-27 14:04
凯文凯利万字预言未来十年AI世界发展趋势

凯文凯利万字预言未来十年AI世界发展趋势

凯文·凯利授课现场 1、关于AI的未来图景 感谢邀请,有机会跟嘉宾商学的企业家校友们面对面聊聊。大家这次访学行程很硬核——从拉斯维加斯CES到硅谷,接触了不少最新的AI技术和理论。所以今天的分享,想提供一个不太一样的视角:关于AI正在发生什么,以及接下来会发生什么。会用一种叫“情景规划”的方式来展开

时间:2026-06-27 14:04
从Manus到GO-1:AI逐步走进物理世界

从Manus到GO-1:AI逐步走进物理世界

2025年3月,中国AI领域投下的重磅冲击波,可不止一枚。Manus通用AI Agent以“全球首款执行级智能体”之姿闪亮登场,紧接着,GO-1通用具身基座大模型宣布开源,扬言要“重新定义人机交互边界”。这两件事,让不少人开始认真琢磨:当AI不再满足于云端聊天,开始伸手触碰物理世界,真正的智能革命,

时间:2026-06-27 14:03
Manus AI是通用Agent革命还是精巧缝合怪

Manus AI是通用Agent革命还是精巧缝合怪

先说一个基本判断:昨天,Manus至少在中文媒体圈里刷屏了。 自媒体的反应相当狂热,“通用Agent终于实现了!”“这是继DeepSeek之后的又一技术革命!”这样的说法遍地都是。从Benchmark来看,Manus的表现确实亮眼——在GAIA测试中,它超越了此前的各种Agent以及OpenAI的D

时间:2026-06-27 14:03
Ubuntu从零部署OpenClaw完整教程(本地模型与DeepSeek)

Ubuntu从零部署OpenClaw完整教程(本地模型与DeepSeek)

0 前言 OpenClaw(圈内常称“龙虾”)是一套开源、支持自托管的 AI 助手平台,原生兼容 Ollama 本地模型与 DeepSeek 等云端 API,让您在隐私保护与性能体验之间灵活切换——需要安全就用本地,追求强大则上云端。本文记录了我在 Ubuntu 系统上从零搭建 OpenClaw

时间:2026-06-27 14:03
热门专题
更多
刀塔传奇破解版无限钻石下载大全 刀塔传奇破解版无限钻石下载大全
洛克王国正式正版手游下载安装大全 洛克王国正式正版手游下载安装大全
思美人手游下载专区 思美人手游下载专区
好玩的阿拉德之怒游戏下载合集 好玩的阿拉德之怒游戏下载合集
不思议迷宫手游下载合集 不思议迷宫手游下载合集
百宝袋汉化组游戏最新合集 百宝袋汉化组游戏最新合集
jsk游戏合集30款游戏大全 jsk游戏合集30款游戏大全
宾果消消消原版下载大全 宾果消消消原版下载大全
  • 日榜
  • 周榜
  • 月榜