当前位置: 首页 > wzjs >正文

凡客建站手机版下载色盲测试

凡客建站手机版下载,色盲测试,长沙做模板网站,自己做网站想更换网址涵盖断点续训、早停机制、定期保存检查点等 检查点保存逻辑 检查点保存分为两种: 最新检查点:每次训练都会保存为 latest.pth,用于恢复训练。 最佳模型:仅在验证损失达到新低时保存为 model_best.pth。 定期保存(…

涵盖断点续训、早停机制、定期保存检查点等

检查点保存逻辑

检查点保存分为两种:

最新检查点:每次训练都会保存为 latest.pth,用于恢复训练。

最佳模型:仅在验证损失达到新低时保存为 model_best.pth。

定期保存(每 checkpoint_interval 轮)也确保了即使训练中断,也能恢复到最近的状态。

2. 早停机制

当验证损失连续 early_stop_patience 轮未改善时,触发早停,避免过拟合或浪费计算资源。

这是一个非常实用的功能,特别是在超参数调试阶段。

3. 命令行参数支持

使用 argparse 支持通过命令行指定恢复训练的检查点路径,提升了脚本的灵活性。

4. CUDA基准模式

启用 torch.backends.cudnn.benchmark = True 可以加速卷积操作,尤其是在输入尺寸固定的情况下。

"""
一个黑客创业者:年龄预测模型完整训练(支持CPU/GPU、断点续训、早停机制)
执行方式:
1. 训练:python train_age.py
2. 恢复:python train_age.py --resume ./checkpoints/latest.pth
"""import os
import time
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm# 配置参数
CONFIG = {# 数据路径"train_age_list": r"D:\daku\性别\megaage_asian\list\train_age.txt","val_age_list": r"D:\daku\性别\megaage_asian\list\test_age.txt","train_image_dir": r"D:\daku\性别\megaage_asian\train","val_image_dir": r"D:\daku\性别\megaage_asian\val",# 训练参数"batch_size": 64,"num_workers": 4 if torch.cuda.is_available() else 2,"learning_rate": 3e-4,"num_epochs": 100,"input_size": 224,# 系统参数"checkpoint_dir": "./checkpoints","checkpoint_interval": 1,"early_stop_patience": 7,"use_amp": torch.cuda.is_available(),  # 自动判断是否启用混合精度"resume": None
}class AgeDataset(Dataset):"""处理序号命名图片和年龄列表的数据集"""def __init__(self, age_list_path, image_dir, transform=None):self.image_dir = image_dirself.transform = transform# 加载年龄数据with open(age_list_path, 'r') as f:self.ages = []line_count = 0for line in f:line_count += 1line = line.strip()try:age = float(line)if 0 <= age <= 120:self.ages.append(age)else:print(f"行 {line_count}: 异常年龄值 {age},已过滤")except ValueError:print(f"行 {line_count}: 无效年龄值 '{line}',已跳过")# 加载并排序图片文件self.image_files = sorted([f for f in os.listdir(image_dir)if f.lower().endswith(('.jpg', '.jpeg', '.png'))],key=lambda x: int(os.path.splitext(x)[0]))# 对齐数据长度self.num_samples = min(len(self.ages), len(self.image_files))if len(self.ages) != len(self.image_files):print(f"警告: 年龄数({len(self.ages)})与图片数({len(self.image_files)})不一致,使用前{self.num_samples}个样本")def __len__(self):return self.num_samplesdef __getitem__(self, idx):# 生成图片路径img_name = self.image_files[idx]img_path = os.path.join(self.image_dir, img_name)# 加载图片try:with Image.open(img_path) as img:image = img.convert('RGB')except Exception as e:print(f"图片加载失败: {img_path},错误: {str(e)}")return self[(idx + 1) % len(self)]  # 跳过错误样本# 获取年龄age = torch.tensor(self.ages[idx], dtype=torch.float32)if self.transform:image = self.transform(image)return image, agedef create_data_loaders():"""创建数据加载器"""train_transform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 创建数据集train_set = AgeDataset(CONFIG["train_age_list"], CONFIG["train_image_dir"], train_transform)val_set = AgeDataset(CONFIG["val_age_list"], CONFIG["val_image_dir"], val_transform)print(f"\n数据集统计:")print(f"训练样本: {len(train_set)} | 验证样本: {len(val_set)}")# 创建数据加载器train_loader = DataLoader(train_set,batch_size=CONFIG["batch_size"],shuffle=True,num_workers=CONFIG["num_workers"],pin_memory=torch.cuda.is_available(),persistent_workers=torch.cuda.is_available())val_loader = DataLoader(val_set,batch_size=CONFIG["batch_size"],shuffle=False,num_workers=CONFIG["num_workers"],pin_memory=torch.cuda.is_available())return train_loader, val_loaderclass AgeRegressor(nn.Module):"""年龄回归模型"""def __init__(self, pretrained=True):super().__init__()base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT if pretrained else None)self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])self.regressor = nn.Sequential(nn.Linear(base_model.fc.in_features, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 1))def forward(self, x):features = self.feature_extractor(x).flatten(1)return self.regressor(features).squeeze(1)def initialize_training(resume_path=None):"""初始化训练环境"""device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"\n训练设备: {device}")# 初始化模型model = AgeRegressor().to(device)optimizer = optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"], weight_decay=1e-4)criterion = nn.HuberLoss()# 自动处理AMPscaler = torch.cuda.amp.GradScaler(enabled=CONFIG["use_amp"]) if torch.cuda.is_available() else Nonescheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3)# 训练状态start_epoch = 0best_loss = float('inf')no_improve = 0# 断点续训if resume_path and os.path.exists(resume_path):checkpoint = torch.load(resume_path, map_location=device)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])scheduler.load_state_dict(checkpoint['scheduler'])start_epoch = checkpoint['epoch'] + 1best_loss = checkpoint['best_loss']no_improve = checkpoint['no_improve']if scaler and 'scaler' in checkpoint:scaler.load_state_dict(checkpoint['scaler'])print(f"成功恢复训练状态,从第 {start_epoch} 轮开始")return {"device": device,"model": model,"optimizer": optimizer,"criterion": criterion,"scaler": scaler,"scheduler": scheduler,"start_epoch": start_epoch,"best_loss": best_loss,"no_improve": no_improve}def train_epoch(model, device, train_loader, optimizer, criterion, scaler):"""训练单个epoch"""model.train()total_loss = 0.0with tqdm(train_loader, desc="训练", unit="batch") as pbar:for images, labels in pbar:images = images.to(device, non_blocking=True)labels = labels.to(device, non_blocking=True)optimizer.zero_grad(set_to_none=True)# 混合精度训练with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu',enabled=CONFIG["use_amp"]):outputs = model(images)loss = criterion(outputs, labels)# 反向传播if scaler:scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()else:loss.backward()optimizer.step()total_loss += loss.item() * images.size(0)pbar.set_postfix(loss=loss.item())return total_loss / len(train_loader.dataset)def validate(model, device, val_loader, criterion):"""验证循环"""model.eval()total_loss = 0.0with torch.no_grad(), tqdm(val_loader, desc="验证", unit="batch") as pbar:for images, labels in pbar:images = images.to(device, non_blocking=True)labels = labels.to(device, non_blocking=True)outputs = model(images)loss = criterion(outputs, labels)total_loss += loss.item() * images.size(0)pbar.set_postfix(loss=loss.item())return total_loss / len(val_loader.dataset)def save_checkpoint(state, filename, is_best=False):"""保存检查点"""os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)filepath = os.path.join(CONFIG["checkpoint_dir"], filename)# 保存完整状态torch.save(state, filepath)# 保存最佳模型if is_best:best_path = os.path.join(CONFIG["checkpoint_dir"], "model_best.pth")torch.save(state["model"], best_path)def main():# 初始化train_loader, val_loader = create_data_loaders()training_env = initialize_training(CONFIG["resume"])# 解包训练环境device = training_env["device"]model = training_env["model"]optimizer = training_env["optimizer"]criterion = training_env["criterion"]scaler = training_env["scaler"]scheduler = training_env["scheduler"]start_epoch = training_env["start_epoch"]best_loss = training_env["best_loss"]no_improve = training_env["no_improve"]# 训练循环for epoch in range(start_epoch, CONFIG["num_epochs"]):print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")start_time = time.time()# 训练与验证train_loss = train_epoch(model, device, train_loader, optimizer, criterion, scaler)val_loss = validate(model, device, val_loader, criterion)scheduler.step(val_loss)# 统计信息epoch_time = time.time() - start_timelr = optimizer.param_groups[0]['lr']print(f"耗时: {epoch_time // 60:.0f}m{epoch_time % 60:.0f}s | LR: {lr:.1e} | "f"训练损失: {train_loss:.4f} | 验证损失: {val_loss:.4f}")# 保存检查点is_best = val_loss < best_lossif is_best:best_loss = val_lossno_improve = 0else:no_improve += 1checkpoint = {'epoch': epoch,'model': model.state_dict(),'optimizer': optimizer.state_dict(),'scheduler': scheduler.state_dict(),'scaler': scaler.state_dict() if scaler else None,'best_loss': best_loss,'no_improve': no_improve,'config': CONFIG}# 定期保存if is_best or (epoch + 1) % CONFIG["checkpoint_interval"] == 0:save_checkpoint(checkpoint, f"epoch_{epoch + 1}.pth", is_best)# 保存最新检查点save_checkpoint(checkpoint, "latest.pth")# 早停机制if no_improve >= CONFIG["early_stop_patience"]:print(f"\n早停触发: 验证损失连续 {CONFIG['early_stop_patience']} 轮未提升")breakif __name__ == "__main__":# 命令行参数parser = argparse.ArgumentParser()parser.add_argument('--resume', help='恢复训练的检查点路径')args = parser.parse_args()if args.resume:CONFIG["resume"] = args.resume# 设置CUDA基准模式if torch.cuda.is_available():torch.backends.cudnn.benchmark = True# 启动训练main()

http://www.dtcms.com/wzjs/52258.html

相关文章:

  • 自己做的网站出现左右滑动条网络营销优化推广公司
  • 武安网站设计公司无锡百度快速优化排名
  • 大兴区住房和城乡建设部网站南通seo网站优化软件
  • 做网站需要什么框架bt兔子磁力搜索
  • 网站添加新关键词营销策划方案模板
  • 免费空间 网站搜狗seo查询
  • 毕业设计可以做哪些网站除了小红书还有什么推广平台
  • 期货做程序化回测的网站子域名网址查询
  • 石家庄市网站建设培训班网络营销的模式有哪些
  • 江油市规划和建设局网站商务网站建设
  • 昆明做网站网站seo排名培训
  • 物联网系统设计方案2021百度新算法优化
  • 使用bootstrap做网站的视频网页制作软件有哪些
  • 中国建设积分商城网站海外广告投放渠道
  • 嘉兴网站开发seo优缺点
  • 成都单位网站设计精准营销系统价值
  • 网站建设所需服务器费用目前推广平台都有哪些
  • 最专业的车网站建设时事政治2023最新热点事件
  • 高端网站建设怎么报名app注册推广任务平台
  • 无锡网站建设外贸太原seo排名公司
  • 做电子书网站网站关键词优化价格
  • 温州红酒网站建设竞价关键词优化软件
  • wordpress 网页飘窗怎么做网站优化排名
  • 劳务公司网站怎么做搜索关键词软件
  • 网站建设近义词2024年2月疫情又开始了吗
  • 网站建设收费价目表seo优化的主要任务包括
  • 做网站卖东西赚钱在线seo超级外链工具
  • 白银网站建设熊掌号现在感染症状有哪些
  • 网站快速排名技巧网站关键词公司
  • 企业建站划算吗seo网站排名优化公司哪家