python打卡day44
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18 # 修改导入
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
# 数据预处理
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载CIFAR-10数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
from torchvision.models import resnet18
model = resnet18(pretrained=True)
# 冻结所有卷积层
for param in model.parameters():param.requires_grad = False# 修改最后一层全连接层以适应CIFAR-10的10个类别
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 初始化记录变量
train_losses = []
train_accs = []
test_accs = []
def train(model, train_loader, criterion, optimizer, epochs=10, is_frozen=True):model.train()for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()epoch_loss = running_loss / len(train_loader)epoch_acc = 100 * correct / totaltrain_losses.append(epoch_loss)train_accs.append(epoch_acc)# 测试当前模型test_acc = evaluate(model, test_loader)test_accs.append(test_acc)print(f'Epoch {epoch+1} - Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%, Test Acc: {test_acc:.2f}%')# 如果是最后一轮冻结训练,解冻所有层if is_frozen and epoch == epochs - 1:for param in model.parameters():param.requires_grad = Trueprint("解冻")
# 测试函数
def evaluate(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return 100 * correct / total
# 绘制训练曲线
def plot_training_curves():plt.figure(figsize=(12, 4))plt.subplot(1, 3, 1)plt.plot(train_losses, label='Train Loss')plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.subplot(1, 3, 2)plt.plot(train_accs, label='Train Accuracy')plt.title('Training Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.subplot(1, 3, 3)plt.plot(test_accs, label='Test Accuracy')plt.title('Test Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.tight_layout()plt.show()
# 训练和测试模型
print("冻结卷积层训练5轮")
train(model, train_loader, criterion, optimizer, epochs=5, is_frozen=True)
冻结卷积层训练5轮
Epoch 1/5: 100%|██████████| 782/782 [01:57<00:00, 6.68it/s]
Epoch 1 - Loss: 0.9700, Train Acc: 70.31%, Test Acc: 77.73%
Epoch 2/5: 100%|██████████| 782/782 [01:50<00:00, 7.05it/s]
Epoch 2 - Loss: 0.6483, Train Acc: 78.63%, Test Acc: 79.28%
Epoch 3/5: 100%|██████████| 782/782 [01:50<00:00, 7.05it/s]
Epoch 3 - Loss: 0.6020, Train Acc: 79.80%, Test Acc: 79.40%
Epoch 4/5: 100%|██████████| 782/782 [01:50<00:00, 7.07it/s]
Epoch 4 - Loss: 0.5790, Train Acc: 80.49%, Test Acc: 80.06%
Epoch 5/5: 100%|██████████| 782/782 [01:50<00:00, 7.08it/s]
Epoch 5 - Loss: 0.5636, Train Acc: 80.77%, Test Acc: 80.26%
解冻
print("\n解冻卷积层训练25轮")
train(model, train_loader, criterion, optimizer, epochs=25, is_frozen=False) # 修改为25轮解冻训练
解冻卷积层训练25轮
Epoch 1/25: 100%|██████████| 782/782 [02:43<00:00, 4.78it/s]
Epoch 1 - Loss: 0.2894, Train Acc: 89.91%, Test Acc: 93.00%
Epoch 2/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 2 - Loss: 0.1951, Train Acc: 93.24%, Test Acc: 93.07%
Epoch 3/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 3 - Loss: 0.1205, Train Acc: 95.91%, Test Acc: 93.57%
Epoch 4/25: 100%|██████████| 782/782 [02:41<00:00, 4.83it/s]
Epoch 4 - Loss: 0.0733, Train Acc: 97.44%, Test Acc: 93.46%
Epoch 5/25: 100%|██████████| 782/782 [02:41<00:00, 4.84it/s]
Epoch 5 - Loss: 0.0466, Train Acc: 98.34%, Test Acc: 94.26%
Epoch 6/25: 100%|██████████| 782/782 [02:41<00:00, 4.84it/s]
Epoch 6 - Loss: 0.0273, Train Acc: 99.06%, Test Acc: 94.17%
Epoch 7/25: 100%|██████████| 782/782 [02:41<00:00, 4.84it/s]
Epoch 7 - Loss: 0.0145, Train Acc: 99.57%, Test Acc: 95.00%
Epoch 8/25: 100%|██████████| 782/782 [02:41<00:00, 4.84it/s]
Epoch 8 - Loss: 0.0093, Train Acc: 99.73%, Test Acc: 94.70%
Epoch 9/25: 100%|██████████| 782/782 [02:41<00:00, 4.84it/s]
Epoch 9 - Loss: 0.0066, Train Acc: 99.80%, Test Acc: 95.25%
Epoch 10/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 10 - Loss: 0.0012, Train Acc: 99.99%, Test Acc: 95.42%
Epoch 11/25: 100%|██████████| 782/782 [02:41<00:00, 4.84it/s]
Epoch 11 - Loss: 0.0002, Train Acc: 100.00%, Test Acc: 95.57%
Epoch 12/25: 100%|██████████| 782/782 [02:40<00:00, 4.86it/s]
Epoch 12 - Loss: 0.0001, Train Acc: 100.00%, Test Acc: 95.50%
Epoch 13/25: 100%|██████████| 782/782 [02:40<00:00, 4.86it/s]
Epoch 13 - Loss: 0.0001, Train Acc: 100.00%, Test Acc: 95.54%
Epoch 14/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 14 - Loss: 0.0001, Train Acc: 100.00%, Test Acc: 95.51%
Epoch 15/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 15 - Loss: 0.0001, Train Acc: 100.00%, Test Acc: 95.50%
Epoch 16/25: 100%|██████████| 782/782 [02:40<00:00, 4.86it/s]
Epoch 16 - Loss: 0.0001, Train Acc: 100.00%, Test Acc: 95.52%
Epoch 17/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 17 - Loss: 0.0001, Train Acc: 100.00%, Test Acc: 95.50%
Epoch 18/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 18 - Loss: 0.0001, Train Acc: 100.00%, Test Acc: 95.49%
Epoch 19/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 19 - Loss: 0.0000, Train Acc: 100.00%, Test Acc: 95.50%
Epoch 20/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 20 - Loss: 0.0000, Train Acc: 100.00%, Test Acc: 95.51%
Epoch 21/25: 100%|██████████| 782/782 [02:41<00:00, 4.84it/s]
Epoch 21 - Loss: 0.0000, Train Acc: 100.00%, Test Acc: 95.50%
Epoch 22/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 22 - Loss: 0.0000, Train Acc: 100.00%, Test Acc: 95.50%
Epoch 23/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 23 - Loss: 0.0000, Train Acc: 100.00%, Test Acc: 95.51%
Epoch 24/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 24 - Loss: 0.0000, Train Acc: 100.00%, Test Acc: 95.51%
Epoch 25/25: 100%|██████████| 782/782 [02:41<00:00, 4.85it/s]
Epoch 25 - Loss: 0.0000, Train Acc: 100.00%, Test Acc: 95.52%
# 绘制训练曲线
plot_training_curves()
@浙大疏锦行