当前位置: 首页 > wzjs >正文

浙江网站建设服务人工智能培训机构排名

浙江网站建设服务,人工智能培训机构排名,WordPress完整虚拟资源,建设网站需要申请报告涵盖断点续训、早停机制、定期保存检查点等 检查点保存逻辑 检查点保存分为两种: 最新检查点:每次训练都会保存为 latest.pth,用于恢复训练。 最佳模型:仅在验证损失达到新低时保存为 model_best.pth。 定期保存(…

涵盖断点续训、早停机制、定期保存检查点等

检查点保存逻辑

检查点保存分为两种:

最新检查点:每次训练都会保存为 latest.pth,用于恢复训练。

最佳模型:仅在验证损失达到新低时保存为 model_best.pth。

定期保存(每 checkpoint_interval 轮)也确保了即使训练中断,也能恢复到最近的状态。

2. 早停机制

当验证损失连续 early_stop_patience 轮未改善时,触发早停,避免过拟合或浪费计算资源。

这是一个非常实用的功能,特别是在超参数调试阶段。

3. 命令行参数支持

使用 argparse 支持通过命令行指定恢复训练的检查点路径,提升了脚本的灵活性。

4. CUDA基准模式

启用 torch.backends.cudnn.benchmark = True 可以加速卷积操作,尤其是在输入尺寸固定的情况下。

"""
一个黑客创业者:年龄预测模型完整训练(支持CPU/GPU、断点续训、早停机制)
执行方式:
1. 训练:python train_age.py
2. 恢复:python train_age.py --resume ./checkpoints/latest.pth
"""import os
import time
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm# 配置参数
CONFIG = {# 数据路径"train_age_list": r"D:\daku\性别\megaage_asian\list\train_age.txt","val_age_list": r"D:\daku\性别\megaage_asian\list\test_age.txt","train_image_dir": r"D:\daku\性别\megaage_asian\train","val_image_dir": r"D:\daku\性别\megaage_asian\val",# 训练参数"batch_size": 64,"num_workers": 4 if torch.cuda.is_available() else 2,"learning_rate": 3e-4,"num_epochs": 100,"input_size": 224,# 系统参数"checkpoint_dir": "./checkpoints","checkpoint_interval": 1,"early_stop_patience": 7,"use_amp": torch.cuda.is_available(),  # 自动判断是否启用混合精度"resume": None
}class AgeDataset(Dataset):"""处理序号命名图片和年龄列表的数据集"""def __init__(self, age_list_path, image_dir, transform=None):self.image_dir = image_dirself.transform = transform# 加载年龄数据with open(age_list_path, 'r') as f:self.ages = []line_count = 0for line in f:line_count += 1line = line.strip()try:age = float(line)if 0 <= age <= 120:self.ages.append(age)else:print(f"行 {line_count}: 异常年龄值 {age},已过滤")except ValueError:print(f"行 {line_count}: 无效年龄值 '{line}',已跳过")# 加载并排序图片文件self.image_files = sorted([f for f in os.listdir(image_dir)if f.lower().endswith(('.jpg', '.jpeg', '.png'))],key=lambda x: int(os.path.splitext(x)[0]))# 对齐数据长度self.num_samples = min(len(self.ages), len(self.image_files))if len(self.ages) != len(self.image_files):print(f"警告: 年龄数({len(self.ages)})与图片数({len(self.image_files)})不一致,使用前{self.num_samples}个样本")def __len__(self):return self.num_samplesdef __getitem__(self, idx):# 生成图片路径img_name = self.image_files[idx]img_path = os.path.join(self.image_dir, img_name)# 加载图片try:with Image.open(img_path) as img:image = img.convert('RGB')except Exception as e:print(f"图片加载失败: {img_path},错误: {str(e)}")return self[(idx + 1) % len(self)]  # 跳过错误样本# 获取年龄age = torch.tensor(self.ages[idx], dtype=torch.float32)if self.transform:image = self.transform(image)return image, agedef create_data_loaders():"""创建数据加载器"""train_transform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 创建数据集train_set = AgeDataset(CONFIG["train_age_list"], CONFIG["train_image_dir"], train_transform)val_set = AgeDataset(CONFIG["val_age_list"], CONFIG["val_image_dir"], val_transform)print(f"\n数据集统计:")print(f"训练样本: {len(train_set)} | 验证样本: {len(val_set)}")# 创建数据加载器train_loader = DataLoader(train_set,batch_size=CONFIG["batch_size"],shuffle=True,num_workers=CONFIG["num_workers"],pin_memory=torch.cuda.is_available(),persistent_workers=torch.cuda.is_available())val_loader = DataLoader(val_set,batch_size=CONFIG["batch_size"],shuffle=False,num_workers=CONFIG["num_workers"],pin_memory=torch.cuda.is_available())return train_loader, val_loaderclass AgeRegressor(nn.Module):"""年龄回归模型"""def __init__(self, pretrained=True):super().__init__()base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT if pretrained else None)self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])self.regressor = nn.Sequential(nn.Linear(base_model.fc.in_features, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 1))def forward(self, x):features = self.feature_extractor(x).flatten(1)return self.regressor(features).squeeze(1)def initialize_training(resume_path=None):"""初始化训练环境"""device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"\n训练设备: {device}")# 初始化模型model = AgeRegressor().to(device)optimizer = optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"], weight_decay=1e-4)criterion = nn.HuberLoss()# 自动处理AMPscaler = torch.cuda.amp.GradScaler(enabled=CONFIG["use_amp"]) if torch.cuda.is_available() else Nonescheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3)# 训练状态start_epoch = 0best_loss = float('inf')no_improve = 0# 断点续训if resume_path and os.path.exists(resume_path):checkpoint = torch.load(resume_path, map_location=device)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])scheduler.load_state_dict(checkpoint['scheduler'])start_epoch = checkpoint['epoch'] + 1best_loss = checkpoint['best_loss']no_improve = checkpoint['no_improve']if scaler and 'scaler' in checkpoint:scaler.load_state_dict(checkpoint['scaler'])print(f"成功恢复训练状态,从第 {start_epoch} 轮开始")return {"device": device,"model": model,"optimizer": optimizer,"criterion": criterion,"scaler": scaler,"scheduler": scheduler,"start_epoch": start_epoch,"best_loss": best_loss,"no_improve": no_improve}def train_epoch(model, device, train_loader, optimizer, criterion, scaler):"""训练单个epoch"""model.train()total_loss = 0.0with tqdm(train_loader, desc="训练", unit="batch") as pbar:for images, labels in pbar:images = images.to(device, non_blocking=True)labels = labels.to(device, non_blocking=True)optimizer.zero_grad(set_to_none=True)# 混合精度训练with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu',enabled=CONFIG["use_amp"]):outputs = model(images)loss = criterion(outputs, labels)# 反向传播if scaler:scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()else:loss.backward()optimizer.step()total_loss += loss.item() * images.size(0)pbar.set_postfix(loss=loss.item())return total_loss / len(train_loader.dataset)def validate(model, device, val_loader, criterion):"""验证循环"""model.eval()total_loss = 0.0with torch.no_grad(), tqdm(val_loader, desc="验证", unit="batch") as pbar:for images, labels in pbar:images = images.to(device, non_blocking=True)labels = labels.to(device, non_blocking=True)outputs = model(images)loss = criterion(outputs, labels)total_loss += loss.item() * images.size(0)pbar.set_postfix(loss=loss.item())return total_loss / len(val_loader.dataset)def save_checkpoint(state, filename, is_best=False):"""保存检查点"""os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)filepath = os.path.join(CONFIG["checkpoint_dir"], filename)# 保存完整状态torch.save(state, filepath)# 保存最佳模型if is_best:best_path = os.path.join(CONFIG["checkpoint_dir"], "model_best.pth")torch.save(state["model"], best_path)def main():# 初始化train_loader, val_loader = create_data_loaders()training_env = initialize_training(CONFIG["resume"])# 解包训练环境device = training_env["device"]model = training_env["model"]optimizer = training_env["optimizer"]criterion = training_env["criterion"]scaler = training_env["scaler"]scheduler = training_env["scheduler"]start_epoch = training_env["start_epoch"]best_loss = training_env["best_loss"]no_improve = training_env["no_improve"]# 训练循环for epoch in range(start_epoch, CONFIG["num_epochs"]):print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")start_time = time.time()# 训练与验证train_loss = train_epoch(model, device, train_loader, optimizer, criterion, scaler)val_loss = validate(model, device, val_loader, criterion)scheduler.step(val_loss)# 统计信息epoch_time = time.time() - start_timelr = optimizer.param_groups[0]['lr']print(f"耗时: {epoch_time // 60:.0f}m{epoch_time % 60:.0f}s | LR: {lr:.1e} | "f"训练损失: {train_loss:.4f} | 验证损失: {val_loss:.4f}")# 保存检查点is_best = val_loss < best_lossif is_best:best_loss = val_lossno_improve = 0else:no_improve += 1checkpoint = {'epoch': epoch,'model': model.state_dict(),'optimizer': optimizer.state_dict(),'scheduler': scheduler.state_dict(),'scaler': scaler.state_dict() if scaler else None,'best_loss': best_loss,'no_improve': no_improve,'config': CONFIG}# 定期保存if is_best or (epoch + 1) % CONFIG["checkpoint_interval"] == 0:save_checkpoint(checkpoint, f"epoch_{epoch + 1}.pth", is_best)# 保存最新检查点save_checkpoint(checkpoint, "latest.pth")# 早停机制if no_improve >= CONFIG["early_stop_patience"]:print(f"\n早停触发: 验证损失连续 {CONFIG['early_stop_patience']} 轮未提升")breakif __name__ == "__main__":# 命令行参数parser = argparse.ArgumentParser()parser.add_argument('--resume', help='恢复训练的检查点路径')args = parser.parse_args()if args.resume:CONFIG["resume"] = args.resume# 设置CUDA基准模式if torch.cuda.is_available():torch.backends.cudnn.benchmark = True# 启动训练main()

http://www.dtcms.com/wzjs/482608.html

相关文章:

  • 广州黄埔做网站公司哪家好交换链接营销实现方式解读
  • 北京文化传媒有限公司网站建设考证培训机构报名网站
  • 苹果软件下载网站付费内容网站
  • 哪个网站推广做的好地推app
  • 外贸网站做的作用是什么小说榜单首页百度搜索风云榜
  • 网站首页确认书太原互联网推广公司
  • 做企业平台的网站有哪些方面使用网站模板快速建站
  • 网站建设社区三门峡网站seo
  • 微信公众好第三方网站怎么做石家庄新闻网头条新闻
  • 网站网页翻页设计营销广告
  • 网站的建设框架网站的seo优化报告
  • 渐江建工水利水电建设有限公司网站最新的即时比分
  • 哦咪咖网站建设网络营销技巧
  • 一家公司做两个网站吗长春seo网站管理
  • 做宣传 为什么要做网站那网站可以自己做吗
  • 电子商务网站开发策划案快速排名推荐
  • 做网站广告多少钱宁德seo公司
  • 做网站源代码需要买吗微信seo什么意思
  • 做网站公示游戏优化大师官方下载
  • 网站开发 卓优科技产品推广软文500字
  • 北京建网站公司企业文化建设方案
  • 做网站服务器哪种好网络推广外包怎么接单
  • 海兴网站建设搜索引擎优化的报告
  • iis7.5 部署网站长春网站制作计划
  • 网站建设详细设计关键词搜索量查询工具
  • 企业网站建设的要求2021谷歌搜索入口
  • 如何添加网站 ico图标茂名seo顾问服务
  • 网页设计作业成品代码和文字深圳百度网站排名优化
  • 企业网站备案代理公司太原做网站哪家好
  • 梧州网站推广外包服务电商网站排名