当前位置: 首页 > news >正文

机器学习——随机森林

一·概念解释

集成学习

集成学习的应用:

  1. 分类问题集成。
  2. 回归问题集成。
  3. 特征选取集成。

随机森林介绍

(一)随机森林优点
  1. 具有极高的准确率。
  2. 随机性的引入,使得随机森林的抗噪声能力很强。
  3. 随机性的引入,使得随机森林不容易过拟合。
  4. 能够处理很高维度的数据,不用做特征选择。
  5. 容易实现并行化计算。
(二)随机森林缺点
  1. 当随机森林中的决策树个数很多时,训练时需要的空间和时间会较大。
  2. 随机森林模型还有许多不好解释的地方,有点算个黑盒模型。

随机森林参数

随机森林 API 文档:

class sklearn.ensemble.RandomForestClassifier(n_estimators='warn', criterion='gini', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features='auto', max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, bootstrap=True, oob_score=False, n_jobs=None, random_state=None, verbose=0, warm_start=False, class_weight=None)

随机森林重要的一些参数:

  1. n_estimators:(随机森林独有)随机森林中决策树的个数。在 0.20 版本中默认是 10 个决策树;在 0.22 版本中默认是 100 个决策树。
  2. criterion:(同决策树)节点分割依据,默认为基尼系数,可选【entropy:信息增益】
  3. max_depth:(同决策树)【重要】default = (None) 设置决策树的最大深度,默认为 None。
    • (1)数据少或者特征少的时候,可以不用管这个参数,按照默认的不限制生长即可
    • (2)如果数据比较多特征也比较多的情况下,可以限制这个参数,范围在 10 ~ 100 之间比较好
  4. min_samples_split:(同决策树)【重要】这个值限制了子树继续划分的条件,如果某节点的样本数少于设定值,则不会再继续分裂。默认是 2. 如果样本量不大,不需要管这个值。如果样本量数量级非常大,则建议增大这个值。
  5. min_samples_leaf:(同决策树)【重要】这个值限制了叶子节点最少的样本数,如果某叶子节点数目小于样本数,则会和兄弟节点一起被剪枝。默认是 1, 可以输入最少的样本数的整数,或者最少样本数占样本总数的百分比。如果样本量不大,不需要管这个值。如果样本量数量级非常大,则推荐增大这个值。
    • 【叶是决策树的末端节点。较小的叶子使模型更容易捕捉训练数据中的噪声。一般来说,更偏向于将最小叶子节点数目设置为大于 50。在实际情况中,应该尽量尝试多种叶子大小种类,以找到最优的那个。】
    • 【比如,设定为 50,此时,上一个节点(100 个样本)进行分裂,分裂为两个节点,其中一个节点的样本数小于 50 个,那么这两个节点都会被剪枝】
  6. min_weight_fraction_leaf:(同决策树) 这个值限制了叶子节点所有样本权重和的最小值,如果小于这个值,则会和兄弟节点一起被剪枝。默认是 0,就是不考虑权重问题。一般来说,如果我们有较多样本有缺失值,或者分类树样本的分布类别偏差很大,就会引入样本权重,这时我们就要注意这个值了。【一般不需要注意】
  7. max_features:(随机森林独有)【重要】随机森林允许单个决策树使用特征的最大数量。选择最适属性时划分的特征不能超过此值。
    • 当为整数时,即最大特征数;当为小数时,训练集特征数 * 小数;
    • if “auto”, then max_features = sqrt(n_features).
    • If “sqrt”, then max_features = sqrt(n_features).
    • If “log2”, then max_features = log2(n_features).
    • If None, then max_features = n_features.
    • 【增加 max_features 一般能提高模型的性能,因为在每个节点上,我们有更多的选择可以考虑。然而,这未必完全是对的,因为它降低了单个树的多样性,而这正是随机森林独特的优点。但是,可以肯定,你通过增加 max_features 会降低算法的速度。因此,你需要适当的平衡和选择最佳 max_features。】
  8. n_jobs :并行 job 个数。这个在 bagging 训练过程中有重要作用,可以并行从而提高性能。1 = 不并行;n:n 个并行;-1:CPU 有多少 core,就启动多少 job。

二·代码

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier# 数据读取与划分
df = pd.read_csv('spambase.csv')
X = df.iloc[:, :-1]
y = df.iloc[:, -1]
xtrain, xtest, ytrain, ytest = train_test_split(X, y, test_size=0.2, random_state=100)# 随机森林模型构建与训练
rf = RandomForestClassifier(n_estimators=100,max_features=0.8,random_state=0
)
rf.fit(xtrain, ytrain)# 导入评估所需的库
from sklearn import metrics
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt# 混淆矩阵可视化函数
def cm_plot(y, yp):cm = confusion_matrix(y, yp)plt.matshow(cm, cmap=plt.cm.Blues)plt.colorbar()for x in range(len(cm)):for y_idx in range(len(cm)):plt.annotate(cm[x, y_idx], xy=(y_idx, x), horizontalalignment='center',verticalalignment='center')plt.ylabel('True label')plt.xlabel('Predicted label')return plt# 训练集预测与评估
train_predicted = rf.predict(xtrain)
print(metrics.classification_report(ytrain, train_predicted, digits=9))
cm_plot(ytrain, train_predicted).show()# 测试集预测与评估
test_predicted = rf.predict(xtest)
print(metrics.classification_report(ytest, test_predicted))
cm_plot(ytest, test_predicted).show()# 特征重要性分析
importances = rf.feature_importances_
im = pd.DataFrame(importances, columns=["importances"])
clos = df.columns
clos_1 = clos.values
clos_2 = clos_1.tolist()
clos = clos_2[:-1]
im['clos'] = closim = im.sort_values(by=['importances'], ascending=False)[:10]plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = Falseindex = range(len(im))
plt.yticks(index, im['clos'])
plt.barh(index, im['importances'])
plt.show()

http://www.dtcms.com/a/316829.html

相关文章:

  • Python高级编程与实践:Python高级数据结构与编程技巧
  • 【C++】Stack and Queue and Functor
  • C++二级考试核心知识点【内附操作题真题及解析】
  • Juc高级篇:可见性,有序性,cas,不可变,设计模式
  • SpringMVC(一)
  • Design Compiler:布图规划探索(ICC)
  • 《失落王国》v1.2.8中文版,单人或联机冒险的低多边形迷宫寻宝游戏
  • Modbus tcp 批量写线圈状态
  • centos7上如何安装Mysql5.5数据库?
  • 跨域场景下的Iframe事件监听
  • 【机器学习深度学习】模型量化
  • OSPF作业
  • Linux 基础
  • vue3 计算方式
  • GPS信号捕获尝试(上)
  • 【android bluetooth 协议分析 01】【HCI 层介绍 30】【hci_event和le_meta_event如何上报到btu层】
  • 【三个数公因数】2022-10-7
  • MySQL CONV()函数
  • 永磁同步电机无速度算法--基于二自由度结构的反推观测器TSBO
  • JAVA学习笔记 自增与自减的使用-006
  • 哲学中的主体性:历史演进、理论范式与当代重构
  • 【Unity】背包系统 + 物品窗口管理系统(中)
  • RC和RR的区别
  • Pytorch实现婴儿哭声检测和识别
  • 【web自动化测试】实战
  • Coze Studio开源,企业用户多了一种选择,也需多几分考量
  • 如何通过 5 种方式将照片从 iPad 传输到电脑
  • 埋点技术进阶:如何构建高效的数据采集架构
  • 默认二级路由(React-Router 6)
  • linux-系统日志查看指令systemctl