当前位置: 首页 > news >正文

机器学习模型训练模块技术文档

一、模块结构概览

import numpy as np
from sklearn.model_selection import cross_validate, learning_curve
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, accuracy_score, recall_score, f1_score
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
import os

依赖说明

  • numpy:处理数值计算

  • sklearn:提供机器学习算法和工具

  • matplotlib:可视化学习曲线

  • os:处理文件路径操作

二、核心类定义

2.1 类初始化

class ModelTrainer:def __init__(self):pass

功能:创建模型训练器的基础类,当前无需特殊初始化参数 

 

2.2 主训练方法 train_model

2.2.1 数据准备阶段
def train_model(self, X, y, output_dir="model_plots"):# 创建输出文件夹os.makedirs(output_dir, exist_ok=True)# 数据分割X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2,        # 20%测试集stratify=y,           # 保持类别分布random_state=42       # 可重复性种子)# 数据标准化scaler = StandardScaler()X_train_scaled = scaler.fit_transform(X_train)  # 训练集拟合+转换X_test_scaled = scaler.transform(X_test)        # 测试集仅转换# 合并标准化数据X_scaled = np.concatenate([X_train_scaled, X_test_scaled])y = np.concatenate([y_train, y_test])
 

关键技术点

  • stratify=y 保证分割后的数据保持原始类别分布

  • 标准化处理防止特征尺度差异影响模型性能

  • 合并数据集用于交叉验证

2.2.2 模型配置
models = {"Random Forest": RandomForestClassifier(n_estimators=200,  # 增加树数量提升模型容量max_depth=8,        # 限制深度防止过拟合n_jobs=-1          # 使用全部CPU核心),"Linear SVM": SVC(kernel='rbf',       # 选择径向基函数核C=0.5,             # 正则化强度参数gamma='auto',      # 自动计算gamma参数probability=True   # 启用概率估计),"KNN": KNeighborsClassifier(n_neighbors=3,     # 使用3近邻n_jobs=-1          # 并行计算)
}scoring = {'accuracy': make_scorer(accuracy_score),'recall': make_scorer(recall_score, average='macro'),  # 多分类宏平均'f1': make_scorer(f1_score, average='macro')
}
 

参数调优说明

  • 随机森林:通过限制max_depth平衡偏差-方差

  • SVM:调整C值控制正则化强度

  • KNN:小邻域数适合高维度数据

2.2.3 交叉验证流程
best_score = -1
best_model_name = ""
best_model = Nonefor name, model in models.items():# 交叉验证cv_results = cross_validate(model, X_scaled, y, cv=3,              # 3折交叉验证scoring=scoring    # 使用自定义指标)# 指标计算acc = np.mean(cv_results['test_accuracy'])rec = np.mean(cv_results['test_recall'])f1 = np.mean(cv_results['test_f1'])# 模型比较if f1 > best_score:best_score = f1best_model_name = namebest_model = model# 生成学习曲线self.plot_learning_curve(model, X_scaled, y, name, output_dir)

评估策略

  • 使用3折交叉验证降低数据划分敏感性

  • 以F1宏平均作为模型选择标准

  • 同步输出各模型指标的标准差

2.3 学习曲线绘制 plot_learning_curve

2.3.1 数据计算

def plot_learning_curve(self, model, X, y, model_name, output_dir):train_sizes, train_scores, test_scores = learning_curve(model, X, y, cv=3,               # 3折交叉验证scoring='accuracy', # 使用准确率指标n_jobs=-1          # 并行计算)# 统计量计算train_mean = np.mean(train_scores, axis=1)train_std = np.std(train_scores, axis=1)test_mean = np.mean(test_scores, axis=1)test_std = np.std(test_scores, axis=1)
2.3.2 可视化实现
    plt.figure(figsize=(8, 6))plt.fill_between(train_sizes,train_mean - train_std,train_mean + train_std,alpha=0.1, color="r")plt.plot(train_sizes, train_mean, 'o-', color="r", label="Training score")# 测试集曲线同理...plt.title(f"Learning Curve ({model_name})")plt.xlabel("Training Examples")plt.ylabel("Accuracy Score")plt.legend(loc="best")# 保存图像output_path = os.path.join(output_dir, f"{model_name}_learning_curve.png")plt.savefig(output_path)plt.close()
 

可视化分析

  • 阴影区域表示±1标准差范围

  • 训练曲线(红色)与验证曲线(绿色)对比

  • 图像尺寸设为8x6英寸保证可读性

三、使用流程示例

# 示例数据
X, y = load_your_data()  # 需自定义数据加载方法# 初始化训练器
trainer = ModelTrainer()# 执行训练
best_model = trainer.train_model(X, y,output_dir="my_models"  # 指定输出目录
)# 使用最佳模型预测
predictions = best_model.predict(new_data)

四、输出文件结构


model_plots/
├── Random Forest_learning_curve.png
├── Linear SVM_learning_curve.png
└── KNN_learning_curve.png

图像展示模型的学习过程,帮助诊断欠/过拟合问题

 

相关文章:

  • XZ03_Overleaf使用教程
  • 名词解释DCDC
  • Wannier90文件与参数
  • Three.js + React 实战系列 - 项目展示区开发详解 Projects 组件(3D 模型 + 动效 + 状态切换)✨
  • DeepSeek技术发展详细时间轴与技术核心解析
  • 【KWDB 创作者计划】基于 ESP32 + KWDB 的智能环境监测系统实战
  • 人工智能浪潮中Python的核心作用与重要地位
  • DeepSeek成本控制的三重奏
  • 学习路线(工业自动化软件架构)
  • 【将你的IDAPython插件迁移到IDA 9.x:核心API变更与升级指南】
  • suna工具调用可视化界面实现原理分析(一)
  • 2025系统架构师---论面向对象的软件设计
  • S100平台调试RS485/RS232
  • JavaSE笔记--反射篇
  • 位运算-详细总结
  • 前端-Vue的项目流程
  • 【Unity】一个AssetBundle热更新的使用小例子
  • 2023年408真题及答案
  • transformer读后感
  • QT6 源(77):阅读与注释滚动条 QScrollBar 的源码,其是基类QAbstractSlider 的子类,
  • 娱见 | 为了撕番而脱粉,内娱粉丝为何如此在乎番位
  • 印尼巴厘岛多地停电,疑似海底电缆发生故障
  • 波兰斯基最新回忆录追述“二战”童年往事
  • 甘肃公布校园食品安全专项整治案例,有食堂涉腐败变质食物
  • 多地景区发公告称售票达接待峰值,有景区暂停网络和线下售票
  • 据报特斯拉寻找新CEO,马斯克财报会议上表态:把更多时间投入特斯拉