当前位置: 首页 > news >正文

机器学习案例——预测矿物类型(模型训练)

模型训练类完整实现

1. 类初始化 (__init__)

def __init__(self, model='mode'):"""模型训练类初始化参数:model -- 模型/数据集名称,对应数据文件名前缀属性:train_x -- 训练集特征 (pd.DataFrame)train_y -- 训练集标签 (pd.Series)test_x -- 测试集特征 (pd.DataFrame) test_y -- 测试集标签 (pd.Series)result -- 存储各模型评估结果的字典"""self.model = model  # 模型/数据集名称self.train_x = None  # 训练集特征self.train_y = None  # 训练集标签self.test_x = None  # 测试集特征self.test_y = None  # 测试集标签self.result = {}  # 存储各模型评估结果self.extract_data()  # 调用数据加载方法

2. 数据加载 (extract_data)

def extract_data(self):"""从Excel文件加载训练集和测试集数据文件路径格式:- 训练集: '数据/{model}[训练集].xlsx'- 测试集: '数据/{model}[测试集].xlsx'数据格式要求:- 最后一列为标签列- 其他列为特征列异常处理:- 文件不存在时抛出ValueError"""try:# 加载训练集train_data = pd.read_excel(f'数据/{self.model}[训练集].xlsx')self.train_x = train_data.iloc[:, :-1]  # 除最后一列外所有列作为特征self.train_y = train_data.iloc[:, -1]   # 最后一列作为标签# 加载测试集test_data = pd.read_excel(f'数据/{self.model}[测试集].xlsx')self.test_x = test_data.iloc[:, :-1]    # 除最后一列外所有列作为特征self.test_y = test_data.iloc[:, -1]     # 最后一列作为标签except FileNotFoundError:raise ValueError(f"错误: 无法找到{self.model}对应的数据文件")

3. 模型训练方法

3.1 逻辑回归 (logistic_regression)

def logistic_regression(self):"""逻辑回归模型训练与评估使用sklearn的LogisticRegression通过GridSearchCV进行超参数调优参数网格:- C: 正则化强度的倒数- solver: 优化算法- max_iter: 最大迭代次数- multi_class: 多分类策略评估指标:- 各类别召回率- 整体准确率"""param_grid = {'C': np.logspace(-2, 3, 9),'solver': ['saga','sag','lbfgs','newton-cg'],'max_iter': [100,200,500],'multi_class': ['ovr','multinomial','multinomial']}# 创建模型model = LogisticRegression()grid = GridSearchCV(model, param_grid, cv=5, scoring='accuracy')grid.fit(self.train_x, self.train_y)# 预测与评估y_pred = grid.predict(self.test_x)report = metrics.classification_report(self.test_y, y_pred)print(report)# 解析评估报告report = report.split()lr_result = {'recall_0': float(report[6]),   # 类别0召回率'recall_1': float(report[11]),  # 类别1召回率'recall_2': float(report[16]),  # 类别2召回率'recall_3': float(report[21]),  # 类别3召回率'acc': float(report[25])        # 整体准确率}self.result['LR'] = lr_result

3.2 随机森林 (random_forest)

def random_forest(self):"""随机森林模型训练与评估使用sklearn的RandomForestClassifier通过GridSearchCV进行超参数调优参数网格:- n_estimators: 树的数量- max_depth: 树的最大深度- min_samples_split: 分裂节点所需最小样本数- min_samples_leaf: 叶节点所需最小样本数- max_features: 寻找最佳分割时考虑的特征数- bootstrap: 是否使用bootstrap采样"""param_grid = {'n_estimators': [50, 100],'max_depth': [15, 20],'min_samples_split': [2,5],'min_samples_leaf': [1,2],'max_features': ['auto'],'bootstrap': [True]}model = RandomForestClassifier()grid = GridSearchCV(model, param_grid, cv=5, scoring='accuracy')grid.fit(self.train_x, self.train_y)y_pred = grid.predict(self.test_x)report = metrics.classification_report(self.test_y, y_pred)print(report)report = report.split()rf_result = {'recall_0': float(report[6]),'recall_1': float(report[11]),'recall_2': float(report[16]),'recall_3': float(report[21]),'acc': float(report[25])}self.result['RF'] = rf_result

3.3 多项式朴素贝叶斯 (multi_naive_bayes)

def multi_naive_bayes(self):"""多项式朴素贝叶斯模型训练与评估使用sklearn的MultinomialNB处理负值数据: 将所有特征值平移至非负通过GridSearchCV进行超参数调优参数网格:- alpha: 平滑参数- fit_prior: 是否学习类别先验概率"""# 处理负值数据train_x = self.train_x - self.train_x.min()test_x = self.test_x - self.test_x.min()param_grid = {'alpha': [0.1, 0.5, 1.0, 2.0],'fit_prior': [True, False]}model = MultinomialNB()grid = GridSearchCV(model, param_grid, cv=5, scoring='accuracy')grid.fit(train_x, self.train_y)y_pred = grid.predict(test_x)report = metrics.classification_report(self.test_y, y_pred)print(report)report = report.split()nb_result = {'recall_0': float(report[6]),'recall_1': float(report[11]),'recall_2': float(report[16]),'recall_3': float(report[21]),'acc': float(report[25])}self.result['NB'] = nb_result

3.4 XGBoost (xgboost)

def xgboost(self):"""XGBoost模型训练与评估使用xgboost的XGBClassifier通过GridSearchCV进行超参数调优参数网格:- learning_rate: 学习率- n_estimators: 提升迭代次数- num_class: 类别数- max_depth: 树的最大深度- min_child_weight: 子节点所需最小权重和- gamma: 分裂所需最小损失函数下降值- subsample: 训练样本采样比例- colsample_bytree: 特征采样比例"""param_grid = {'learning_rate': [0.05],'n_estimators': [200],'num_class': [5],'max_depth': [7],'min_child_weight': [1, 3],'gamma': [0,0.1],'subsample': [0.8, 1.0],'colsample_bytree': [0.8, 1.0]}model = XGBClassifier()grid = GridSearchCV(model, param_grid, cv=5, scoring='accuracy')grid.fit(self.train_x, self.train_y)y_pred = grid.predict(self.test_x)report = metrics.classification_report(self.test_y, y_pred)print(report)report = report.split()xgb_result = {'recall_0': float(report[6]),'recall_1': float(report[11]),'recall_2': float(report[16]),'recall_3': float(report[21]),'acc': float(report[25])}self.result['XGB'] = xgb_result

3.5 支持向量机 (svc)

def svc(self):"""支持向量机模型训练与评估使用sklearn的SVC通过GridSearchCV进行超参数调优参数网格:- C: 惩罚参数- kernel: 核函数类型- gamma: 核函数系数- degree: 多项式核的阶数"""param_grid = {'C': [0.1, 1, 10, 100],'kernel': ['linear', 'rbf', 'poly'],'gamma': ['scale', 'auto', 0.1, 1],'degree': [2, 3, 4]}model = SVC()grid = GridSearchCV(model, param_grid, cv=5, scoring='accuracy')grid.fit(self.train_x, self.train_y)y_pred = grid.predict(self.test_x)report = metrics.classification_report(self.test_y, y_pred)print(report)report = report.split()svc_result = {'recall_0': float(report[6]),'recall_1': float(report[11]),'recall_2': float(report[16]),'recall_3': float(report[21]),'acc': float(report[25])}self.result['SVC'] = svc_result

3.6 K近邻 (knn)

def knn(self):"""K近邻模型训练与评估使用sklearn的KNeighborsClassifier通过GridSearchCV进行超参数调优参数网格:- n_neighbors: 邻居数量- weights: 权重计算方式- p: 距离度量参数"""param_grid = {'n_neighbors': [3, 5, 7, 10],'weights': ['uniform', 'distance'],'p': [1, 2]}model = KNeighborsClassifier()grid = GridSearchCV(model, param_grid, cv=5, scoring='accuracy')grid.fit(self.train_x, self.train_y)y_pred = grid.predict(self.test_x)report = metrics.classification_report(self.test_y, y_pred)print(report)report = report.split()knn_result = {'recall_0': float(report[6]),'recall_1': float(report[11]),'recall_2': float(report[16]),'recall_3': float(report[21]),'acc': float(report[25])}self.result['KNN'] = knn_result

3.7 决策树 (decisiontree)

def decisiontree(self):"""决策树模型训练与评估使用sklearn的DecisionTreeClassifier当前版本未设置参数网格"""param_grid = {}model = DecisionTreeClassifier()grid = GridSearchCV(model, param_grid, cv=5, scoring='accuracy')grid.fit(self.train_x, self.train_y)y_pred = grid.predict(self.test_x)report = metrics.classification_report(self.test_y, y_pred)print(report)report = report.split()dt_result = {'recall_0': float(report[6]),'recall_1': float(report[11]),'recall_2': float(report[16]),'recall_3': float(report[21]),'acc': float(report[25])}self.result['DT'] = dt_result

3.8 AdaBoost (adaboost)

def adaboost(self):"""AdaBoost模型训练与评估使用sklearn的AdaBoostClassifier通过GridSearchCV进行超参数调优参数网格:- n_estimators: 弱学习器数量- learning_rate: 学习率- algorithm: 算法实现方式"""param_grid = {'n_estimators': [50, 100, 200],'learning_rate': [0.01, 0.1, 1.0],'algorithm': ['SAMME', 'SAMME.R']}model = AdaBoostClassifier()grid = GridSearchCV(model, param_grid, cv=5, scoring='accuracy')grid.fit(self.train_x, self.train_y)y_pred = grid.predict(self.test_x)report = metrics.classification_report(self.test_y, y_pred)print(report)report = report.split()ab_result = {'recall_0': float(report[6]),'recall_1': float(report[11]),'recall_2': float(report[16]),'recall_3': float(report[21]),'acc': float(report[25])}self.result['AdaBoost'] = ab_result

4. 批量训练入口 (train_process)

def train_process(self, method=''):"""批量训练入口方法参数:method -- 逗号分隔的模型简称字符串支持的模型简称:- lr: 逻辑回归- rf: 随机森林- nb: 朴素贝叶斯- xg: XGBoost- svc: 支持向量机- knn: K近邻- tree: 决策树- adaboost: AdaBoost示例:train_process('lr,rf,xg')  # 训练逻辑回归、随机森林和XGBoost"""# 解析输入的方法字符串met = [m.strip() for m in method.split(',')]# 方法映射字典method_map = {'lr': self.logistic_regression,'rf': self.random_forest,'nb': self.multi_naive_bayes,'xg': self.xgboost,'svc': self.svc,'knn': self.knn,'tree': self.decisiontree,'adaboost': self.adaboost,}# 遍历所有指定的方法for i in met:if i in method_map:method_map[i]()  # 调用对应的训练方法else:print(f"警告:忽略无效方法名 '{i}'")
  • 导入依赖和初始化
import json 
from process_data import MakeModel# 使用json模块保存模型评估结果
# 从process_data模块导入MakeModel类,用于数据预处理和模型训练

模型训练并保存结果

  • 测试不同预处理方法
key_list = ['drop','mean','median','mode','line','forest']# 定义6种不同的缺失值处理方法:
# 'drop' - 直接删除缺失值样本
# 'mean' - 用特征均值填充缺失值
# 'median' - 用特征中位数填充缺失值  
# 'mode' - 用特征众数填充缺失值
# 'line' - 用线性插值填充缺失值
# 'forest' - 用随机森林预测填充缺失值
  • 模型训练和评估
mak = MakeModel(key) 
mak.train_process('lr,rf,xg,nb,adaboost')# 对每种预处理方法(key),训练以下5种机器学习模型:
# 'lr' - 逻辑回归(Logistic Regression)
# 'rf' - 随机森林(Random Forest)
# 'xg' - XGBoost
# 'nb' - 朴素贝叶斯(Naive Bayes)
# 'adaboost' - AdaBoost# 训练流程包括:
# 1. 数据预处理(根据key指定的方法处理缺失值)
# 2. 特征工程(如标准化、编码等)
# 3. 模型训练
# 4. 交叉验证评估# 评估结果保存在mak.result字典中,包含各模型的准确率、召回率等指标
  1. 结果保存
with open(f'./数据/result.json', 'w', encoding='utf-8') as f:json.dump(results, f, ensure_ascii=False, indent=4)# 将所有预处理方法和模型的评估结果保存为JSON文件:
# - 文件路径为'./数据/result.json'
# - encoding='utf-8'确保支持中文路径和内容
# - ensure_ascii=False保证中文字符正常显示
# - indent=4使输出格式更易读(带缩进)
# - 文件内容包含各预处理方法下各模型的评估指标

可视化结果

import json
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import font_manager# 设置中文字体(三种方案任选其一)
# 方案1:使用系统自带字体(推荐)
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']  # 微软雅黑,Windows系统默认中文字体
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示为方块的问题# 方案2:指定具体字体路径(如果方案1无效,比如在Linux系统下)
# font_path = "C:/Windows/Fonts/simhei.ttf"  # 黑体路径示例
# font_prop = font_manager.FontProperties(fname=font_path)
# plt.rcParams['font.family'] = font_prop.get_name()# 方案3:使用SimHei字体(需确保已安装)
# plt.rcParams['font.sans-serif'] = ['SimHei']  # 适用于所有平台的通用中文字体解决方案# 加载数据文件
with open('./数据/result.json', 'r', encoding='utf-8') as f:data = json.load(f)  # 读取JSON格式的模型评估结果# 准备数据
models = ['LR', 'rf', 'xb', 'nb', 'adaboost']  # 定义要比较的模型列表:逻辑回归、随机森林、XGBoost、朴素贝叶斯、AdaBoost
fill_methods = list(data.keys())  # 获取所有数据填充方法的名称# 提取准确率数据
acc_data = {model: [data[method][model]['acc'] for method in fill_methods]for model in models  # 构建字典,存储每个模型在不同填充方法下的准确率
}# 创建画布
plt.figure(figsize=(12, 7), dpi=100)  # 设置图表大小和分辨率# 颜色和标记样式设置
colors = plt.cm.tab10(np.linspace(0, 1, len(models)))  # 使用tab10色系为不同模型分配颜色
markers = ['o', 's', '^', 'D', 'v']  # 定义不同的标记形状:圆形、方形、三角形、菱形、倒三角形# 为每个模型绘制点线图
for idx, model in enumerate(models):plt.plot(fill_methods, acc_data[model],marker=markers[idx],  # 设置标记形状color=colors[idx],  # 设置线条颜色linestyle='-',  # 实线连接linewidth=2.5,  # 线条粗细markersize=10,  # 标记大小markeredgewidth=2,  # 标记边缘宽度markeredgecolor='white',  # 标记边缘颜色label=model)  # 图例标签# 添加图表标题和坐标轴标签
plt.title('不同填充方法下模型准确率对比',fontsize=15, pad=20)  # 标题字体大小和与图表顶部的距离
plt.xlabel('填充方法', fontsize=13, labelpad=10)  # X轴标签
plt.ylabel('准确率', fontsize=13, labelpad=10)  # Y轴标签
plt.ylim(0.35, 0.75)  # 设置Y轴范围
plt.yticks(np.arange(0.35, 0.76, 0.05))  # 设置Y轴刻度间隔为0.05# 网格和边框设置
plt.grid(True, linestyle='--', alpha=0.4, which='both')  # 添加虚线网格
ax = plt.gca()  # 获取当前坐标轴
for spine in ax.spines.values():  # 隐藏所有边框spine.set_visible(False)# 添加数据标签
for model in models:for x, y in zip(fill_methods, acc_data[model]):plt.text(x, y-0.02, f'{y:.2f}',  # 在数据点下方显示两位小数的准确率ha='center',  # 水平居中va='top',  # 垂直对齐fontsize=10,  # 字体大小bbox=dict(boxstyle='round,pad=0.2',  # 添加圆角矩形背景fc='white',  # 背景颜色ec='none',  # 无边框alpha=0.7))  # 透明度# 图例设置
legend = plt.legend(loc='upper left',  # 图例位置bbox_to_anchor=(1, 1),  # 图例框位置frameon=True,  # 显示图例框title='模型类型',  # 图例标题title_fontsize=12)  # 标题字体大小
legend.get_frame().set_facecolor('#f5f5f5')  # 设置图例背景色# X轴标签旋转
plt.xticks(rotation=35, ha='right')  # 将X轴标签旋转35度,右对齐plt.tight_layout()  # 自动调整子图参数,使之填充整个图像区域# 保存图片(使用英文文件名避免编码问题)
plt.savefig('./model_accuracy_comparison.png',  # 输出文件名dpi=300,  # 图片分辨率bbox_inches='tight')  # 保存时包含所有元素
plt.show()  # 显示图表

http://www.dtcms.com/a/347819.html

相关文章:

  • DS18B20温度传感器详解
  • 电阻的功率
  • 多光谱相机检测石油石化行业的跑冒滴漏的可行性分析
  • 电蚊拍的原理及电压电容参数深度解析:从高频振荡到倍压整流的完整技术剖析
  • 决策树基础学习教育第二课:量化最优分裂——信息熵与基尼系数
  • 01_Python的in运算符判断列表等是否包含特定元素
  • [Vid-LLM] 创建和训练Vid-LLMs的各种方法体系
  • crypto.randomUUID is not a function
  • 一个备份、去除、新增k8s的node标签脚本
  • Redis(八股二弹)
  • 玳瑁的嵌入式日记D24-0823(数据结构)
  • 每日一题8.23
  • Day26 树的层序遍历 哈希表 排序算法 内核链表
  • 线程池理解
  • CMake安装教程
  • 传统 AI 与生成式 AI:IT 领导者指南
  • 10.Shell脚本修炼手册---脚本的条件测试与比较
  • 如何查看MySQL 的执行计划?
  • 引领GEO优化服务新潮流 赋能企业数字化转型
  • 信贷模型域——信贷获客模型(获客模型)
  • AI大模型 限时找我领取
  • Transformer核心概念I-token
  • Java:对象的浅拷贝与深拷贝
  • 获取高德地图经纬度解析地址的免费API接口(无调用限制)
  • JWT实现Token登录验证
  • 任务型Agent:执行计划详细设计
  • 计算机组成原理(11) 第二章 - 存储系统的基本概念
  • Introduction to GIS ——Chapter 1(Introduction)
  • 控制建模matlab练习15:线性状态反馈控制器-④最优化控制LQR
  • 动态内存详解