当前位置: 首页 > news >正文

Day51 复习日-模型改进

day43对自己找的数据集用简单cnn训练,现在用预训练,加入注意力等

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理
# # 计算均值和方差(仅运行一次)
# def calculate_mean_std(dataloader):
#     mean = torch.zeros(3)
#     std = torch.zeros(3)
#     total_images = 0
#     for images, _ in dataloader:
#         batch_size = images.size(0)
#         images = images.view(batch_size, 3, -1)
#         mean += images.mean(2).sum(0)
#         std += images.std(2).sum(0)
#         total_images += batch_size
#     mean /= total_images
#     std /= total_images
#     return mean, std# # 用无增强的dataloader计算(避免增强影响统计)
# temp_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
# temp_dataset = datasets.ImageFolder(root=your_data_root, transform=temp_transform)
# temp_loader = DataLoader(temp_dataset, batch_size=32, shuffle=False)
# mean, std = calculate_mean_std(temp_loader)
# print(f"数据集均值:{mean},方差:{std}")train_transform = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # 使用ImageNet的均值和方差transforms.Normalize((0.4790, 0.4813, 0.4370), (0.2123, 0.2066, 0.2085))
])test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))transforms.Normalize((0.4790, 0.4813, 0.4370), (0.2123, 0.2066, 0.2085))
])# 2. 加载自定义数据集
full_dataset = datasets.ImageFolder(root=r"BengaliFishImages\fish_images",  transform=train_transform
)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])# 3. 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义注意力机制# SE注意力机制模块
class SEBlock(nn.Module):def __init__(self, channel, reduction=16):super(SEBlock, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)# CBAM注意力机制模块
class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)self.relu1 = nn.ReLU()self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))out = avg_out + max_outreturn self.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)return self.sigmoid(x)class CBAMBlock(nn.Module):def __init__(self, channel, ratio=16, kernel_size=7):super(CBAMBlock, self).__init__()self.channel_attention = ChannelAttention(channel, ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):x = x * self.channel_attention(x)x = x * self.spatial_attention(x)return x# 5. 定义改进的CNN模型(可选择添加SE或CBAM注意力)
class ImprovedCNN(nn.Module):def __init__(self, num_classes=20, attention_type=None):super(ImprovedCNN, self).__init__()self.attention_type = attention_type# 第一个卷积块self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(2, 2)  # 128 -> 64if attention_type == 'se':self.att1 = SEBlock(32)elif attention_type == 'cbam':self.att1 = CBAMBlock(32)# 第二个卷积块self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(2)  # 64 -> 32if attention_type == 'se':self.att2 = SEBlock(64)elif attention_type == 'cbam':self.att2 = CBAMBlock(64)# 第三个卷积块self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.pool3 = nn.MaxPool2d(2)  # 32 -> 16if attention_type == 'se':self.att3 = SEBlock(128)elif attention_type == 'cbam':self.att3 = CBAMBlock(128)# 第四个卷积块self.conv4 = nn.Conv2d(128, 256, 3, padding=1)self.bn4 = nn.BatchNorm2d(256)self.relu4 = nn.ReLU()self.pool4 = nn.MaxPool2d(2)  # 16 -> 8if attention_type == 'se':self.att4 = SEBlock(256)elif attention_type == 'cbam':self.att4 = CBAMBlock(256)# 全连接层self.fc1 = nn.Linear(256 * 8 * 8, 512)self.dropout = nn.Dropout(p=0.5)self.fc2 = nn.Linear(512, num_classes)def forward(self, x):# 卷积块 1x = self.conv1(x)x = self.bn1(x)x = self.relu1(x)x = self.pool1(x)if self.attention_type is not None:x = self.att1(x)# 卷积块 2x = self.conv2(x)x = self.bn2(x)x = self.relu2(x)x = self.pool2(x)if self.attention_type is not None:x = self.att2(x)# 卷积块 3x = self.conv3(x)x = self.bn3(x)x = self.relu3(x)x = self.pool3(x)if self.attention_type is not None:x = self.att3(x)# 卷积块 4x = self.conv4(x)x = self.bn4(x)x = self.relu4(x)x = self.pool4(x)if self.attention_type is not None:x = self.att4(x)# 全连接层x = x.view(-1, 256 * 8 * 8)x = self.fc1(x)x = self.relu4(x)x = self.dropout(x)x = self.fc2(x)return x# 6. 基于预训练模型的分类器
def create_pretrained_model(model_name, num_classes=20, freeze_feature=True, attention_type=None):"""创建基于预训练模型的分类器Args:model_name: 预训练模型名称,如'resnet50', 'vgg16', 'mobilenet_v2'num_classes: 分类类别数freeze_feature: 是否冻结特征提取部分attention_type: 注意力类型,None, 'se' 或 'cbam'Returns:构建好的模型"""if model_name == 'resnet50':model = models.resnet50(pretrained=True)# 冻结特征提取部分if freeze_feature:for param in model.parameters():param.requires_grad = False# 添加注意力机制(可选)if attention_type == 'se':model.layer4[0].conv1 = nn.Sequential(model.layer4[0].conv1,SEBlock(512))elif attention_type == 'cbam':model.layer4[0].conv1 = nn.Sequential(model.layer4[0].conv1,CBAMBlock(512))# 替换最后的全连接层num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))elif model_name == 'vgg16':model = models.vgg16(pretrained=True)if freeze_feature:for param in model.features.parameters():param.requires_grad = False# 添加注意力机制(可选)if attention_type is not None:att_module = SEBlock(512) if attention_type == 'se' else CBAMBlock(512)model.features = nn.Sequential(*list(model.features.children()),att_module)# 替换分类器num_ftrs = model.classifier[6].in_featuresmodel.classifier[6] = nn.Sequential(nn.Linear(num_ftrs, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))elif model_name == 'mobilenet_v2':model = models.mobilenet_v2(pretrained=True)if freeze_feature:for param in model.features.parameters():param.requires_grad = False# 添加注意力机制(可选)if attention_type is not None:att_module = SEBlock(1280) if attention_type == 'se' else CBAMBlock(1280)model.features = nn.Sequential(*list(model.features.children()),att_module)# 替换分类器num_ftrs = model.classifier[1].in_featuresmodel.classifier = nn.Sequential(nn.Dropout(0.2),nn.Linear(num_ftrs, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))else:raise ValueError(f"不支持的模型名称: {model_name}")return model# 7. 训练与测试函数(保持原有功能,略作调整)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):model.train()all_iter_losses = []iter_indices = []train_acc_history = []test_acc_history = []train_loss_history = []test_loss_history = []for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totaltrain_acc_history.append(epoch_train_acc)train_loss_history.append(epoch_train_loss)# 测试阶段model.eval()test_loss = 0correct_test = 0total_test = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testtest_acc_history.append(epoch_test_acc)test_loss_history.append(epoch_test_loss)scheduler.step(epoch_test_loss)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc# 8. 绘图函数(保持不变)
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('训练和测试准确率')plt.legend()plt.grid(True)plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('训练和测试损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 9. 模型训练配置与执行
def main():# 选择模型类型: 'custom' (自定义CNN), 'resnet50', 'vgg16', 'mobilenet_v2'model_type = 'resnet50'  # 可更换为其他模型# 选择注意力机制: None, 'se', 'cbam'attention_type = 'cbam'  # 可更换为其他注意力类型或None# 训练参数epochs = 30  # 预训练模型通常需要更少的epochsnum_classes = 20# 初始化模型if model_type == 'custom':print(f"使用自定义CNN模型,注意力机制: {attention_type}")model = ImprovedCNN(num_classes=num_classes, attention_type=attention_type).to(device)else:print(f"使用预训练{model_type}模型,注意力机制: {attention_type}")# model = create_pretrained_model(#     model_name=model_type,#     num_classes=num_classes,#     freeze_feature=False,  # 设为True表示只训练顶层,False表示微调整个模型#     attention_type=attention_type# ).to(device)# 使用预训练模型,先冻结特征层model = create_pretrained_model(model_name=model_type,num_classes=num_classes,freeze_feature=True,  # 先冻结特征层,只训练顶层attention_type=None  # 禁用注意力).to(device)# 定义损失函数、优化器和学习率调度器criterion = nn.CrossEntropyLoss()# optimizer = optim.Adam(model.parameters(), lr=0.001)# scheduler = optim.lr_scheduler.ReduceLROnPlateau(#     optimizer, mode='min', patience=3, factor=0.5# )# 调整优化器和学习率optimizer = optim.Adam(model.parameters(), lr=1e-4)  # 更小的学习率scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5, min_lr=1e-6)# 开始训练print(f"开始训练...")final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# 保存模型model_filename = f"{model_type}_{attention_type if attention_type else 'no_att'}_fish_model.pth"torch.save(model.state_dict(), model_filename)print(f"模型已保存为: {model_filename}")if __name__ == "__main__":main()

@浙大疏锦行 

 

http://www.dtcms.com/a/267035.html

相关文章:

  • Python 的内置函数 reversed
  • 系统移植基础部分
  • Resource punkt_tab not found. NLTK
  • Docker Desktop 安装到D盘(包括镜像下载等)+ 汉化
  • JxBrowser 7.43.3 版本发布啦!
  • 数据结构---线性表理解(一)
  • 【unitrix】 4.16 类型级别左移运算实现解析(shl.rs)
  • spring-ai-alibaba 1.0.0.2 学习(十)——各种工具调用方式对比
  • Python 闭包(Closure)实战总结
  • 【网络与系统安全】强制访问控制——BLP模型
  • PortSwigger Labs SQLInjection LAB6-7
  • 汽车功能安全【ISO 26262】概述1
  • Python-GUI-wxPython-布局
  • 黑马python(二十五)
  • hello判断
  • 斜线投影几何分割公式 h = rx·ry/(rx+ry) 的推导方法
  • 【github】想fork的项目变为私有副本
  • boost--io_service/io_context
  • FFmpeg 升级指北
  • 【网络与系统安全】强制访问控制——Biba模型
  • AI生成式软件工程正处在从“辅助编程”到“AI原生开发”的范式转移
  • 使用坚果云扩容Zotero同步空间的简单快捷方法
  • Vue3-组件化-Vue核心思想之一
  • Python 中的余数运算及数论中的同余定理
  • 五层协议介绍
  • 指针篇(7)- 指针运算笔试题(阿里巴巴)
  • CSS——圆形头像外嵌光圈
  • springsecurity02
  • js中的FileReader对象
  • ESP32CAM通过ESPHome接入HomeAssistant(含无线刷固件等)