当前位置: 首页 > news >正文

Pytorch超分辨率模型实现与详细解释

下面我提供一个完整Pytorch超分辨率模型实现每一行代码进行详细解释包括所有引用头文件

import torch #导入pytorch用于构建训练神经网络主要框架

import torch.nn as nn #导入pytorch神经网络模块-包含各种神经网络函数

import torch.nn.functional as F #导入Pytorch神经网络函数模块--包含激活函数损失函数

import torch.utils.data import DataLoader #导入pytorch数据加载工具用于创建管理数据加载

from torchvision import datasets, transforms #导入torchvision数据集和变换模块--提供常用数据图像预处理方法

import matplotlib.pyplot as plt #导入matplotlib pyplot模块-用于绘制图表可视化结果

import numpy as np #导入numpy用于数值计算特别是在处理图像数据

import os #导入操作系统接口模块用于处理文件目录路径

import time #导入时间模块--用于测量训练时间

设置设备GPU如果可用否则CPU

torch.cuda.is_avaiable() 检查当前系统是否可用CUDA GPU

#如果使用GPU加速计算否则使用CPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1 定义ESPCN模型

class ESPCN(nn.Module):

#初始化方法定义模型结构

def __init__(self, upscale_factor, num_channels=1):

初始化ESPCN模型

参数

upscale_factor 放大倍数

num_channels 输入图像通道数默认为1(灰度图)

#调用nn.Module 初始化方法

super(ESPCN, self).__init__()

#第一个卷积层提取特征

nn.Conv2d: 2D卷积层用于处理图像数据

参数输入通道数输出通道数卷积核大小填充大小

这里使用5x5卷积核填充2保持空间尺寸不变

self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=5, padding=2)

#第二卷积进一步处理特征

输入64通道输出32通道3x3卷积核填充1保持尺寸

self.conv2 = nn.Conv2d(64,32,kernel_size=3,padding=1)

#最后一个卷积层生成放大特征图

输出通道数num_channels (upscale_factor *2)

这是因为我们像素pixel_shuffle 来提升分辨率

self.conv3 = nn.Conv2d(32, num_channels (upscale_factor *2), kernel_szie=3, padding=1)

#像素操作子像素卷积层

pixelshuffle 形状 C * r^2, H, W张量

重新排列

self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

#定义向前传播过程描述数据如何通过网络各层

def forward(self, x):

向前传播

参数

x 输入第分辨率图像形状为(batch_size, num_channels, height, width)

返回 分辨率图像形状为(batch_szie, num_channels, height upscale_factor, width upscale_factor)

#第一层卷积使用tanh激活函数

torch.tanh 正切激活函数压缩(-1,1)范围

x = torch.tanh(self.conv1(x))

#第二层卷积后使用tanh激活函数

x = torch.tanh(self.conv2(x))

#第三层卷积

x = self.conv3(x)

#应用像素操作通道维度转换为空间维度

x = self.pixel_shuffle(x)

#使用sigmoid激活函数压缩到(0,1)范围

#这是因为图像像素值通常0-1之间

x = torch.sigmoid(x)

return x

2 准备数据

def prepare_data(batch_szie, upscale_factor, dataset_name='MNIST'):

准备训练预测数据

参数

batch_szie 批处理大小

upscale_factor 放大倍数

dataset_name 数据集名称默认为MNIST

返回

训练测试数据加载

数据转换管道

transforms.Compose 多个变换组合在一起

transform = transforms.Compose([

#transforms.ToTensor PIL图像或者numpy数组转换Pytorch张量

#同时 像素值[0.255]缩放[0,1]范围

transforms.ToTensor(),

#transforms.Normalize 张量进行标准化

#参数均值标准差这里标准化[-1,1]范围

transforms.Normalize(0.5, (0.5,))

])

#根据数据集名称选择不同数据集

if dataset_name == 'MNIST':

#下载加载MNIST训练数据集

#MNIST是一个手写数字数据集包含60 000训练样本10 000测试样本

train_dataset = datasets.MNIST(

root='./data' #数据存储路径

train=True, #加载训练集

download=True,#如果数据不存在下载

transform = transform #应用上面定义数据转换

)

#下载并加载MNIST测试数据集

test_dataset = datasets.MNIST(

root='./data',

train=False, #加载测试集

download=True,

transform=transform

)

else:

#可以在这里添加其他数据集支持

raise ValueError("不支持的数据集:{dataset_name}")

#创建训练数据加载起

DataLoader 包装数据集提供批量加载shuffling 功能

train_loader = DataLoader(

train_dataset,

batch_size = batch_size, #每个批次样本数量

shuffle=True, 每个epoch 开始打乱数据顺序

num_works = 2, 使用2子进程加载数据

pin_memory=True #数据固定内存中加速GPU传输

)

#创建测试数据加载

test_loader = DataLoader(

test_dataset,

batch_size=batch_size,

shuffle=False, #测试不需要打乱顺序

num_works=2,

pin_memory=True

)

#返回训练测试数据加载

return train_laoder, test_loader

3 训练函数

def train_model(model, train_loader, criterion, optimizer, num_epochs, upscale_factor):

训练模型

参数

model:要训练模型

train_loader: 训练数据加载

criterion 损失函数

optimizer: 优化

num_epochs: 训练

upscale_factor: 放大倍数

#设置模型训练模式

#启用dropoutbatch normalization 训练特定行为

model.train()

#记录训练过程损失值

losses = []

#记录训练开始时间

start_time = time.time()

#循环遍历每个epoch

for epoch in range(num_epochs):

#初始化当前epoch损失

epoch_loss = 0

#遍历训练数据加载起每个批次

for batch_idx, (data, target) in enumerate(train_loader):

#数据移动到相应设备 GPU或者CPU

data = data.to(device)

#创建分辨率输入

#首先图像下采样然后上采样原始大小模拟分辨率图像

#F.interpolate: 图像进行采样或者下采样

#scale_factor= 1/upscale_factor 下采样比例

mode = 'bicubic' 使用双三次循环算法

align_corners= False: 差值算法参数

lr_data = F.interpolate(

data,

scale_factor=1/upscale_factor,

mode='bicubic',

align_corners = False

)

#下采样图像采样原始尺寸

lr_data = F.interpolate(

lr_data,

scale_factor = upscale_factor,

mode = 'bicubic',

align_corners=False

)

#清零梯度

pytorch 梯度累加所以在每个批次开始需要清零

optimizer.zero_grad()

#前向传播降低分辨率图像输入模型得到分辨率输出

output = model(lr_data)

#计算损失比较模型输出原始分辨率图像

loss = criterion(output, data)

#反向传播计算梯度

loss.backward()

#更新权重根据梯度调整模型参数

optimizer.step()

#累加当前批次损失

epoch_loss += loss.item()

#计算当前epoch平均损失

losses.append(avg_loss)

#打印训练进度

if (epoch + 1) % 5 == 0:

#计算已用时间

elapsed_time = time.time() - start_time

#打印当前epoch,epoch损失值已用时间

#训练绘制损失曲线

plt.figure(figsize=(10,5))

plt.plot(losses)

plt.title('Training loss over epochs')

plt.xlabel('Epoch')

plt.ylabel('Loss')

plt.grid(True)

#保存损失曲线图像

plt.savefig('training_loss.png')

plt.show()

#打印训练时间

total_time = time.time() - start_time

4 测试函数

#定义模型测试函数

def test_model(model, test_loader, upscale_factor, num_examples=5):

测试模型显示结果

参数

model: 要测试模型

test_loader: 测试数据加载器

upscale_factor: 放大倍数

num_examples: 显示示例数量

#设置模型评估模式

#禁用dropoutbatch normalization训练特定行为

model.eval()

#初始化示例计数器

examples_shown = 0

#不计算梯度节省内存计算资源

with torch.no_grad():

#遍历测试数据加载器

for i, (data, target) in enumerate(test_loader):

#如果已经显示了足够示例退出循环

if examples_shown >= num_examples:

break

#数据移动到相应设备

data = data.to(device)

#创建分辨率输入(与训练时间相同的方法)

lr_data = F.interpolate(

data,

scale_factor = 1/upscale_factor,

model='bicubic',

align_corners = False

)

lr_data = F.interpolate(

lr_data,

scale_factor = upscale_factor,

mode = 'bicubic',

align_corners = False

)

#生成分辨率图像

hr_output = model(lr_data)

#图像CPU转换为numpy数组以便显示

lr_image = lr_data[0].cpu().sequeeze().numpy()

hr_image = hr_output[0.cpu().sequeeze().numpy()

original_image = data[0].cpu().squeeze().numpy()

#显示结果

plt.figure(figsize=(12,4))

#显示分辨率输入图像

plt.subplot(1,3,1)

plt.imshow(lr_image, cmap='gray')

plt.title('Low Resolution Input')

plt.axis('off')

#显示分辨率输出图像

plt.subplot(1,3,2)

plt.imshow(hr_iamge, cmap='gray')

plt.title('Super Resolution Output')

plt.axis('off')

#显示原始高分辨率图像

plt.subplot(1,3,3)

plt.imshow(original_image, cmap='gray')

plt.title('Original high Resolution')

plt.axis('off')

#保存对比图像

plt.savefig(f'comparsion_example_{examples_shown+1}.png')

plt.show()

#增加示例计数器

examples_shown += 1

5 计算PSNR指标函数

def calculate_psnr(model, test_loader, upscale_factor):

计算模型峰值信噪比PSNR

参数model评估模型

test_loader :测试数据加载器

upscale_factor 放大倍数

返回平均PSNR

#设置模型评估模式

moel.eval()

#初始化PSNR总和样本计数

total_psnr=0.0

total_samples=0

#不计算梯度

with torch.no_grad():

#遍历测试数据加载器

for data, _ in test_loader:

#数据移动到相应设备

data = data.to(device)

#创建分辨率输入

lr_data = F.interpolate(

data,

scale_factor=1/upscale_factor,

mode = 'bicubic',

align_corners = False

)

lr_data = F.interpolate(

lr_data,

scale_factor = upscale_factor,

mode='bicubic',

align_corners=False

)

#生成分辨率图像

hr_output=model(lr_data)

#计算每个样本PSNR

for i in range(data.size(0)):

#张量转换numpy数组

original=data[i].cpu().numpy()

reconstructed=hr_output[i].cpu().numpy()

#计算均方差误差(MSE)

mse = np.mean((original - reconstructed) ** 2)

#避免除以

if mse == 0:

psnr = 100 #无穷PSNR这里100

else :

#计算PSNR 20 log10(MAX) - 10 log10(MSE)

#对于[0,1]范围图像 MAX = 1

psnr=20 np.log10(1.0) - 10 np.log10(mse)

#累加PSNR

total_psnr += psnr

total_samples += 1

#计算平均PSNR

avg_psnr = total_psnr / total_samples

return avg_psnr

6 函数

#定义函数组织整个训练和测试流程

def main():

#参数设置

upscale_factor=2 #放大倍数

num_epochs = 20 #训练

batch_szie = 64 #批处理大小

learning_rate = 0.001 学习率

#创建输出目录如果不存在

if not os.path.exists('results'):

os.makedirs('results')

#准备数据

train_loader, test_loader = prepare_data(batch_size, upscale_factor)

#初始化模型

model = ESPCN(upscale_factor=upscale_factor).to(device)

#打印模型结构

print(model)

#计算模型参数数量

total_params = sum(p.numel() for p in model.parameters())

#定义损失函数 - 均方误差损失

criterion = nn.MSELoss()

#定义优化器Adam优化器

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

#训练模型

train_model(model, train_loader, criterion, optimizer, num_epochs, upscale_factor)

#测试模型

test_model(model, test_loader, upscale_factor)

#计算PSNR

calculate_psnr(model, test_loader, upscale_factor)

#保存模型

model_path='results/espcn_model.pth'

torch.save(model.state_dcit(), model_path)

if __name__=="__main__":

main()

头文件解释总结

1 torch:pytorch提供张量操作自动求导功能

2 torch.nn:pytorch神经网络模块包含各种损失函数

3 torch.nn.functional:pytorch函数接口包含激活函数损失函数

4 torch.utils.data pytorch视觉提供常用数据集图像变换

5 matplotlib.pyplot会图库用于可视化结果

6 torchvision pytorch视觉提供常用数据集图像变换

7 numpy 数据计算用于处理数组数据

8 os 操作系统接口用于处理文件目录

9 time时间模块用于测量运动时间

http://www.dtcms.com/a/357629.html

相关文章:

  • Linux内核进程管理子系统有什么第三十八回 —— 进程主结构详解(34)
  • 叠叠问题解决
  • iPaaS实施的前提是先进行集成关系的梳理
  • 从自定义日期类角度解析运算符重载,友元函数(friend)
  • AI助力PPT创作:秒出PPT与豆包AI谁更高效?
  • 实现动态数组
  • 【NJU-OS-JYY笔记】操作系统:设计与实现
  • 【开题答辩全过程】以 基于Vue Spring Boot的教师资格证考试助力系统设计与实现为例,包含答辩的问题和答案
  • 黑客之都CSP-J模拟赛题解
  • C6.6:交流参量、电压增益、电流增益的学习
  • 企业级-搭建CICD(持续集成持续交付)实验手册
  • 【面试场景题】三阶段事务提交比两阶段事务提交的优势是什么
  • TypeScript: Symbol.iterator属性
  • 蓝蜂蓝牙模组:破解仪器仪表开发困境
  • 打通安卓、苹果后,小米澎湃OS 3又盯上了Windows
  • 【系列05】端侧AI:构建与部署高效的本地化AI模型 第4章:模型量化(Quantization)
  • AntSK知识库多格式导入技术深度解析:从文档到智能,一站式知识管理的技术奇迹
  • 第十二节 Spring 注入集合
  • 零知识证明的刑事证据困境:隐私权与侦查权的数字博弈
  • Windows 11 跳过 OOBE 的方法和步骤
  • 企业级数据库管理实战(二):数据库权限最小化原则的落地方法
  • 现状摸底:如何快速诊断企业的“数字化健康度”?
  • 嵌入式Linux驱动开发 - 蜂鸣器驱动
  • 25.8.29_NSSCTF——[BJDCTF 2020]Easy_WP
  • VeOmni 全模态训练框架技术详解
  • 深入理解Go 与 PHP 在参数传递上的核心区别
  • 变量声明方式
  • 嵌入式第四十一天(数据库)
  • 海量小文件问题综述和解决攻略(二)
  • C++ DDS框架学习