【超分辨率】基于DDIM+SwinUnet实现超分辨率
详细代码及训练得到的8倍超分辨率模型已放在GitHub
Github: SuperResolution-DDIM-SwinUnet
简介
-
在DIV2K数据集(800张2K图像)上训练了一个8倍超分辨率模型,采用了和sr3一样的:将低分辨率图像和噪声拼接输入模型。不过没有采用sr3的直接输入噪声强度,而是继续沿用输入去燥步骤t的方法,并增加了DDPM的步数到1000(如果仅是100步的话,输出结果的噪点会比较多)。
-
效果图放在了Github的result目录里,引入了DDIM采样(这也是使用t作为时间条件的好处),从结果看DDIM仅需采样40步效果就和DDPM采样1000步相当了。而DDIM采样1步或2步也能大体还原,不过质量不高。
不足:
1.可能是使用SwinUnet的关系,超分辨率后的图像总是能隐约看到“小框框”;而且图像大小必须能被256整除(这个其实好解决,resize即可)。
2.只做了一个8倍超分辨率的模型(倍数太大,从效果来看失真率很高),可以考虑做倍率较低的比如2倍和4倍,进行拼接从而实现8倍的效果,可能失真率会好一点。
代码:(run.py、scheduler.py、SwinUnet.py、load_data.py、training.py)
"run.py"
import numpy as np
import torch
from SwinUnet import SwinUnet
from scheduler import Scheduler
from PIL import Image
import argparse
import datetime
import os
def main(args):
device = torch.device(args.device)
model = SwinUnet(channels=3, dim=96, mlp_ratio=4, patch_size=4, window_size=8,
depth=[2, 2, 6, 2], nheads=[3, 6, 12, 24]).to(device)
sr_ratio = args.sr_ratio
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.eval()
scheduler = Scheduler(model, args.denoise_steps)
image_path = args.image_path
img = Image.open(image_path)
img_size = img.size
assert img_size[0] >= 256 and img_size[1] >= 256, "图片的最小尺寸为256"
img_size = (
(img_size[0] // 256) * 256 * sr_ratio,
(img_size[1] // 256) * 256 * sr_ratio
)
img = img.resize(img_size)
img_arr = np.array(img)
if img_arr.shape[-1] == 4: img_arr = img_arr[..., :3]
img_arr = img_arr.transpose(2, 0, 1) / 255.
img_arr = 2 * (img_arr - 0.5)
img_arr = torch.from_numpy(img_arr).float().to(device)
img_arr = img_arr.unsqueeze(0)
if args.use_ddim:
y = scheduler.ddim(img_arr, device, sub_sequence_step=args.ddim_sub_sequence_steps)[-1]
else:
y = scheduler.ddpm(img_arr, device)[-1]
y = y.transpose(1, 2, 0)
y = (y + 1.) / 2
y *= 255.0
new_img = Image.fromarray(y.astype(np.uint8))
new_img.save(os.path.join(args.results_dir, str(datetime.datetime.now()) + ".png"))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--image_path", type=str)
parser.add_argument("--sr_ratio", type=int, default=8)
parser.add_argument("--results_dir", type=str, default="./results")
parser.add_argument("--denoise_steps", type=int, default=1000)
parser.add_argument("--model_path", type=str, default="SwinUNet-SR8.pth")
parser.add_argument("--use_ddim", type=int, default=1)
parser.add_argument("--ddim_sub_sequence_steps", type=int, default=25)
args = parser.parse_args()
main(args)
"scheduler.py"
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
def extract_into_tensor(arr, timesteps, broadcast_shape):
res = torch.from_numpy(arr).to(torch.float32).to(device=timesteps.device)[timesteps]
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class Scheduler:
def __init__(self, denoise_model, denoise_steps, beta_start=1e-4, beta_end=0.005):
self.model = denoise_model
betas = np.array(
np.linspace(beta_start, beta_end, denoise_steps),
dtype=np.float64
)
self.denoise_steps = denoise_steps
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
alphas = 1.0 - betas
self.sqrt_alphas = np.sqrt(alphas)
self.one_minus_alphas = 1.0 - alphas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
def q_sample(self, y0, t, noise):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, y0.shape) * y0
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y0.shape) * noise
)
def training_losses(self, x, y, t):
noise = torch.randn_like(y)
y_t = self.q_sample(y, t, noise)
predict_noise = self.model(torch.cat([x, y_t], dim=1), t)
return F.mse_loss(predict_noise, noise)
@torch.no_grad()
def ddpm(self, x, device):
y = torch.randn(*x.shape, device=device)
for t in tqdm(reversed(range(0, self.denoise_steps)), total=self.denoise_steps):
t = torch.tensor([t], device=device).repeat(x.shape[0])
t_mask = (t != 0).float().view(-1, *([1] * (len(y.shape) - 1)))
eps = self.model(torch.cat([x, y], dim=1), t)
y = y - (
extract_into_tensor(self.one_minus_alphas, t, y.shape) * eps
/ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y.shape)
)
y = y / extract_into_tensor(self.sqrt_alphas, t, y.shape)
sigma = torch.sqrt(
extract_into_tensor(self.one_minus_alphas, t, y.shape)
* (1.0 - extract_into_tensor(self.alphas_cumprod_prev, t, y.shape))
/ (1.0 - extract_into_tensor(self.alphas_cumprod, t, y.shape))
)
y = y + sigma * torch.randn_like(y) * t_mask
y = y.clip(-1, 1)
return y.detach().cpu().numpy()
@torch.no_grad()
def ddim(self, x, device, eta=0.0, sub_sequence_step=25):
# 初始化 y 为高斯噪声
y = torch.randn(*x.shape, device=device)
# 构造跳步采样的时间序列,从 denoise_steps-1 到 0,每隔 jump 取一个时间步
t_seq = list(range(self.denoise_steps - 1, -1, -sub_sequence_step))
for i in tqdm(range(len(t_seq)), total=len(t_seq)):
# 当前时间步 t 和下一个采样时间步 s(若为最后一步,则 s 设为 0)
t = t_seq[i]
s = 0 if i == len(t_seq) - 1 else t_seq[i + 1]
# 构造与 batch 数量相同的时间步张量
t_tensor = torch.tensor([t], device=device).repeat(x.shape[0])
s_tensor = torch.tensor([s], device=device).repeat(x.shape[0])
# 用模型预测噪声
eps = self.model(torch.cat([x, y], dim=1), t_tensor)
# 提取当前和下一个时间步对应的 α 累积乘积
alpha_bar_t = extract_into_tensor(self.alphas_cumprod, t_tensor, y.shape)
alpha_bar_s = extract_into_tensor(self.alphas_cumprod, s_tensor, y.shape)
# 根据 DDIM 公式预测原始样本 x0 的估计
y0_pred = (y - torch.sqrt(1 - alpha_bar_t) * eps) / torch.sqrt(alpha_bar_t)
# 计算控制随机性的 sigma
sigma = 0.0
if eta > 0.0 and s > 0:
sigma = eta * torch.sqrt(
(1 - alpha_bar_s) / (1 - alpha_bar_t) *
(1 - alpha_bar_t / alpha_bar_s)
)
# 利用预测的 x0 和当前噪声方向更新至下一个时间步的样本
y = torch.sqrt(alpha_bar_s) * y0_pred + torch.sqrt(1 - alpha_bar_s - sigma ** 2) * eps
# 若 eta > 0 则在更新后加入噪声(最后一步不添加)
if eta > 0.0 and s > 0:
y = y + sigma * torch.randn_like(y)
y = y.clip(-1, 1)
return y.detach().cpu().numpy()
"SwinUnet.py"
import numpy as np
import torch as th
from torch import nn, einsum
import math
from einops import rearrange
#############################################
# Sinusoidal 时间步嵌入
#############################################
class SinusoidalTimeEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = th.exp(th.arange(half_dim, device=device) * -emb)
emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
emb = th.cat([emb.sin(), emb.cos()], dim=-1)
return emb # [B, dim]
#############################################
# 下采样模块:Patch Merging
#############################################
class PatchMerging(nn.Module):
def __init__(self, in_channels, out_channels, downscaling_factor):
super().__init__()
self.downscaling_factor = downscaling_factor
self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)
def forward(self, x):
b, h, w, c = x.shape
new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
x = x.permute(0, 3, 1, 2)
x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
x = self.linear(x)
return x
#############################################
# 上采样模块:简单插值上采样
#############################################
class PatchExpanding(nn.Module):
def __init__(self, in_channels, out_channels, upscaling_factor):
super().__init__()
self.upscaling_factor = upscaling_factor
self.out_channels = out_channels
self.linear = nn.Linear(in_channels, out_channels * self.upscaling_factor ** 2)
def forward(self, x):
B, H, W, _ = x.shape
x = self.linear(x)
x = x.view(B, H, W, self.upscaling_factor, self.upscaling_factor, self.out_channels)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.view(B, H * self.upscaling_factor, W * self.upscaling_factor, self.out_channels)
return x
#############################################
# 窗口自注意力机制
#############################################
class WindowAttention(nn.Module):
def __init__(self, dim, nheads, window_size, shifted, relative_pos_embedding):
super().__init__()
head_dim = dim // nheads
self.nheads = nheads
self.scale = head_dim ** -0.5
self.window_size = window_size
self.relative_pos_embedding = relative_pos_embedding
self.shifted = shifted
if self.shifted:
displacement = window_size // 2
self.cyclic_shift = CyclicShift(-displacement)
self.cyclic_back_shift = CyclicShift(displacement)
self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
upper_lower=True, left_right=False), requires_grad=False)
self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
upper_lower=False, left_right=True), requires_grad=False)
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
if self.relative_pos_embedding:
self.relative_indices = get_relative_distances(window_size) + window_size - 1
self.pos_embedding = nn.Parameter(th.randn(2 * window_size - 1, 2 * window_size - 1))
else:
self.pos_embedding = nn.Parameter(th.randn(window_size ** 2, window_size ** 2))
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
if self.shifted:
x = self.cyclic_shift(x)
b, n_h, n_w, _, h = *x.shape, self.nheads
qkv = self.to_qkv(x).chunk(3, dim=-1)
nw_h = n_h // self.window_size
nw_w = n_w // self.window_size
q, k, v = map(
lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
h=h, w_h=self.window_size, w_w=self.window_size), qkv)
dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale
if self.relative_pos_embedding:
dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
else:
dots += self.pos_embedding
if self.shifted:
dots[:, :, -nw_w:] += self.upper_lower_mask
dots[:, :, nw_w - 1::nw_w] += self.left_right_mask
attn = dots.softmax(dim=-1)
out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
out = self.to_out(out)
if self.shifted:
out = self.cyclic_back_shift(out)
return out
class CyclicShift(nn.Module):
def __init__(self, displacement):
super().__init__()
self.displacement = displacement
def forward(self, x):
return th.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
def create_mask(window_size, displacement, upper_lower, left_right):
mask = th.zeros(window_size ** 2, window_size ** 2)
if upper_lower:
mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')
if left_right:
mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
mask[:, -displacement:, :, :-displacement] = float('-inf')
mask[:, :-displacement, :, -displacement:] = float('-inf')
mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
return mask
def get_relative_distances(window_size):
indices = th.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
distances = indices[None, :, :] - indices[:, None, :]
return distances
#############################################
# SwinTransformerBlock: 采用和DiT相同的Adaptive Layer Normalization
#############################################
def modulate(x, shift, scale):
return x * (1 + scale[:, None, None, :]) + shift[:, None, None, :]
class SwinTransformerAdaLnBlock(nn.Module):
def __init__(self, dim, mlp_ratio, nheads, window_size, shifted, relative_pos_embedding):
super().__init__()
self.dim = dim
self.attn = WindowAttention(dim, nheads, window_size, shifted, relative_pos_embedding)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_ratio * dim),
nn.GELU(),
nn.Linear(mlp_ratio * dim, dim)
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim)
)
def forward(self, x, t):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(6, dim=1)
x = x + gate_msa[:, None, None, :] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp[:, None, None, :] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
#############################################
# SwinUnet blocks 各组件
#############################################
def block_forward(block, x, t):
for b in block:
x = b(x, t[:, :b.dim])
return x
class SwinUnetEncoder(nn.Module):
def __init__(self, channels, dim, patch_size, depth, mlp_ratio, nheads, window_size, relative_pos_embedding):
super().__init__()
self.patch_embed = PatchMerging(channels, dim, patch_size)
self.block0 = nn.ModuleList([
SwinTransformerAdaLnBlock(
dim=dim * 1,
mlp_ratio=mlp_ratio,
nheads=nheads[0],
window_size=window_size,
shifted=True if i // 2 == 0 else False,
relative_pos_embedding=relative_pos_embedding
) for i in range(1, depth[0] + 1)
])
self.patch_merge0 = PatchMerging(dim * 1, dim * 2, downscaling_factor=2)
self.block1 = nn.ModuleList([
SwinTransformerAdaLnBlock(
dim=dim * 2,
mlp_ratio=mlp_ratio,
nheads=nheads[1],
window_size=window_size,
shifted=True if i // 2 == 0 else False,
relative_pos_embedding=relative_pos_embedding
) for i in range(1, depth[1] + 1)
])
self.patch_merge1 = PatchMerging(dim * 2, dim * 4, downscaling_factor=2)
self.block2 = nn.ModuleList([
SwinTransformerAdaLnBlock(
dim=dim * 4,
mlp_ratio=mlp_ratio,
nheads=nheads[2],
window_size=window_size,
shifted=True if i // 2 == 0 else False,
relative_pos_embedding=relative_pos_embedding
) for i in range(1, depth[2] + 1)
])
self.patch_merge2 = PatchMerging(dim * 4, dim * 8, downscaling_factor=2)
def forward(self, x, t):
x = x.permute(0, 2, 3, 1)
skip_connections = []
x = self.patch_embed(x)
x = block_forward(self.block0, x, t)
skip_connections.append(x)
x = self.patch_merge0(x)
x = block_forward(self.block1, x, t)
skip_connections.append(x)
x = self.patch_merge1(x)
x = block_forward(self.block2, x, t)
skip_connections.append(x)
x = self.patch_merge2(x)
return x, skip_connections
class SwinUnetDecoder(nn.Module):
def __init__(self, channels, dim, patch_size, depth, mlp_ratio, nheads, window_size, relative_pos_embedding):
super().__init__()
self.patch_expand0 = PatchExpanding(dim * 8, dim * 4, upscaling_factor=2)
self.block0 = nn.ModuleList([
SwinTransformerAdaLnBlock(
dim=dim * 4,
mlp_ratio=mlp_ratio,
nheads=nheads[2],
window_size=window_size,
shifted=True if i // 2 == 0 else False,
relative_pos_embedding=relative_pos_embedding
) for i in range(1, depth[2] + 1)
])
self.skip0 = nn.Linear(dim * 4 * 2, dim * 4, bias=False)
self.patch_expand1 = PatchExpanding(dim * 4, dim * 2, upscaling_factor=2)
self.block1 = nn.ModuleList([
SwinTransformerAdaLnBlock(
dim=dim * 2,
mlp_ratio=mlp_ratio,
nheads=nheads[1],
window_size=window_size,
shifted=True if i // 2 == 0 else False,
relative_pos_embedding=relative_pos_embedding
) for i in range(1, depth[1] + 1)
])
self.skip1 = nn.Linear(dim * 2 * 2, dim * 2, bias=False)
self.patch_expand2 = PatchExpanding(dim * 2, dim * 1, upscaling_factor=2)
self.block2 = nn.ModuleList([
SwinTransformerAdaLnBlock(
dim=dim * 1,
mlp_ratio=mlp_ratio,
nheads=nheads[0],
window_size=window_size,
shifted=True if i // 2 == 0 else False,
relative_pos_embedding=relative_pos_embedding
) for i in range(1, depth[0] + 1)
])
self.skip2 = nn.Linear(dim * 1 * 2, dim * 1, bias=False)
self.patch_to_image = PatchExpanding(dim, channels, patch_size)
def forward(self, x, skip_connect, t):
x = self.patch_expand0(x)
x = th.cat((x, skip_connect[2]), dim=-1)
x = self.skip0(x)
x = block_forward(self.block0, x, t)
x = self.patch_expand1(x)
x = th.cat((x, skip_connect[1]), dim=-1)
x = self.skip1(x)
x = block_forward(self.block1, x, t)
x = self.patch_expand2(x)
x = th.cat((x, skip_connect[0]), dim=-1)
x = self.skip2(x)
x = block_forward(self.block2, x, t)
x = self.patch_to_image(x)
return x.permute(0, 3, 1, 2)
#############################################
# SwinUnet: 条件的处理采用直接拼接
#############################################
class SwinUnet(nn.Module):
def __init__(self, channels, dim, mlp_ratio, patch_size, window_size, depth, nheads,
relative_pos_embedding=True, use_condition=True):
super().__init__()
self.time_embed = SinusoidalTimeEmb(8 * dim)
self.encoder = SwinUnetEncoder(channels=2 * channels if use_condition else channels, dim=dim, patch_size=patch_size,
depth=depth[:3], mlp_ratio=mlp_ratio, nheads=nheads[:3],
window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.bottleneck = nn.ModuleList([
SwinTransformerAdaLnBlock(
dim=8 * dim,
mlp_ratio=mlp_ratio,
nheads=nheads[-1],
window_size=window_size,
shifted=True if i // 2 == 0 else False,
relative_pos_embedding=relative_pos_embedding
) for i in range(1, depth[-1] + 1)
])
self.decoder = SwinUnetDecoder(channels=channels, dim=dim, patch_size=patch_size,
depth=depth[:3], mlp_ratio=mlp_ratio, nheads=nheads[:3],
window_size=window_size, relative_pos_embedding=relative_pos_embedding)
def forward(self, x, t):
t = self.time_embed(t)
x, skip_connection = self.encoder(x, t)
x = block_forward(self.bottleneck, x, t)
return self.decoder(x, skip_connection, t)
"load_data.py"
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
def is_image_file(file_path):
# 定义常见的图片文件扩展名
image_extensions = {'.jpg', '.jpeg', '.png'}
# 获取文件的扩展名并判断是否在图片扩展名集合中
file_extension = os.path.splitext(file_path)[1].lower()
return file_extension in image_extensions
class CustomDataset(Dataset):
def __init__(self, path, img_size=None, sr_ratio=8):
super().__init__()
files = os.listdir(path)
self.img_size = img_size
self.files = []
for file in files:
self.files.append(os.path.join(path, file))
self.ratio = sr_ratio
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
hr_img = Image.open(self.files[idx])
if self.img_size is not None:
hr_img = hr_img.resize(self.img_size)
hr_size = hr_img.size
else:
hr_size = hr_img.size
hr_size = ((hr_size[0] // 256 + 1) * 256, (hr_size[1] // 256 + 1) * 256)
hr_img = hr_img.resize(hr_size)
hr_arr = np.array(hr_img).transpose(2, 0, 1) / 255.
lr_img = hr_img.resize((hr_size[0] // self.ratio, hr_size[1] // self.ratio))
lr_img = lr_img.resize(hr_size)
lr_arr = np.array(lr_img).transpose(2, 0, 1) / 255.
lr_arr = 2 * (lr_arr - 0.5)
hr_arr = 2 * (hr_arr - 0.5)
return lr_arr, hr_arr
class ImageNetDataset(Dataset):
def __init__(self, path, img_size=(256, 256), sr_ratio=8):
super().__init__()
self.img_size = img_size
class_dirs = os.listdir(path)
self.files = []
for class_dir in class_dirs:
files = os.listdir(os.path.join(path, class_dir))
for file in files:
if is_image_file(os.path.join(path, class_dir, file)):
self.files.append(os.path.join(path, class_dir, file))
self.ratio = sr_ratio
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
hr_img = Image.open(self.files[idx])
hr_img = hr_img.convert("RGB")
if self.img_size is not None:
hr_img = hr_img.resize(self.img_size)
hr_size = hr_img.size
hr_arr = np.array(hr_img).transpose(2, 0, 1) / 255.
lr_img = hr_img.resize((hr_size[0] // self.ratio, hr_size[1] // self.ratio))
lr_img = lr_img.resize(hr_size)
lr_arr = np.array(lr_img).transpose(2, 0, 1) / 255.
lr_arr = 2 * (lr_arr - 0.5)
hr_arr = 2 * (hr_arr - 0.5)
return lr_arr, hr_arr
"training.py"
import torch
import numpy as np
from torch import optim
from tqdm import tqdm
from torch.autograd import Variable
from torch.utils.data import DataLoader
from load_data import CustomDataset
from scheduler import Scheduler
from SwinUnet import SwinUnet
if __name__ == '__main__':
device = torch.device("mps")
batch_size = 16
lr = 1e-4
epochs = 200
denoise_steps = 1000
sr_ratio = 8
train_dataset = CustomDataset(
"./DIV2K_train_HR", img_size=(512, 512), sr_ratio=sr_ratio,
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
model = SwinUnet(channels=3, dim=96, mlp_ratio=4, patch_size=4, window_size=8,
depth=[2, 2, 6, 2], nheads=[3, 6, 12, 24]).to(device)
model.load_state_dict(torch.load("SwinUNet-SR8.pth", map_location=device))
optimizer = optim.AdamW(model.parameters(), lr=lr)
scheduler = Scheduler(model, denoise_steps)
model.train()
for epoch in range(epochs):
print('*' * 40)
train_loss = []
for i, data in tqdm(enumerate(train_loader, 1), total=len(train_loader)):
x, y = data
x = Variable(x).to(torch.float32).to(device)
y = Variable(y).to(torch.float32).to(device)
t = torch.randint(low=0, high=denoise_steps, size=(x.shape[0],)).to(device)
training_loss = scheduler.training_losses(x, y, t)
optimizer.zero_grad()
training_loss.backward()
optimizer.step()
train_loss.append(training_loss.item())
torch.save(model.state_dict(), f"unet-sr{sr_ratio}.pth")
print('Finish {} Loss: {:.6f}'.format(epoch + 1, np.mean(train_loss)))