当前位置: 首页 > news >正文

超分辨率重建(Super-Resolution, SR)

1. 超分辨率的任务定义

  • 输入:低分辨率图像 LR

  • 输出:高分辨率图像 HR

  • 本质:让模糊图变清晰,恢复纹理和细节

超分辨率(SR)就是把低分辨率图像恢复成高分辨率图像,让图像更清晰,同时尽可能恢复真实、合理的纹理细节。

2. SR 的评价指标

  • PSNR:像素误差低(值高)

  • SSIM:结构相似度高(越接近 1 越好)

2.1 PSNR

3. 为什么插值方法(上采样)无法做到高质量超分辨率

首先,插值方法(如双线性、双三次)只是根据邻近像素做数学推断,它不会“理解”图像内容,因此只能生成平滑的过渡,无法真正恢复丢失的细节。

其次,低分辨率图像中本来就缺失高频纹理(如纹理、细线、毛发等),插值无法凭空生成,只能把现有像素“拉伸”,导致图像变得模糊。

最后,不同区域的纹理结构非常复杂,插值方法无法区分边缘、平坦区、纹理区,因此容易产生伪影或边缘模糊,而深度学习才能通过大量样本学习规律化的细节恢复能力。

https://cloud.tencent.com/developer/article/2095690https://cloud.tencent.com/developer/article/2095690

4. SRCNN

4.1 概念

  • SRCNN 是 2014 年提出的最简单的深度学习超分模型,思想直观:
    先用传统插值把 LR 放大到目标尺寸(bicubic)→ 再用 3 层卷积网络修复细节。

  • 网络结构(最常见配置):

    1. Conv1: 大核(9×9),64 个通道 —— 特征提取(抓大的局部结构)

    2. Conv2: 1×1,32 个通道 —— 非线性映射(把特征从低维映射到高维)

    3. Conv3: 5×5,输出通道 = 图像通道(1 或 3) —— 重建图像

  • 激活:前两层后面接 ReLU,最后一层不接激活(直接回归像素值)

  • 训练目标:最简单常用的是 L2(MSE)或 L1 损失,衡量输出与 HR 的像素差

  • 优点:结构极其简单、便于理解和复现;

  • 缺点:需要先插值放大(计算浪费),恢复能力相对现代方法较弱

4.2 实践

4.2.1 带 TODO 的 PyTorch 代码框架

# srcnn_skeleton.pyimport os
import random
from glob import glob
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm# ============================
# 1. 数据集模块
# ============================
class SRDataset(Dataset):"""HR -> LR -> 上采样后的网络输入"""def __init__(self, root, scale=2, patch_size=48):self.root = rootself.files = glob(os.path.join(root, "*.png")) + glob(os.path.join(root, "*.jpg"))self.scale = scaleself.patch_size = patch_sizeself.to_tensor = transforms.ToTensor()def __len__(self):return len(self.files)def _random_crop(self, img):# TODO: 随机裁剪 patchraise NotImplementedErrordef __getitem__(self, idx):# TODO: 打开 HR 图像# TODO: 随机裁剪# TODO: 下采样生成 LR# TODO: 上采样回 HR 大小# TODO: 转 tensor 返回raise NotImplementedError# ============================
# 2. 模型模块
# ============================
class SRCNN(nn.Module):"""SRCNN 网络"""def __init__(self, in_channels=3):super().__init__()# TODO: 定义三层卷积raise NotImplementedErrordef forward(self, x):# TODO: 实现 forward# conv1 -> relu -> conv2 -> relu -> conv3 -> 输出 clampraise NotImplementedError# ============================
# 3. 工具函数模块
# ============================
def calc_psnr(sr, hr, shave_border=0):# TODO: 实现 PSNR 计算raise NotImplementedError# ============================
# 4. 训练循环模块
# ============================
def train_one_epoch(model, loader, criterion, optimizer, device):# TODO: 实现训练一轮raise NotImplementedErrordef validate(model, loader, device):# TODO: 实现验证并计算平均 PSNRraise NotImplementedError# ============================
# 5. 主函数
# ============================
def main():data_dir = "./BSD100"  # TODO: 修改为你的 BSD100 数据集路径scale = 2batch_size = 8epochs = 30lr = 1e-4patch_size = 48device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# TODO: 创建 dataset 和 dataloaderraise NotImplementedError# TODO: 初始化 SRCNN 模型raise NotImplementedError# TODO: 定义损失函数和优化器raise NotImplementedError# TODO: 训练循环 + 验证 + 保存最优模型raise NotImplementedErrorif __name__ == "__main__":main()

4.2.2 数据集说明

使用的是BSD100数据集

只有100张,主要是学习SRCNN,于是选择用比较小的数据https://aistudio.baidu.com/datasetdetail/99299?login_type=weixinhttps://aistudio.baidu.com/datasetdetail/99299?login_type=weixin

4.2.3 实现

# srcnn_train.pyimport os
import random
from glob import glob
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np# ============================
# 1. 数据集模块
# ============================
class SRDataset(Dataset):"""HR -> LR -> 上采样后的网络输入"""def __init__(self, root, scale=2, patch_size=48):self.root = rootself.files = glob(os.path.join(root, "*.png")) + glob(os.path.join(root, "*.jpg"))self.scale = scaleself.patch_size = patch_sizeself.to_tensor = transforms.ToTensor()def __len__(self):return len(self.files)def _random_crop(self, img):# TODO: 随机裁剪 patchw,h= img.sizetop= random.randint(0,h - self.patch_size)left= random.randint(0,w - self.patch_size)return img.crop((left,top,left + self.patch_size,top + self.patch_size))def __getitem__(self, idx):# TODO: 打开 HR 图像# TODO: 随机裁剪# TODO: 下采样生成 LR# TODO: 上采样回 HR 大小# TODO: 转 tensor 返回img=Image.open(self.files[idx]).convert("RGB")img= self._random_crop(img)lr_img=img.resize((img.width // self.scale, img.height // self.scale), Image.NEAREST)lr_up=lr_img.resize((img.width, img.height), Image.NEAREST)return self.to_tensor(lr_up), self.to_tensor(img)raise NotImplementedError# ============================
# 2. 模型模块
# ============================
class SRCNN(nn.Module):"""SRCNN 网络"""def __init__(self, in_channels=3):super().__init__()# TODO: 定义三层卷积self.conv1=nn.Conv2d(in_channels, 64, kernel_size=9, padding=4)self.conv2=nn.Conv2d(64, 32, kernel_size=1, padding=0)self.conv3=nn.Conv2d(32, in_channels, kernel_size=5, padding=2)self.relu=nn.ReLU()def forward(self, x):# TODO: 实现 forward# conv1 -> relu -> conv2 -> relu -> conv3 -> 输出 clampx=self.relu(self.conv1(x))x=self.relu(self.conv2(x))x=self.conv3(x)return torch.clamp(x, 0.0, 1.0)#确保输出的张量 x 中的所有值都在 0.0 到 1.0 之间#torch.clamp(input, min, max) 会将 input 中小于 min 的值设置为 min,将大于 max 的值设置为 max。raise NotImplementedError# ============================
# 3. 工具函数模块
# ============================
def calc_psnr(sr, hr, shave_border=0):sr = sr.transpose(0,2,3,1)[0] if sr.ndim==4 else sr.transpose(1,2,0)hr = hr.transpose(0,2,3,1)[0] if hr.ndim==4 else hr.transpose(1,2,0)if shave_border > 0:sr = sr[shave_border:-shave_border, shave_border:-shave_border, :]hr = hr[shave_border:-shave_border, shave_border:-shave_border, :]mse = np.mean((sr - hr) ** 2)if mse == 0:return 100return 20 * np.log10(1.0 / np.sqrt(mse))# ============================
# 4. 训练循环模块
# ============================
def train_one_epoch(model, loader, criterion, optimizer, device):# TODO: 实现训练一轮model.train()running_loss=0.0for lr, hr in tqdm(loader):lr, hr = lr.to(device), hr.to(device)optimizer.zero_grad()pred = model(lr)loss = criterion(pred, hr)loss.backward()optimizer.step()running_loss += loss.item()return running_loss / len(loader)raise NotImplementedErrordef validate(model, loader, device):# TODO: 实现验证并计算平均 PSNRmodel.eval()psnr_total=0.0with torch.no_grad():   for lr, hr in tqdm(loader):lr, hr = lr.to(device), hr.to(device)sr = model(lr)psnr_total += calc_psnr(sr.cpu().numpy(), hr.cpu().numpy())return psnr_total / len(loader) raise NotImplementedError# ============================
# 5. 主函数
# ============================
def main():data_dir = "./BSDS100/HR"  # TODO: 修改为你的 BSD100 数据集路径scale = 4batch_size = 8epochs = 100lr = 1e-4patch_size = 48device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# TODO: 创建 dataset 和 dataloaderdataset = SRDataset(data_dir, scale=scale, patch_size=patch_size)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)# TODO: 初始化 SRCNN 模型model = SRCNN().to(device)# TODO: 定义损失函数和优化器criterion = nn.MSELoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)# TODO: 训练循环 + 验证 + 保存最优模型best_psnr = 0.0for epoch in range(epochs):train_loss = train_one_epoch(model, dataloader, criterion, optimizer, device)val_psnr = validate(model, dataloader, device)print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val PSNR: {val_psnr:.2f} dB")if val_psnr > best_psnr:best_psnr = val_psnrtorch.save(model.state_dict(), "best_srcnn.pth")print("Saved Best Model")       if __name__ == "__main__":main()

进行了可视化看了看结果,效果不是很好

数据集只有100张,并且网络也比较简单

4.2.4 可视化结果

import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)model = SRCNN().to(device)
model.load_state_dict(torch.load("best_srcnn.pth"))
model.eval()from PIL import Image
import torchvision.transforms.functional as TFimg_path = "./BSDS100/HR/14037.png"
hr = Image.open(img_path).convert("RGB")# 下采样 + 上采样
scale = 2
lr = hr.resize((hr.width//scale, hr.height//scale), Image.BICUBIC)
lr_up = lr.resize((hr.width, hr.height), Image.BICUBIC)# 转 tensor
lr_tensor = TF.to_tensor(lr_up).unsqueeze(0).to(device)  # [1,C,H,W]
sr_tensor = model(lr_tensor)
sr_img = sr_tensor.squeeze(0).cpu()
sr_img = TF.to_pil_image(sr_img)plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.title("LR Up")
plt.imshow(lr_up)
plt.axis('off')plt.subplot(1,3,2)
plt.title("SRCNN Output")
plt.imshow(sr_img)
plt.axis('off')plt.subplot(1,3,3)
plt.title("HR Ground Truth")
plt.imshow(hr)
plt.axis('off')plt.show()

---

更新中

http://www.dtcms.com/a/618306.html

相关文章:

  • 高端品牌网站建设注意事项制作ppt的基本做法
  • 2025 年 Redis 面试天花板
  • component-富文本实现(WangEditor)
  • 烟台城乡住房建设厅网站网站alt标签
  • win11上使用Workbench备份mysql数据库
  • B站评论数据采集:基于Requests的智能爬虫实战
  • 信息学与容斥
  • 网易云音乐评论数据采集:基于Requests的智能爬虫实战
  • 网站空间登录网站建设模式有哪些内容
  • VSCode 中快捷键的使用:(大小写转换快捷键、自动补全函数注释快捷键、代码和注释自动缩进快捷键)
  • 使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 25--数据驱动--参数化处理 Excel 文件 2
  • SpringCloud微服务笔记
  • 广告公司网站官网安徽网站建设流程
  • 华为OD机试真题2025双机位A卷 --【压缩日志查询】(Python C++ JAVA JS GO)
  • 网站编辑怎么做内容分类手机网站 程序
  • 瑞安建设网站成都vr 网站开发
  • C++多线程【数据共享】之互斥锁
  • Java漏洞集合工具
  • JavaScript 正则表达式详解
  • 【CS创世SD NAND征文】高可靠性数控设备:技术方案与行业展望
  • 深入理解Go语言Slice的append操作:从内存分配到扩容机制
  • Linux---文件控制<fcntl.h> (file control, fcntl)
  • 网站放到服务器珠海市 网站建设
  • 农林科技公司网站模板seo研究中心官网
  • 东莞响应式网站哪家好架设网站开发环境
  • 类似淘宝网站建设有哪些模板wordpress文章图片全屏浏览
  • 技术演进中的开发沉思-194 JavaScript: Prototype 框架
  • Windows MongoDB 安装与配置指南
  • Kafka客户端整合
  • 购物网站建设方案手机建立网站的软件