当前位置: 首页 > news >正文

站长之家商城太原建设网站制作

站长之家商城,太原建设网站制作,共享ip做网站,求推荐专业的网站建设开发如果想在加载预训练权重后,对所有层的参数都进行继续训练(Fine‑tuning),只需要做两件事: 不要冻结任何参数,也就是保证所有参数的 requires_gradTrue; 在定义优化器时,将它指向 mo…

如果想在加载预训练权重后,对所有层的参数都进行继续训练(Fine‑tuning),只需要做两件事:

  1. 不要冻结任何参数,也就是保证所有参数的 requires_grad=True

  2. 在定义优化器时,将它指向 model.parameters() 而不是仅最后一层。

下面以CIFAR10经典数据集为案例,对torch,hub加载预训练模型进行自定义训练进行一个简单的讲解:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset# 0. 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 1. 使用 torch.hub 加载预训练模型
try:model = torch.hub.load('pytorch/vision:stable', 'resnet18', pretrained=True)
except Exception:print("Failed to load stable, trying v0.13.0")model = torch.hub.load('pytorch/vision:v0.13.0', 'resnet18', pretrained=True)# print("Original model structure:")
# print(model)# 2. 修改模型参数/结构 (如果需要)
# 假设我们的新任务有 10 个类别 (例如 CIFAR-10)
num_classes_new = 10# 获取原始全连接层的输入特征数
num_ftrs = model.fc.in_features
# 用新的全连接层替换 (即使你想训练所有层,如果类别数变了,这一步通常是必要的)
model.fc = nn.Linear(num_ftrs, num_classes_new)
print(f"\nReplaced fc layer. New fc layer: {model.fc}")
# 新添加的 model.fc 层的参数默认 requires_grad=True# 3. 确保所有参数都可训练
# 当加载预训练模型 (pretrained=True) 时,默认所有参数的 requires_grad 就是 True。
# 当你替换一个层时 (如 model.fc = nn.Linear(...)),新层的参数默认也是 requires_grad=True。
# 所以,理论上你可能不需要显式设置。
# 但为了确保,或者如果你之前有代码冻结了层,可以显式地将所有参数的 requires_grad 设为 True:
for param in model.parameters():param.requires_grad = True# 检查哪些参数是可训练的 (应该会打印出模型的所有参数名)
print("\nTrainable parameters (should be all parameters):")
trainable_params_count = 0
for name, param in model.named_parameters():if param.requires_grad:# print(name) # 取消注释以查看所有可训练参数的名称trainable_params_count += param.numel()
print(f"Total trainable parameters: {trainable_params_count}")model = model.to(device) # 将模型移动到设备# 4. 准备数据 (与之前相同)
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])try:train_dataset_full = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)test_dataset_full = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
except Exception as e:print(f"Failed to download CIFAR10 automatically: {e}")print("Using dummy data for demonstration as CIFAR10 download failed.")class DummyDataset(torch.utils.data.Dataset):def __init__(self, num_samples, transform):self.num_samples = num_samplesself.transform = transformdef __len__(self):return self.num_samplesdef __getitem__(self, idx):dummy_image = torch.rand(3, 32, 32)dummy_label = torch.randint(0, num_classes_new, (1,)).item()pil_image = transforms.ToPILImage()(dummy_image)transformed_image = self.transform(pil_image)return transformed_image, dummy_labeltrain_dataset_full = DummyDataset(1000, transform)test_dataset_full = DummyDataset(200, transform)train_subset_indices = list(range(0, min(1000, len(train_dataset_full))))
train_dataset = Subset(train_dataset_full, train_subset_indices)
test_subset_indices = list(range(0, min(200, len(test_dataset_full))))
test_dataset = Subset(test_dataset_full, test_subset_indices)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)# 5. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()# 关键:将模型的所有参数传递给优化器
# 对于完全微调,学习率通常需要设置得比从头训练或只训练分类头时更小,
# 因为预训练的特征已经比较好了,我们不希望用太大的学习率破坏它们。
# 常见的值可能是 1e-4, 1e-5 等。
optimizer = optim.Adam(model.parameters(), lr=1e-4) # 或者使用 optim.SGD# (可选) 使用不同的学习率进行微调 (Discriminative Learning Rates)
# 有时,你会希望对预训练的层使用更小的学习率,对新添加的层使用较大的学习率。
# params_to_update = []
# # 预训练的卷积层和BN层
# for name, param in model.named_parameters():
#     if "fc" not in name: # 假设除了fc层都是预训练的
#         params_to_update.append({'params': param, 'lr': 1e-5}) # 较小的学习率
# # 新的fc层
# params_to_update.append({'params': model.fc.parameters(), 'lr': 1e-3}) # 较大的学习率
# optimizer = optim.Adam(params_to_update)# 6. 再次训练(微调)模型
num_epochs = 3 # 训练轮数,为了演示设为较小值
print(f"\nStarting full fine-tuning for {num_epochs} epochs...")for epoch in range(num_epochs):model.train() # 设置模型为训练模式running_loss = 0.0correct_train = 0total_train = 0for i, (inputs, labels) in enumerate(train_loader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)_, predicted = torch.max(outputs.data, 1)total_train += labels.size(0)correct_train += (predicted == labels).sum().item()if (i + 1) % 20 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')epoch_loss = running_loss / len(train_dataset)epoch_acc = 100 * correct_train / total_trainprint(f'Epoch [{epoch+1}/{num_epochs}] completed. Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_acc:.2f}%')print('Finished Training')# 7. (可选) 评估模型
model.eval()
correct_test = 0
total_test = 0
with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total_test += labels.size(0)correct_test += (predicted == labels).sum().item()accuracy = 100 * correct_test / total_test
print(f'\nAccuracy of the model on the {len(test_dataset)} test images: {accuracy:.2f}%')# 8. (可选) 保存微调后的模型
# torch.save(model.state_dict(), 'resnet18_fully_finetuned.pth')
# print("Fully finetuned model saved as resnet18_fully_finetuned.pth")

代码解释:

  1. 设置设备: 自动选择 GPU (如果可用) 或 CPU。

  2. 加载预训练模型:

    • torch.hub.load('pytorch/vision:stable', 'resnet18', pretrained=True) 从 PyTorch Hub 加载 ResNet18 模型,并下载预训练在 ImageNet 上的权重。stable 通常指向最新的稳定版本。如果遇到问题,可以指定一个如 'v0.13.0' 的具体版本。

  3. 修改模型:

    • 替换最后一层: 这是最常见的迁移学习策略。我们获取原始 ResNet18 fc 层的输入特征数 num_ftrs,然后用一个新的 nn.Linear 层替换它,这个新层的输出维度是我们新任务的类别数 num_classes_new。

    • 修改参数值 (示例): 虽然在这个场景下我们替换了层,但如果你想修改现有层(比如 model.layer4[0].conv1.weight)的参数值,可以直接访问 param.data 并进行修改,例如使用 nn.init 中的函数重新初始化。

    • 冻结参数:

      • for param in model.parameters(): param.requires_grad = False:遍历模型的所有参数,并将它们的 requires_grad 属性设置为 False。这样,在反向传播时,这些参数的梯度就不会被计算,它们的权重也不会在优化过程中更新。

      • for param in model.fc.parameters(): param.requires_grad = True:只解冻我们新添加的 fc 层的参数,使其可训练。

      • 更细致的解冻: 如果你想微调更多层,比如 ResNet 的 layer4,你可以类似地解冻它:

        # for param in model.layer4.parameters():
        #     param.requires_grad = True

        content_copydownload

        Use code with caution.Python
  4. 准备数据:

    • 使用 torchvision.transforms 来预处理图像,使其符合 ResNet18 的输入要求(尺寸、归一化)。

    • 使用 torchvision.datasets.CIFAR10 加载数据集(如果下载失败,则使用虚拟数据)。

    • Subset 用于从完整数据集中取一小部分进行快速演示。

    • DataLoader 用于创建数据批次。

  5. 定义损失函数和优化器:

    • nn.CrossEntropyLoss 是分类任务常用的损失函数。

    • optim.Adam 是一个常用的优化器。

    • 关键: filter(lambda p: p.requires_grad, model.parameters()) 确保优化器只更新那些 requires_grad=True 的参数。如果你确定只有 model.fc 是可训练的,也可以直接用 optim.Adam(model.fc.parameters(), lr=0.001)。

  6. 再次训练 (微调):

    • 标准的 PyTorch 训练循环。

    • model.train() 将模型设置为训练模式(这会影响像 Dropout 和 BatchNorm 这样的层)。

    • 前向传播、计算损失、反向传播、更新权重。

  7. 评估模型:

    • model.eval() 将模型设置为评估模式。

    • with torch.no_grad(): 禁用梯度计算,以节省内存和计算。

  8. 保存模型:

    • torch.save(model.state_dict(), 'path/to/your/model.pth') 保存模型的状态字典(推荐方式)。


关键点说明

  • 所有层都参与反向传播
    默认从 torch.hub.load(..., pretrained=True) 得到的模型,其所有参数 requires_grad=True。只要不手动将某些层设为 False,优化器就会更新它们。

  • 优化器的参数列表

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    

    这里传入的是 model.parameters(),它包含了整张网络的可训练参数;如果只想训练最后一层,才会写成 model.fc.parameters()

  • 学习率选择
    微调(Fine‑tuning)全模型时,通常需使用比训练头部更小的学习率(如 1e-4 或更低),以免预训练权重被大步修改。

  • 冻结与否的对比

    • 冻结:param.requires_grad = False

    • 微调全模型:不设置或显式 param.requires_grad = True

http://www.dtcms.com/a/417247.html

相关文章:

  • 自己的网站怎么开微站是什么东西
  • 网上书城网站建设目的自己服务器做网站如何备案
  • 北京建设教育协会的网站wordpress修改插件路径
  • 网站做后台产品内页设计
  • 广州市建设工程造价站网站小浣熊做单网站
  • 网站建设选择题题库制作一个网站的基本步骤
  • 网站建设最难的部分电影vip网站建设步骤
  • 宁波网络建站海拉尔建网站
  • 复兴网站制作文章页模板wordpress
  • 怎么通过做网站来赚钱设计网站教程
  • 学校网站建设运行简介手机网站源码 html5
  • 南京建行网站专业的网站开发公司电话
  • 杭州设计 公司 网站建设报名系统
  • php网站开发开题报告如果做二手车网站
  • 深圳网站建设的客户在哪里上海网站群建设
  • 广州协会网站建设网页游戏排行榜逃
  • 公共资源交易中心网站建设汇报做网站维护的人叫啥
  • 数码电子产品网站建设策划书怎样做电商卖货
  • html企业网站系统建设网站经验
  • 牛股大转盘网站建设珠海做快照网站电话
  • 引导企业做网站鹿泉微信网站建设
  • 广东备案网站文化网站前置审批
  • 长沙网站优化外包服务下载中国移动商旅100最新版本
  • logo在线设计生成器下载宁波seo关键词优化教程
  • 影视自助建站php ajax网站开发典型实例 pdf
  • 网站建设公司石家庄微信小程序制作过程
  • 国外做兼职网站用层还是表格做网站快
  • 网站制作价格东莞无需本金十分钟赚800
  • 通用网址通用网站查询asp绿色简洁通用型企业网站源码
  • 外贸网站建设渠道网站ui设计例子