当前位置: 首页 > news >正文

Python Day11

@浙大疏锦行 Python Day11

内容:

  • 参数:手动设置(超参数),模型学习(内参数)
  • 模型 = 算法 + 参数
  • 寻找参数框架:网格搜索(爆搜),随机搜索,贝叶斯搜索(优化随机搜索)
  • `time`库计时
import timestart_time = time.time()
# proc...
end_time = time.time()print(f"time is {end_time - start_time}")

 代码:

import pandas as pd  # 用于数据处理和分析,可处理表格数据。
import numpy as np  # 用于数值计算,提供了高效的数组操作。
import matplotlib.pyplot as plt  # 用于绘制各种类型的图表
import seaborn as sns  # 基于matplotlib的高级绘图库,能绘制更美观的统计图形。
from sklearn.svm import SVC #支持向量机分类器
from sklearn.neighbors import KNeighborsClassifier #K近邻分类器
from sklearn.linear_model import LogisticRegression #逻辑回归分类器
import xgboost as xgb #XGBoost分类器
import lightgbm as lgb #LightGBM分类器
from sklearn.ensemble import RandomForestClassifier #随机森林分类器
from catboost import CatBoostClassifier #CatBoost分类器
from sklearn.tree import DecisionTreeClassifier #决策树分类器
from sklearn.naive_bayes import GaussianNB #高斯朴素贝叶斯分类器
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score # 用于评估分类器性能的指标
from sklearn.metrics import classification_report, confusion_matrix #用于生成分类报告和混淆矩阵
import warnings #用于忽略警告信息
warnings.filterwarnings("ignore") # 忽略所有警告信息
from sklearn.model_selection import train_test_split# 设置中文字体(解决中文显示问题)
plt.rcParams['font.sans-serif'] = ['SimHei']  # Windows系统常用黑体字体
plt.rcParams['axes.unicode_minus'] = False  # 正常显示负号data = pd.read_csv("./data/heart.csv")# 这里不需要处理离散值以及缺失值
X = data.drop(['target'], axis=1)
y = data['target']X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)print("----------默认参数的随机森林-------------")
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print(f"预测精度为:{accuracy_score(y_pred, y_test)}")
print("---------默认参数的随机森林结束------------")
print("----------网格搜索的随机森林-------------")
from sklearn.model_selection import GridSearchCV
import time
# 定义要搜索的参数网格
param_grid = {'n_estimators': [50, 100, 200],'max_depth': [None, 10, 20, 30],'min_samples_split': [2, 5, 10],'min_samples_leaf': [1, 2, 4]
}
grid_search = GridSearchCV(estimator=RandomForestClassifier(random_state=42), # 随机森林分类器param_grid=param_grid, # 参数网格cv=5, # 5折交叉验证n_jobs=-1, # 使用所有可用的CPU核心进行并行计算scoring='accuracy') # 使用准确率作为评分标准
start_time = time.time()
grid_search.fit(X_train, y_train)
end_time = time.time()
print(f"网格搜索耗时: {end_time - start_time:.4f} 秒")
print("最佳参数: ", grid_search.best_params_) #best_params_属性返回最佳参数组合
# 使用最佳参数的模型进行预测
best_model = grid_search.best_estimator_ # 获取最佳模型
best_pred = best_model.predict(X_test) # 在测试集上进行预测
print("\n网格搜索优化后的随机森林 在测试集上的分类报告:")
print(classification_report(y_test, best_pred))
print("网格搜索优化后的随机森林 在测试集上的混淆矩阵:")
print(confusion_matrix(y_test, best_pred))
print("-------------网格搜索随机森林结束-----------------")
print("-------------贝叶斯参数随机森林-------------------")
from skopt import BayesSearchCV
from skopt.space import Integer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import time# 定义要搜索的参数空间
search_space = {'n_estimators': Integer(50, 200),'max_depth': Integer(10, 30),'min_samples_split': Integer(2, 10),'min_samples_leaf': Integer(1, 4)
}# 创建贝叶斯优化搜索对象
bayes_search = BayesSearchCV(estimator=RandomForestClassifier(random_state=42),search_spaces=search_space,n_iter=32,  # 迭代次数,可根据需要调整cv=5, # 5折交叉验证,这个参数是必须的,不能设置为1,否则就是在训练集上做预测了n_jobs=-1,scoring='accuracy'
)start_time = time.time()
# 在训练集上进行贝叶斯优化搜索
bayes_search.fit(X_train, y_train)
end_time = time.time()print(f"贝叶斯优化耗时: {end_time - start_time:.4f} 秒")
print("最佳参数: ", bayes_search.best_params_)# 使用最佳参数的模型进行预测
best_model = bayes_search.best_estimator_
best_pred = best_model.predict(X_test)print("\n贝叶斯优化后的随机森林 在测试集上的分类报告:")
print(classification_report(y_test, best_pred))
print("贝叶斯优化后的随机森林 在测试集上的混淆矩阵:")
print(confusion_matrix(y_test, best_pred))print("-------------贝叶斯结束---------------")
http://www.dtcms.com/a/276763.html

相关文章:

  • Agent任务规划
  • 【PMP备考】敏捷思维:驾驭不确定性的项目管理之道
  • QT中设计qss字体样式但是没有用【已解决】
  • 文件系统(精讲)
  • JVM与系统性能监控工具实战指南:从JVM到系统的全链路分析
  • 【每日刷题】阶乘后的零
  • SOEM build on ubuntu
  • Golang实战:使用 Goroutine 实现数字与字母的交叉打印
  • 使用bp爆破模块破解pikachu登录密码
  • 使用frp内网穿透:将本地服务暴露到公网
  • 张量类型转换
  • 深入探讨Java的ZGC垃圾收集器:原理、实战与优缺点
  • 格密码--数学基础--08最近向量问题(CVP)与格陪集
  • Mentor软件模块复杂,如何分角色授权最合理?
  • 【PTA数据结构 | C语言版】阶乘的递归实现
  • 串口屏的小记哦
  • 鸿蒙进程通信的坑之ServiceExtensionAbility
  • Datomic数据库简介(TBC)
  • Ntfs!LfsFlushLfcb函数分析之Ntfs!_LFCB->LbcbWorkque的背景分析3个restart页面一个普通页面的一个例子
  • 如何在IEEETrans格式的latex标题页插入图像
  • CCS-MSPM0G3507-4-基础篇-串口通讯-实现收和发
  • Java SE--抽象类和接口
  • 面试150 对称二叉树
  • Waiting for server response 和 Content Download
  • 嵌入式程序调试工具
  • 《人件》阅读笔记
  • 【Flask】基础入门
  • 华为业务变革项目IPD基本知识
  • nodejs获取可用cpu数
  • 前端弹性布局全解析