当前位置: 首页 > news >正文

训练+评估流程

训练+评估流程

  • 1、要求
  • 2、训练+评估(PyTorch + TensorBoard )
    • 完整代码(单文件示例)
    • 运行方法
    • 功能对应表
  • 3、pytorch自定义评估要继承哪个类?
  • 4、HF Trainer和SB3
  • 5、 汇总
    • 1. PyTorch Lightning + TensorBoard + ModelCheckpoint + EarlyStopping
      • 核心代码示例:
    • 2. TensorFlow / Keras + TensorBoard + ModelCheckpoint + EarlyStopping
    • 3. Stable Baselines3 (强化学习)
    • 4. Huggingface Trainer(NLP)
    • 5. 结合Weights & Biases(W\&B)
    • 总结推荐

1、要求

训练:
- 可视化
- 回调
- 自定义loss
- 自定义metric评估:
- 可视化
- 自定义数据:
- 自定义模型:
- 自定义

训练:训练可视化,自定义回调,评估指标
评估:评估可视化,自定义评估指标

数据:自定义数据
https://lightning.ai/docs/pytorch/stable/data/datamodule.html

模型:自定义模型

回调集合:参考sb3
稳定基线为您提供了一组常见的回调:
定期保存模型(CheckpointCallback)
定期评估模型并保存最佳模型(EvalCallback)
链接回调 (CallbackList)
触发事件回调(事件回调、EveryNTimesteps)
每 N 个时间步长记录数据 (LogEveryNTimesteps)
根据奖励阈值提前停止训练 (StopTrainingOnRewardThreshold)

2、训练+评估(PyTorch + TensorBoard )

官方api,https://docs.pytorch.ac.cn/docs/stable/tensorboard.html
https://blog.csdn.net/Python_Ai_Road/article/details/107704530
https://github.com/haoxizhong/pytorch-worker
https://github.com/onepanelio/onepanel

提问:

pytorch tensorboard实现下面要求:
训练:
- 可视化
- 回调
- 自定义metric评估:
- 可视化
- 自定义数据:
- 自定义模型:
- 自定义

好的,我帮你写一个PyTorch + TensorBoard 完整可运行的示例,
满足你所有要求:

  • 训练
    ✅ 可视化(loss/accuracy/自定义metric曲线)
    ✅ 回调(EarlyStopping)
    ✅ 自定义 metric(比如 F1-score)

  • 评估
    ✅ 可视化(混淆矩阵、曲线)
    ✅ 自定义指标(precision、recall、F1等)

  • 数据
    ✅ 自定义 Dataset / DataLoader

  • 模型
    ✅ 自定义 nn.Module


完整代码(单文件示例)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import time# ===================== 1. 自定义数据集 =====================
class MyDataset(Dataset):def __init__(self, size=500):np.random.seed(0)self.x = np.random.rand(size, 10).astype(np.float32)self.y = (self.x.sum(axis=1) > 5).astype(np.int64)  # 二分类def __len__(self):return len(self.x)def __getitem__(self, idx):return self.x[idx], self.y[idx]# ===================== 2. 自定义模型 =====================
class MyModel(nn.Module):def __init__(self, input_dim=10, hidden_dim=32, num_classes=2):super().__init__()self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, num_classes))def forward(self, x):return self.net(x)# ===================== 3. 回调(早停) =====================
class EarlyStopping:def __init__(self, patience=5, save_path="best_model.pth"):self.patience = patienceself.counter = 0self.best_loss = float("inf")self.save_path = save_pathself.early_stop = Falsedef __call__(self, val_loss, model):if val_loss < self.best_loss:self.best_loss = val_lossself.counter = 0torch.save(model.state_dict(), self.save_path)else:self.counter += 1if self.counter >= self.patience:self.early_stop = True# ===================== 4. 自定义 metric =====================
def compute_metrics(y_true, y_pred):precision = precision_score(y_true, y_pred, zero_division=0)recall = recall_score(y_true, y_pred, zero_division=0)f1 = f1_score(y_true, y_pred, zero_division=0)return precision, recall, f1# ===================== 5. 训练(可视化 + 回调 + 自定义metric) =====================
def train_model():train_set = MyDataset(500)val_set = MyDataset(100)train_loader = DataLoader(train_set, batch_size=32, shuffle=True)val_loader = DataLoader(val_set, batch_size=32)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = MyModel().to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)log_dir = f"runs/train_{time.strftime('%Y%m%d-%H%M%S')}"writer = SummaryWriter(log_dir=log_dir)early_stopping = EarlyStopping(patience=5)for epoch in range(50):# ===== 训练 =====model.train()total_loss, y_true_train, y_pred_train = 0, [], []for x, y in train_loader:x, y = x.to(device), y.to(device)optimizer.zero_grad()outputs = model(x)loss = criterion(outputs, y)loss.backward()optimizer.step()total_loss += loss.item()preds = outputs.argmax(1)y_true_train.extend(y.cpu().numpy())y_pred_train.extend(preds.cpu().numpy())train_loss = total_loss / len(train_loader)train_precision, train_recall, train_f1 = compute_metrics(y_true_train, y_pred_train)train_acc = np.mean(np.array(y_true_train) == np.array(y_pred_train))# ===== 验证 =====model.eval()val_loss, y_true_val, y_pred_val = 0, [], []with torch.no_grad():for x, y in val_loader:x, y = x.to(device), y.to(device)outputs = model(x)loss = criterion(outputs, y)val_loss += loss.item()preds = outputs.argmax(1)y_true_val.extend(y.cpu().numpy())y_pred_val.extend(preds.cpu().numpy())val_loss /= len(val_loader)val_precision, val_recall, val_f1 = compute_metrics(y_true_val, y_pred_val)val_acc = np.mean(np.array(y_true_val) == np.array(y_pred_val))# ===== 写入 TensorBoard =====writer.add_scalar("Loss/train", train_loss, epoch)writer.add_scalar("Loss/val", val_loss, epoch)writer.add_scalar("Accuracy/train", train_acc, epoch)writer.add_scalar("Accuracy/val", val_acc, epoch)writer.add_scalar("F1/train", train_f1, epoch)writer.add_scalar("F1/val", val_f1, epoch)print(f"Epoch {epoch+1} | "f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} F1: {train_f1:.4f} | "f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} F1: {val_f1:.4f}")# 回调early_stopping(val_loss, model)if early_stopping.early_stop:print("Early stopping!")breakwriter.close()print(f"训练日志保存在 {log_dir}")# ===================== 6. 评估(可视化 + 自定义metric) =====================
def evaluate_model():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = MyModel().to(device)model.load_state_dict(torch.load("best_model.pth"))model.eval()test_set = MyDataset(100)test_loader = DataLoader(test_set, batch_size=32)y_true, y_pred = [], []with torch.no_grad():for x, y in test_loader:x, y = x.to(device), y.to(device)outputs = model(x)preds = outputs.argmax(1)y_true.extend(y.cpu().numpy())y_pred.extend(preds.cpu().numpy())acc = np.mean(np.array(y_true) == np.array(y_pred))precision, recall, f1 = compute_metrics(y_true, y_pred)print(f"Test Accuracy: {acc:.4f}")print(f"Test Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}")# ===== 混淆矩阵可视化 =====cm = confusion_matrix(y_true, y_pred)fig, ax = plt.subplots()ax.matshow(cm, cmap=plt.cm.Blues)for i in range(cm.shape[0]):for j in range(cm.shape[1]):ax.text(j, i, cm[i, j], va='center', ha='center')plt.xlabel("Predicted")plt.ylabel("True")plt.title("Confusion Matrix")plt.show()if __name__ == "__main__":train_model()evaluate_model()

运行方法

# 安装依赖
pip install torch tensorboard scikit-learn matplotlib# 运行
python demo.py# 查看TensorBoard
tensorboard --logdir runs

功能对应表

模块功能实现方式
训练可视化 loss / acc / F1SummaryWriter.add_scalar()
回调 EarlyStoppingEarlyStopping
自定义 metric(precision/recall/F1)compute_metrics()
评估可视化混淆矩阵matplotlib + confusion_matrix
自定义 metric(precision/recall/F1)compute_metrics()
数据自定义 DatasetMyDataset
模型自定义 nn.ModuleMyModel

我建议你如果后面项目要长期用,可以帮你改成模块化结构(dataset.py、model.py、train.py、evaluate.py 分开),
这样扩展性更好。

你要我下一步帮你拆成模块化版本吗?这样更适合大型项目维护。

3、pytorch自定义评估要继承哪个类?

pytorch自定义评估要继承哪个类

PyTorch 里,如果你是纯用 torch(而不是 PyTorch Lightning、Hugging Face Trainer 等封装框架),自定义评估其实不用继承任何固定的类,因为 PyTorch 本身没有“评估器”这一内置抽象,评估通常就是自己写一个函数:

def evaluate(model, dataloader, criterion):model.eval()total_loss = 0.0correct = 0total = 0with torch.no_grad():for inputs, targets in dataloader:outputs = model(inputs)loss = criterion(outputs, targets)total_loss += loss.item()# 分类任务示例_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()avg_loss = total_loss / len(dataloader)accuracy = 100. * correct / totalreturn avg_loss, accuracy

但是,如果你是在 某个训练框架 里做自定义评估,就会有对应的基类:

框架自定义评估类要继承的基类
PyTorch Lightningpl.Callbackpl.LightningModule(重写 validation_step
Hugging Face TransformersTrainerCallback 或直接在 Trainer 里传 compute_metrics
TorchMetricstorchmetrics.Metric
Catalystcatalyst.metrics.MetricCallback
Igniteignite.metrics.Metric

所以我需要确认一下,你的“自定义评估”是:

  • 纯 PyTorch 写的训练循环
  • 还是基于 Lightning / Transformers / 其他训练框架

我猜你可能是想在 PPO 或 RL 里加自定义评估指标,如果是那种情况,类继承方式又不一样。
你要我帮你按 RL + PPO 场景写一个自定义评估类吗?这样就直接能接到你现有环境里。

4、HF Trainer和SB3

详细参考下面笔记:
https://blog.csdn.net/weixin_44986037/article/details/151153156?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/151225587?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/150427299?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/148544185?spm=1001.2014.3001.5501

建议用针对性框架:
Stable Baselines3 : 强化学习
Huggingface Trainer :NLP及Transformer模型
PyTorch Lightning:不追求训练速度,建议用tensorboard,PyTorch Lightning坑有点多

5、 汇总

提问:现成的ai模型训练评估框架,包括tersorboard查看训练数据、筛选模型、评估效果

提问:

简单概述主流现成的AI模型训练+评估框架,包含:训练过程中的数据监控(比如通过 TensorBoard 可视化损失、指标)模型筛选(如自动保存最优模型、早停)评估效果(自动计算指标,方便比较模型)

你想要的是一个现成的AI模型训练+评估框架,包含:

  • 训练过程中的数据监控(比如通过 TensorBoard 可视化损失、指标)
  • 模型筛选(如自动保存最优模型、早停)
  • 评估效果(自动计算指标,方便比较模型)

下面给你推荐几个主流且功能完整的框架/工具组合,覆盖你需求的这几个点:


1. PyTorch Lightning + TensorBoard + ModelCheckpoint + EarlyStopping

  • 训练监控:Lightning内置对TensorBoard的支持,训练时自动记录loss、metric到TensorBoard。
  • 模型筛选:内置 ModelCheckpoint,可自动保存指标最优模型(如val_loss最小、准确率最高)。
  • 早停:内置 EarlyStopping,当验证指标不再提升时自动停止训练。
  • 评估效果:支持训练、验证、测试阶段指标计算和记录。

核心代码示例:

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStoppingcheckpoint_callback = ModelCheckpoint(monitor='val_loss',    # 监控验证损失mode='min',           # 取最小值对应模型save_top_k=1,verbose=True,
)early_stop_callback = EarlyStopping(monitor='val_loss',patience=5,verbose=True,mode='min'
)trainer = pl.Trainer(max_epochs=50,callbacks=[checkpoint_callback, early_stop_callback],logger=pl.loggers.TensorBoardLogger('tb_logs/')
)trainer.fit(model, train_dataloader, val_dataloader)

启动后,在终端运行 tensorboard --logdir=tb_logs/ 即可实时查看训练过程。


2. TensorFlow / Keras + TensorBoard + ModelCheckpoint + EarlyStopping

  • TensorFlow内置支持TensorBoard,训练时会自动写入日志。
  • ModelCheckpoint可保存验证指标最优模型。
  • EarlyStopping自动停止训练避免过拟合。

示例:

callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs'),tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True),tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
]model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=50, callbacks=callbacks)

3. Stable Baselines3 (强化学习)

  • 自带 EvalCallback 用于训练过程中自动评估模型表现。
  • CheckpointCallback 用于保存模型。
  • 训练日志可通过 TensorBoard 监控。

示例:

from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallbackeval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',log_path='./logs/', eval_freq=10000, deterministic=True)
checkpoint_callback = CheckpointCallback(save_freq=5000, save_path='./logs/')model.learn(total_timesteps=100000, callback=[eval_callback, checkpoint_callback])

运行 tensorboard --logdir=./logs 可查看指标变化。


4. Huggingface Trainer(NLP)

  • 内置日志记录(支持TensorBoard、WandB等)
  • 自动保存最优模型和早停(EarlyStoppingCallback)
  • 评估指标自动计算和记录

示例:

from transformers import Trainer, TrainingArguments, EarlyStoppingCallbacktraining_args = TrainingArguments(output_dir='./results',evaluation_strategy='epoch',save_strategy='epoch',logging_dir='./logs',logging_steps=10,load_best_model_at_end=True,
)trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)trainer.train()

然后使用TensorBoard查看 ./logs 文件夹。


5. 结合Weights & Biases(W&B)

(即WandB)

  • 支持各种框架(PyTorch、TensorFlow、Lightning、SB3、HF等)
  • 自动记录训练日志、模型、超参
  • 支持在线实时可视化和对比多个实验
  • 支持模型版本管理和筛选

官网:https://wandb.ai/site


总结推荐

框架/工具训练监控(TensorBoard)自动保存最优模型早停评估指标简单易用适用范围
PyTorch LightningPyTorch通用
TensorFlow / KerasTensorFlow/Keras
Stable Baselines3部分强化学习
Huggingface TrainerNLP及Transformer模型
Weights & Biases ((WandB))各类深度学习框架

如果你告诉我你用的是哪种框架,我可以帮你写具体的训练+评估+TensorBoard配置示例,或者帮你搭建完整的训练流水线。你看?


文章转载自:

http://hVnf9QWH.Lkjzz.cn
http://E2r0LWVC.Lkjzz.cn
http://h1aGeyiN.Lkjzz.cn
http://QVD9ztxL.Lkjzz.cn
http://p3i7oRRQ.Lkjzz.cn
http://E5n9tN6x.Lkjzz.cn
http://Ga13Zfq8.Lkjzz.cn
http://VWWFyvCb.Lkjzz.cn
http://QENu6Wqj.Lkjzz.cn
http://O4iAuIUU.Lkjzz.cn
http://2dFaQQO9.Lkjzz.cn
http://cTKQy3bI.Lkjzz.cn
http://2NDEe1T7.Lkjzz.cn
http://flLTHg39.Lkjzz.cn
http://04RoCz0l.Lkjzz.cn
http://ZFBx0nyV.Lkjzz.cn
http://9qGW3zOq.Lkjzz.cn
http://FJ7yqwyM.Lkjzz.cn
http://JlP1Yeok.Lkjzz.cn
http://mebA9ORM.Lkjzz.cn
http://IQuUZjJk.Lkjzz.cn
http://hpVT1aE5.Lkjzz.cn
http://clHEIEPD.Lkjzz.cn
http://BDaz4jjr.Lkjzz.cn
http://QX2IaxDd.Lkjzz.cn
http://jpcLQQp3.Lkjzz.cn
http://esW2DU3O.Lkjzz.cn
http://Jmx3MLh9.Lkjzz.cn
http://yQgMFxUP.Lkjzz.cn
http://nIJbGkeT.Lkjzz.cn
http://www.dtcms.com/a/372234.html

相关文章:

  • 【数学建模】烟幕干扰弹投放策略优化:模型与算法整合框架
  • PHP云课堂在线网课系统 多功能网校系统 在线教育系统源码
  • redis的高可用(哨兵)
  • Redis之分布式锁与缓存设计
  • pip常用指令小结
  • Python中进行时区转换和处理
  • CTFshow系列——PHP特性Web97-100
  • Python快速入门专业版(九):字符串进阶:常用方法(查找、替换、分割、大小写转换)
  • MySQL 8.0+ 内核剖析:架构、事务与数据管理
  • 11.2.1.项目整体架构和技术选型及部署
  • [C++刷怪笼]:set/map--优质且易操作的容器
  • zotero扩容
  • 20250907_梳理异地备份每日自动巡检py脚本逻辑流程+安装Python+PyCharm+配置自动运行
  • UserManagement.vue和Profile.vue详细解释
  • Python进阶编程:文件操作、系统命令与函数设计完全指南
  • 【redis 基础】redis 的常用数据结构及其核心操作
  • 美团大模型“龙猫”登场,能否重塑本地生活新战局?
  • nats消息队列处理
  • k8s镜像推送到阿里云,使用ctr推送镜像到阿里云
  • Ubuntu Qt x64平台搭建 arm64 编译套件
  • IO性能篇(一):文件系统是怎么工作的
  • SQL Server——基本操作
  • nginx详解
  • 硬件开发1-51单片机4-DS18B20
  • 【LLIE专题】LYT-Net:一种轻量级 YUV Transformer 低光图像增强网络
  • 数据库造神计划第二天---数据库基础操作
  • TypeORM 入门教程之 `@OneToOne` 关系详解
  • 嵌入式解谜日志之数据结构—基本概念
  • make_shared的使用
  • 《九江棒球》未来十年棒垒球发展规划·棒球1号位