超分辨率重建(Super-Resolution, SR)
1. 超分辨率的任务定义
-
输入:低分辨率图像 LR
-
输出:高分辨率图像 HR
-
本质:让模糊图变清晰,恢复纹理和细节
超分辨率(SR)就是把低分辨率图像恢复成高分辨率图像,让图像更清晰,同时尽可能恢复真实、合理的纹理细节。
2. SR 的评价指标
-
PSNR:像素误差低(值高)
-
SSIM:结构相似度高(越接近 1 越好)
2.1 PSNR

3. 为什么插值方法(上采样)无法做到高质量超分辨率
首先,插值方法(如双线性、双三次)只是根据邻近像素做数学推断,它不会“理解”图像内容,因此只能生成平滑的过渡,无法真正恢复丢失的细节。
其次,低分辨率图像中本来就缺失高频纹理(如纹理、细线、毛发等),插值无法凭空生成,只能把现有像素“拉伸”,导致图像变得模糊。
最后,不同区域的纹理结构非常复杂,插值方法无法区分边缘、平坦区、纹理区,因此容易产生伪影或边缘模糊,而深度学习才能通过大量样本学习规律化的细节恢复能力。
https://cloud.tencent.com/developer/article/2095690
https://cloud.tencent.com/developer/article/2095690
4. SRCNN
4.1 概念
-
SRCNN 是 2014 年提出的最简单的深度学习超分模型,思想直观:
先用传统插值把 LR 放大到目标尺寸(bicubic)→ 再用 3 层卷积网络修复细节。 -
网络结构(最常见配置):
-
Conv1: 大核(9×9),64 个通道 —— 特征提取(抓大的局部结构)
-
Conv2: 1×1,32 个通道 —— 非线性映射(把特征从低维映射到高维)
-
Conv3: 5×5,输出通道 = 图像通道(1 或 3) —— 重建图像
-
-
激活:前两层后面接 ReLU,最后一层不接激活(直接回归像素值)
-
训练目标:最简单常用的是 L2(MSE)或 L1 损失,衡量输出与 HR 的像素差
-
优点:结构极其简单、便于理解和复现;
-
缺点:需要先插值放大(计算浪费),恢复能力相对现代方法较弱
4.2 实践
4.2.1 带 TODO 的 PyTorch 代码框架
# srcnn_skeleton.pyimport os
import random
from glob import glob
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm# ============================
# 1. 数据集模块
# ============================
class SRDataset(Dataset):"""HR -> LR -> 上采样后的网络输入"""def __init__(self, root, scale=2, patch_size=48):self.root = rootself.files = glob(os.path.join(root, "*.png")) + glob(os.path.join(root, "*.jpg"))self.scale = scaleself.patch_size = patch_sizeself.to_tensor = transforms.ToTensor()def __len__(self):return len(self.files)def _random_crop(self, img):# TODO: 随机裁剪 patchraise NotImplementedErrordef __getitem__(self, idx):# TODO: 打开 HR 图像# TODO: 随机裁剪# TODO: 下采样生成 LR# TODO: 上采样回 HR 大小# TODO: 转 tensor 返回raise NotImplementedError# ============================
# 2. 模型模块
# ============================
class SRCNN(nn.Module):"""SRCNN 网络"""def __init__(self, in_channels=3):super().__init__()# TODO: 定义三层卷积raise NotImplementedErrordef forward(self, x):# TODO: 实现 forward# conv1 -> relu -> conv2 -> relu -> conv3 -> 输出 clampraise NotImplementedError# ============================
# 3. 工具函数模块
# ============================
def calc_psnr(sr, hr, shave_border=0):# TODO: 实现 PSNR 计算raise NotImplementedError# ============================
# 4. 训练循环模块
# ============================
def train_one_epoch(model, loader, criterion, optimizer, device):# TODO: 实现训练一轮raise NotImplementedErrordef validate(model, loader, device):# TODO: 实现验证并计算平均 PSNRraise NotImplementedError# ============================
# 5. 主函数
# ============================
def main():data_dir = "./BSD100" # TODO: 修改为你的 BSD100 数据集路径scale = 2batch_size = 8epochs = 30lr = 1e-4patch_size = 48device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# TODO: 创建 dataset 和 dataloaderraise NotImplementedError# TODO: 初始化 SRCNN 模型raise NotImplementedError# TODO: 定义损失函数和优化器raise NotImplementedError# TODO: 训练循环 + 验证 + 保存最优模型raise NotImplementedErrorif __name__ == "__main__":main()
4.2.2 数据集说明
使用的是BSD100数据集
只有100张,主要是学习SRCNN,于是选择用比较小的数据https://aistudio.baidu.com/datasetdetail/99299?login_type=weixin
https://aistudio.baidu.com/datasetdetail/99299?login_type=weixin
4.2.3 实现
# srcnn_train.pyimport os
import random
from glob import glob
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np# ============================
# 1. 数据集模块
# ============================
class SRDataset(Dataset):"""HR -> LR -> 上采样后的网络输入"""def __init__(self, root, scale=2, patch_size=48):self.root = rootself.files = glob(os.path.join(root, "*.png")) + glob(os.path.join(root, "*.jpg"))self.scale = scaleself.patch_size = patch_sizeself.to_tensor = transforms.ToTensor()def __len__(self):return len(self.files)def _random_crop(self, img):# TODO: 随机裁剪 patchw,h= img.sizetop= random.randint(0,h - self.patch_size)left= random.randint(0,w - self.patch_size)return img.crop((left,top,left + self.patch_size,top + self.patch_size))def __getitem__(self, idx):# TODO: 打开 HR 图像# TODO: 随机裁剪# TODO: 下采样生成 LR# TODO: 上采样回 HR 大小# TODO: 转 tensor 返回img=Image.open(self.files[idx]).convert("RGB")img= self._random_crop(img)lr_img=img.resize((img.width // self.scale, img.height // self.scale), Image.NEAREST)lr_up=lr_img.resize((img.width, img.height), Image.NEAREST)return self.to_tensor(lr_up), self.to_tensor(img)raise NotImplementedError# ============================
# 2. 模型模块
# ============================
class SRCNN(nn.Module):"""SRCNN 网络"""def __init__(self, in_channels=3):super().__init__()# TODO: 定义三层卷积self.conv1=nn.Conv2d(in_channels, 64, kernel_size=9, padding=4)self.conv2=nn.Conv2d(64, 32, kernel_size=1, padding=0)self.conv3=nn.Conv2d(32, in_channels, kernel_size=5, padding=2)self.relu=nn.ReLU()def forward(self, x):# TODO: 实现 forward# conv1 -> relu -> conv2 -> relu -> conv3 -> 输出 clampx=self.relu(self.conv1(x))x=self.relu(self.conv2(x))x=self.conv3(x)return torch.clamp(x, 0.0, 1.0)#确保输出的张量 x 中的所有值都在 0.0 到 1.0 之间#torch.clamp(input, min, max) 会将 input 中小于 min 的值设置为 min,将大于 max 的值设置为 max。raise NotImplementedError# ============================
# 3. 工具函数模块
# ============================
def calc_psnr(sr, hr, shave_border=0):sr = sr.transpose(0,2,3,1)[0] if sr.ndim==4 else sr.transpose(1,2,0)hr = hr.transpose(0,2,3,1)[0] if hr.ndim==4 else hr.transpose(1,2,0)if shave_border > 0:sr = sr[shave_border:-shave_border, shave_border:-shave_border, :]hr = hr[shave_border:-shave_border, shave_border:-shave_border, :]mse = np.mean((sr - hr) ** 2)if mse == 0:return 100return 20 * np.log10(1.0 / np.sqrt(mse))# ============================
# 4. 训练循环模块
# ============================
def train_one_epoch(model, loader, criterion, optimizer, device):# TODO: 实现训练一轮model.train()running_loss=0.0for lr, hr in tqdm(loader):lr, hr = lr.to(device), hr.to(device)optimizer.zero_grad()pred = model(lr)loss = criterion(pred, hr)loss.backward()optimizer.step()running_loss += loss.item()return running_loss / len(loader)raise NotImplementedErrordef validate(model, loader, device):# TODO: 实现验证并计算平均 PSNRmodel.eval()psnr_total=0.0with torch.no_grad(): for lr, hr in tqdm(loader):lr, hr = lr.to(device), hr.to(device)sr = model(lr)psnr_total += calc_psnr(sr.cpu().numpy(), hr.cpu().numpy())return psnr_total / len(loader) raise NotImplementedError# ============================
# 5. 主函数
# ============================
def main():data_dir = "./BSDS100/HR" # TODO: 修改为你的 BSD100 数据集路径scale = 4batch_size = 8epochs = 100lr = 1e-4patch_size = 48device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# TODO: 创建 dataset 和 dataloaderdataset = SRDataset(data_dir, scale=scale, patch_size=patch_size)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)# TODO: 初始化 SRCNN 模型model = SRCNN().to(device)# TODO: 定义损失函数和优化器criterion = nn.MSELoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)# TODO: 训练循环 + 验证 + 保存最优模型best_psnr = 0.0for epoch in range(epochs):train_loss = train_one_epoch(model, dataloader, criterion, optimizer, device)val_psnr = validate(model, dataloader, device)print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val PSNR: {val_psnr:.2f} dB")if val_psnr > best_psnr:best_psnr = val_psnrtorch.save(model.state_dict(), "best_srcnn.pth")print("Saved Best Model") if __name__ == "__main__":main()
进行了可视化看了看结果,效果不是很好

数据集只有100张,并且网络也比较简单
4.2.4 可视化结果
import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)model = SRCNN().to(device)
model.load_state_dict(torch.load("best_srcnn.pth"))
model.eval()from PIL import Image
import torchvision.transforms.functional as TFimg_path = "./BSDS100/HR/14037.png"
hr = Image.open(img_path).convert("RGB")# 下采样 + 上采样
scale = 2
lr = hr.resize((hr.width//scale, hr.height//scale), Image.BICUBIC)
lr_up = lr.resize((hr.width, hr.height), Image.BICUBIC)# 转 tensor
lr_tensor = TF.to_tensor(lr_up).unsqueeze(0).to(device) # [1,C,H,W]
sr_tensor = model(lr_tensor)
sr_img = sr_tensor.squeeze(0).cpu()
sr_img = TF.to_pil_image(sr_img)plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.title("LR Up")
plt.imshow(lr_up)
plt.axis('off')plt.subplot(1,3,2)
plt.title("SRCNN Output")
plt.imshow(sr_img)
plt.axis('off')plt.subplot(1,3,3)
plt.title("HR Ground Truth")
plt.imshow(hr)
plt.axis('off')plt.show()
---
更新中
