当前位置: 首页 > news >正文

【超分辨率】基于DDIM+SwinUnet实现超分辨率

详细代码及训练得到的8倍超分辨率模型已放在GitHub

Github: SuperResolution-DDIM-SwinUnet

简介

  • 在DIV2K数据集(800张2K图像)上训练了一个8倍超分辨率模型,采用了和sr3一样的:将低分辨率图像和噪声拼接输入模型。不过没有采用sr3的直接输入噪声强度,而是继续沿用输入去燥步骤t的方法,并增加了DDPM的步数到1000(如果仅是100步的话,输出结果的噪点会比较多)。

  • 效果图放在了Github的result目录里,引入了DDIM采样(这也是使用t作为时间条件的好处),从结果看DDIM仅需采样40步效果就和DDPM采样1000步相当了。而DDIM采样1步或2步也能大体还原,不过质量不高。

不足:

1.可能是使用SwinUnet的关系,超分辨率后的图像总是能隐约看到“小框框”;而且图像大小必须能被256整除(这个其实好解决,resize即可)。
2.只做了一个8倍超分辨率的模型(倍数太大,从效果来看失真率很高),可以考虑做倍率较低的比如2倍和4倍,进行拼接从而实现8倍的效果,可能失真率会好一点。

代码:(run.py、scheduler.py、SwinUnet.py、load_data.py、training.py)
"run.py"
import numpy as np
import torch

from SwinUnet import SwinUnet
from scheduler import Scheduler
from PIL import Image

import argparse
import datetime
import os

def main(args):

    device = torch.device(args.device)

    model = SwinUnet(channels=3, dim=96, mlp_ratio=4, patch_size=4, window_size=8,
                     depth=[2, 2, 6, 2], nheads=[3, 6, 12, 24]).to(device)

    sr_ratio = args.sr_ratio

    model.load_state_dict(torch.load(args.model_path, map_location=device))
    model.eval()

    scheduler = Scheduler(model, args.denoise_steps)

    image_path = args.image_path
    img = Image.open(image_path)

    img_size = img.size

    assert img_size[0] >= 256 and img_size[1] >= 256, "图片的最小尺寸为256"

    img_size = (
        (img_size[0] // 256) * 256 * sr_ratio,
        (img_size[1] // 256) * 256 * sr_ratio
    )

    img = img.resize(img_size)

    img_arr = np.array(img)

    if img_arr.shape[-1] == 4: img_arr = img_arr[..., :3]

    img_arr = img_arr.transpose(2, 0, 1) / 255.

    img_arr = 2 * (img_arr - 0.5)

    img_arr = torch.from_numpy(img_arr).float().to(device)
    img_arr = img_arr.unsqueeze(0)

    if args.use_ddim:
        y = scheduler.ddim(img_arr, device, sub_sequence_step=args.ddim_sub_sequence_steps)[-1]
    else:
        y = scheduler.ddpm(img_arr, device)[-1]

    y = y.transpose(1, 2, 0)
    y = (y + 1.) / 2
    y *= 255.0

    new_img = Image.fromarray(y.astype(np.uint8))

    new_img.save(os.path.join(args.results_dir, str(datetime.datetime.now()) + ".png"))


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--image_path", type=str)
    parser.add_argument("--sr_ratio", type=int, default=8)
    parser.add_argument("--results_dir", type=str, default="./results")
    parser.add_argument("--denoise_steps", type=int, default=1000)
    parser.add_argument("--model_path", type=str, default="SwinUNet-SR8.pth")
    parser.add_argument("--use_ddim", type=int, default=1)
    parser.add_argument("--ddim_sub_sequence_steps", type=int, default=25)
    args = parser.parse_args()
    main(args)
"scheduler.py"
import numpy as np

import torch

import torch.nn.functional as F

from tqdm import tqdm

def extract_into_tensor(arr, timesteps, broadcast_shape):

    res = torch.from_numpy(arr).to(torch.float32).to(device=timesteps.device)[timesteps]
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res + torch.zeros(broadcast_shape, device=timesteps.device)

class Scheduler:

    def __init__(self, denoise_model, denoise_steps, beta_start=1e-4, beta_end=0.005):

        self.model = denoise_model

        betas = np.array(
            np.linspace(beta_start, beta_end, denoise_steps),
            dtype=np.float64
        )

        self.denoise_steps = denoise_steps

        assert len(betas.shape) == 1, "betas must be 1-D"
        assert (betas > 0).all() and (betas <= 1).all()

        alphas = 1.0 - betas

        self.sqrt_alphas = np.sqrt(alphas)
        self.one_minus_alphas = 1.0 - alphas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)

        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)

        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])

    def q_sample(self, y0, t, noise):

        return (
            extract_into_tensor(self.sqrt_alphas_cumprod, t, y0.shape) * y0
            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y0.shape) * noise
        )

    def training_losses(self, x, y, t):

        noise = torch.randn_like(y)
        y_t = self.q_sample(y, t, noise)

        predict_noise = self.model(torch.cat([x, y_t], dim=1), t)

        return F.mse_loss(predict_noise, noise)


    @torch.no_grad()
    def ddpm(self, x, device):

        y = torch.randn(*x.shape, device=device)

        for t in tqdm(reversed(range(0, self.denoise_steps)), total=self.denoise_steps):

            t = torch.tensor([t], device=device).repeat(x.shape[0])
            t_mask = (t != 0).float().view(-1, *([1] * (len(y.shape) - 1)))

            eps = self.model(torch.cat([x, y], dim=1), t)

            y = y - (
                    extract_into_tensor(self.one_minus_alphas, t, y.shape) * eps
                    / extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, y.shape)
            )

            y = y / extract_into_tensor(self.sqrt_alphas, t, y.shape)

            sigma = torch.sqrt(
                extract_into_tensor(self.one_minus_alphas, t, y.shape)
                * (1.0 - extract_into_tensor(self.alphas_cumprod_prev, t, y.shape))
                / (1.0 - extract_into_tensor(self.alphas_cumprod, t, y.shape))
            )

            y = y + sigma * torch.randn_like(y) * t_mask

            y = y.clip(-1, 1)

        return y.detach().cpu().numpy()


    @torch.no_grad()
    def ddim(self, x, device, eta=0.0, sub_sequence_step=25):
        # 初始化 y 为高斯噪声
        y = torch.randn(*x.shape, device=device)
        # 构造跳步采样的时间序列,从 denoise_steps-1 到 0,每隔 jump 取一个时间步
        t_seq = list(range(self.denoise_steps - 1, -1, -sub_sequence_step))
        for i in tqdm(range(len(t_seq)), total=len(t_seq)):
            # 当前时间步 t 和下一个采样时间步 s(若为最后一步,则 s 设为 0)
            t = t_seq[i]
            s = 0 if i == len(t_seq) - 1 else t_seq[i + 1]
            # 构造与 batch 数量相同的时间步张量
            t_tensor = torch.tensor([t], device=device).repeat(x.shape[0])
            s_tensor = torch.tensor([s], device=device).repeat(x.shape[0])

            # 用模型预测噪声
            eps = self.model(torch.cat([x, y], dim=1), t_tensor)
            # 提取当前和下一个时间步对应的 α 累积乘积
            alpha_bar_t = extract_into_tensor(self.alphas_cumprod, t_tensor, y.shape)
            alpha_bar_s = extract_into_tensor(self.alphas_cumprod, s_tensor, y.shape)

            # 根据 DDIM 公式预测原始样本 x0 的估计
            y0_pred = (y - torch.sqrt(1 - alpha_bar_t) * eps) / torch.sqrt(alpha_bar_t)

            # 计算控制随机性的 sigma
            sigma = 0.0
            if eta > 0.0 and s > 0:
                sigma = eta * torch.sqrt(
                    (1 - alpha_bar_s) / (1 - alpha_bar_t) *
                    (1 - alpha_bar_t / alpha_bar_s)
                )
            # 利用预测的 x0 和当前噪声方向更新至下一个时间步的样本
            y = torch.sqrt(alpha_bar_s) * y0_pred + torch.sqrt(1 - alpha_bar_s - sigma ** 2) * eps
            # 若 eta > 0 则在更新后加入噪声(最后一步不添加)
            if eta > 0.0 and s > 0:
                y = y + sigma * torch.randn_like(y)

            y = y.clip(-1, 1)

        return y.detach().cpu().numpy()
"SwinUnet.py"
import numpy as np
import torch as th
from torch import nn, einsum

import math

from einops import rearrange


#############################################
# Sinusoidal 时间步嵌入
#############################################
class SinusoidalTimeEmb(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):

        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = th.exp(th.arange(half_dim, device=device) * -emb)
        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
        emb = th.cat([emb.sin(), emb.cos()], dim=-1)

        return emb  # [B, dim]


#############################################
# 下采样模块:Patch Merging
#############################################
class PatchMerging(nn.Module):

    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):

        b, h, w, c = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = x.permute(0, 3, 1, 2)
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        x = self.linear(x)

        return x


#############################################
# 上采样模块:简单插值上采样
#############################################
class PatchExpanding(nn.Module):

    def __init__(self, in_channels, out_channels, upscaling_factor):
        super().__init__()

        self.upscaling_factor = upscaling_factor
        self.out_channels = out_channels
        self.linear = nn.Linear(in_channels, out_channels * self.upscaling_factor ** 2)

    def forward(self, x):
        B, H, W, _ = x.shape
        x = self.linear(x)
        x = x.view(B, H, W, self.upscaling_factor, self.upscaling_factor, self.out_channels)

        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.view(B, H * self.upscaling_factor, W * self.upscaling_factor, self.out_channels)

        return x


#############################################
# 窗口自注意力机制
#############################################
class WindowAttention(nn.Module):

    def __init__(self, dim, nheads, window_size, shifted, relative_pos_embedding):
        super().__init__()

        head_dim = dim // nheads

        self.nheads = nheads
        self.scale = head_dim ** -0.5
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted

        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)

        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(th.randn(2 * window_size - 1, 2 * window_size - 1))
        else:
            self.pos_embedding = nn.Parameter(th.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(dim, dim)

    def forward(self, x):

        if self.shifted:
            x = self.cyclic_shift(x)

        b, n_h, n_w, _, h = *x.shape, self.nheads

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        nw_h = n_h // self.window_size
        nw_w = n_w // self.window_size

        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h=h, w_h=self.window_size, w_w=self.window_size), qkv)

        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale

        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding

        if self.shifted:
            dots[:, :, -nw_w:] += self.upper_lower_mask
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

        attn = dots.softmax(dim=-1)

        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out)
        return out


class CyclicShift(nn.Module):

    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement

    def forward(self, x):
        return th.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))


def create_mask(window_size, displacement, upper_lower, left_right):
    mask = th.zeros(window_size ** 2, window_size ** 2)

    if upper_lower:
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask


def get_relative_distances(window_size):
    indices = th.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    return distances


#############################################
# SwinTransformerBlock: 采用和DiT相同的Adaptive Layer Normalization
#############################################
def modulate(x, shift, scale):
    return x * (1 + scale[:, None, None, :]) + shift[:, None, None, :]


class SwinTransformerAdaLnBlock(nn.Module):

    def __init__(self, dim, mlp_ratio, nheads, window_size, shifted, relative_pos_embedding):
        super().__init__()

        self.dim = dim
        self.attn = WindowAttention(dim, nheads, window_size, shifted, relative_pos_embedding)
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)

        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_ratio * dim),
            nn.GELU(),
            nn.Linear(mlp_ratio * dim, dim)
        )
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim)
        )

    def forward(self, x, t):

        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(6, dim=1)

        x = x + gate_msa[:, None, None, :] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp[:, None, None, :] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))

        return x


#############################################
# SwinUnet blocks 各组件
#############################################
def block_forward(block, x, t):

    for b in block:

        x = b(x, t[:, :b.dim])

    return x


class SwinUnetEncoder(nn.Module):

    def __init__(self, channels, dim, patch_size, depth, mlp_ratio, nheads, window_size, relative_pos_embedding):
        super().__init__()

        self.patch_embed = PatchMerging(channels, dim, patch_size)

        self.block0 = nn.ModuleList([
            SwinTransformerAdaLnBlock(
                dim=dim * 1,
                mlp_ratio=mlp_ratio,
                nheads=nheads[0],
                window_size=window_size,
                shifted=True if i // 2 == 0 else False,
                relative_pos_embedding=relative_pos_embedding
            ) for i in range(1, depth[0] + 1)
        ])
        self.patch_merge0 = PatchMerging(dim * 1, dim * 2, downscaling_factor=2)

        self.block1 = nn.ModuleList([
            SwinTransformerAdaLnBlock(
                dim=dim * 2,
                mlp_ratio=mlp_ratio,
                nheads=nheads[1],
                window_size=window_size,
                shifted=True if i // 2 == 0 else False,
                relative_pos_embedding=relative_pos_embedding
            ) for i in range(1, depth[1] + 1)
        ])
        self.patch_merge1 = PatchMerging(dim * 2, dim * 4, downscaling_factor=2)

        self.block2 = nn.ModuleList([
            SwinTransformerAdaLnBlock(
                dim=dim * 4,
                mlp_ratio=mlp_ratio,
                nheads=nheads[2],
                window_size=window_size,
                shifted=True if i // 2 == 0 else False,
                relative_pos_embedding=relative_pos_embedding
            ) for i in range(1, depth[2] + 1)
        ])
        self.patch_merge2 = PatchMerging(dim * 4, dim * 8, downscaling_factor=2)

    def forward(self, x, t):

        x = x.permute(0, 2, 3, 1)
        skip_connections = []
        x = self.patch_embed(x)

        x = block_forward(self.block0, x, t)
        skip_connections.append(x)
        x = self.patch_merge0(x)

        x = block_forward(self.block1, x, t)
        skip_connections.append(x)
        x = self.patch_merge1(x)

        x = block_forward(self.block2, x, t)
        skip_connections.append(x)
        x = self.patch_merge2(x)

        return x, skip_connections


class SwinUnetDecoder(nn.Module):
    
    def __init__(self, channels, dim, patch_size, depth, mlp_ratio, nheads, window_size, relative_pos_embedding):
        super().__init__()

        self.patch_expand0 = PatchExpanding(dim * 8, dim * 4, upscaling_factor=2)
        self.block0 = nn.ModuleList([
            SwinTransformerAdaLnBlock(
                dim=dim * 4,
                mlp_ratio=mlp_ratio,
                nheads=nheads[2],
                window_size=window_size,
                shifted=True if i // 2 == 0 else False,
                relative_pos_embedding=relative_pos_embedding
            ) for i in range(1, depth[2] + 1)
        ])
        self.skip0 = nn.Linear(dim * 4 * 2, dim * 4, bias=False)

        self.patch_expand1 = PatchExpanding(dim * 4, dim * 2, upscaling_factor=2)
        self.block1 = nn.ModuleList([
            SwinTransformerAdaLnBlock(
                dim=dim * 2,
                mlp_ratio=mlp_ratio,
                nheads=nheads[1],
                window_size=window_size,
                shifted=True if i // 2 == 0 else False,
                relative_pos_embedding=relative_pos_embedding
            ) for i in range(1, depth[1] + 1)
        ])
        self.skip1 = nn.Linear(dim * 2 * 2, dim * 2, bias=False)

        self.patch_expand2 = PatchExpanding(dim * 2, dim * 1, upscaling_factor=2)
        self.block2 = nn.ModuleList([
            SwinTransformerAdaLnBlock(
                dim=dim * 1,
                mlp_ratio=mlp_ratio,
                nheads=nheads[0],
                window_size=window_size,
                shifted=True if i // 2 == 0 else False,
                relative_pos_embedding=relative_pos_embedding
            ) for i in range(1, depth[0] + 1)
        ])
        self.skip2 = nn.Linear(dim * 1 * 2, dim * 1, bias=False)

        self.patch_to_image = PatchExpanding(dim, channels, patch_size)

    def forward(self, x, skip_connect, t):

        x = self.patch_expand0(x)
        x = th.cat((x, skip_connect[2]), dim=-1)
        x = self.skip0(x)
        x = block_forward(self.block0, x, t)

        x = self.patch_expand1(x)
        x = th.cat((x, skip_connect[1]), dim=-1)
        x = self.skip1(x)
        x = block_forward(self.block1, x, t)

        x = self.patch_expand2(x)
        x = th.cat((x, skip_connect[0]), dim=-1)
        x = self.skip2(x)
        x = block_forward(self.block2, x, t)

        x = self.patch_to_image(x)

        return x.permute(0, 3, 1, 2)


#############################################
# SwinUnet: 条件的处理采用直接拼接
#############################################
class SwinUnet(nn.Module):

    def __init__(self, channels, dim, mlp_ratio, patch_size, window_size, depth, nheads,
                 relative_pos_embedding=True, use_condition=True):
        super().__init__()

        self.time_embed = SinusoidalTimeEmb(8 * dim)

        self.encoder = SwinUnetEncoder(channels=2 * channels if use_condition else channels, dim=dim, patch_size=patch_size,
                                       depth=depth[:3], mlp_ratio=mlp_ratio, nheads=nheads[:3],
                                       window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        self.bottleneck = nn.ModuleList([
            SwinTransformerAdaLnBlock(
                dim=8 * dim,
                mlp_ratio=mlp_ratio,
                nheads=nheads[-1],
                window_size=window_size,
                shifted=True if i // 2 == 0 else False,
                relative_pos_embedding=relative_pos_embedding
            ) for i in range(1, depth[-1] + 1)
        ])

        self.decoder = SwinUnetDecoder(channels=channels, dim=dim, patch_size=patch_size,
                                       depth=depth[:3], mlp_ratio=mlp_ratio, nheads=nheads[:3],
                                       window_size=window_size, relative_pos_embedding=relative_pos_embedding)

    def forward(self, x, t):

        t = self.time_embed(t)

        x, skip_connection = self.encoder(x, t)

        x = block_forward(self.bottleneck, x, t)

        return self.decoder(x, skip_connection, t)
"load_data.py"
import os
import numpy as np

from PIL import Image
from torch.utils.data import Dataset


def is_image_file(file_path):
    # 定义常见的图片文件扩展名
    image_extensions = {'.jpg', '.jpeg', '.png'}
    # 获取文件的扩展名并判断是否在图片扩展名集合中
    file_extension = os.path.splitext(file_path)[1].lower()
    return file_extension in image_extensions


class CustomDataset(Dataset):

    def __init__(self, path, img_size=None, sr_ratio=8):
        super().__init__()

        files = os.listdir(path)

        self.img_size = img_size

        self.files = []

        for file in files:
            self.files.append(os.path.join(path, file))

        self.ratio = sr_ratio

    def __len__(self):

        return len(self.files)

    def __getitem__(self, idx):

        hr_img = Image.open(self.files[idx])

        if self.img_size is not None:
            hr_img = hr_img.resize(self.img_size)
            hr_size = hr_img.size

        else:
            hr_size = hr_img.size

            hr_size = ((hr_size[0] // 256 + 1) * 256, (hr_size[1] // 256 + 1) * 256)
            hr_img = hr_img.resize(hr_size)

        hr_arr = np.array(hr_img).transpose(2, 0, 1) / 255.

        lr_img = hr_img.resize((hr_size[0] // self.ratio, hr_size[1] // self.ratio))
        lr_img = lr_img.resize(hr_size)
        lr_arr = np.array(lr_img).transpose(2, 0, 1) / 255.

        lr_arr = 2 * (lr_arr - 0.5)
        hr_arr = 2 * (hr_arr - 0.5)

        return lr_arr, hr_arr


class ImageNetDataset(Dataset):

    def __init__(self, path, img_size=(256, 256), sr_ratio=8):
        super().__init__()

        self.img_size = img_size

        class_dirs = os.listdir(path)

        self.files = []

        for class_dir in class_dirs:

            files = os.listdir(os.path.join(path, class_dir))

            for file in files:

                if is_image_file(os.path.join(path, class_dir, file)):
                    self.files.append(os.path.join(path, class_dir, file))

        self.ratio = sr_ratio

    def __len__(self):

        return len(self.files)

    def __getitem__(self, idx):

        hr_img = Image.open(self.files[idx])
        hr_img = hr_img.convert("RGB")

        if self.img_size is not None:
            hr_img = hr_img.resize(self.img_size)

        hr_size = hr_img.size
        hr_arr = np.array(hr_img).transpose(2, 0, 1) / 255.

        lr_img = hr_img.resize((hr_size[0] // self.ratio, hr_size[1] // self.ratio))
        lr_img = lr_img.resize(hr_size)
        lr_arr = np.array(lr_img).transpose(2, 0, 1) / 255.

        lr_arr = 2 * (lr_arr - 0.5)
        hr_arr = 2 * (hr_arr - 0.5)

        return lr_arr, hr_arr
"training.py"
import torch
import numpy as np

from torch import optim
from tqdm import tqdm
from torch.autograd import Variable
from torch.utils.data import DataLoader

from load_data import CustomDataset
from scheduler import Scheduler
from SwinUnet import SwinUnet

if __name__ == '__main__':

    device = torch.device("mps")
    batch_size = 16
    lr = 1e-4
    epochs = 200
    denoise_steps = 1000
    sr_ratio = 8

    train_dataset = CustomDataset(
        "./DIV2K_train_HR", img_size=(512, 512), sr_ratio=sr_ratio,
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    model = SwinUnet(channels=3, dim=96, mlp_ratio=4, patch_size=4, window_size=8,
                    depth=[2, 2, 6, 2], nheads=[3, 6, 12, 24]).to(device)

    model.load_state_dict(torch.load("SwinUNet-SR8.pth", map_location=device))

    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = Scheduler(model, denoise_steps)

    model.train()
    for epoch in range(epochs):

        print('*' * 40)

        train_loss = []

        for i, data in tqdm(enumerate(train_loader, 1), total=len(train_loader)):

            x, y = data
            x = Variable(x).to(torch.float32).to(device)
            y = Variable(y).to(torch.float32).to(device)

            t = torch.randint(low=0, high=denoise_steps, size=(x.shape[0],)).to(device)
            training_loss = scheduler.training_losses(x, y, t)

            optimizer.zero_grad()
            training_loss.backward()
            optimizer.step()
            train_loss.append(training_loss.item())

        torch.save(model.state_dict(), f"unet-sr{sr_ratio}.pth")
        print('Finish  {}  Loss: {:.6f}'.format(epoch + 1, np.mean(train_loss)))

相关文章:

  • 深入理解pthread多线程编程:从基础到生产者-消费者模型
  • Android: Handler 的用法详解
  • 【工具】在 Visual Studio 中使用 Dotfuscator 对“C# 类库(DLL)或应用程序(EXE)”进行混淆
  • 关于 Nginx 配置中 proxy_set_header Host $host 的作用及其对 HTTP 请求头影响的详细说明,结合示例展示设置前后的差异
  • 【VSCode SSH 连接远程服务器】:身份验证时,出现 key: invalid format 的问题
  • 服务端向客户端推送数据的实现方案
  • Linux | I.MX6ULL 终结者底板原理图讲解完(第六天)
  • 关于亚马逊TTS的笔记
  • 银行回单识别技术应用与API服务解析
  • 1 分钟掌握 PlantUML,快速绘制 UML 类图!
  • Docker学习--本地镜像管理相关命令--docker history 命令
  • 在Windows下使用Docker部署Nacos注册中心(基于MySQL容器)
  • 初识C++(入门)
  • kubernetes》》k8s》》Deployment》》ClusterIP、LoadBalancer、Ingress 内部访问、外边访问
  • 31天Python入门——第20天:魔法方法详解
  • TruPlasma RF 1002-G2/13 软件 TruPlasma RF 1003-G2/13软件 TRUMPF 调试监控软件
  • SQL Server:用户权限
  • 系统设计:高并发策略与缓存设计
  • 003-JMeter发起请求详解
  • LVS高可用负载均衡
  • 张巍任中共河南省委副书记
  • 陕西省市监局通报5批次不合格食品,涉添加剂超标、微生物污染等问题
  • 六省会共建交通枢纽集群,中部离经济“第五极”有多远?
  • 王伟妻子人民日报撰文:81192,一架永不停航的战机
  • 上海市重大工程一季度开局良好,多项生态类项目按计划实施
  • 刘强东坐镇京东一线:管理层培训1800人次,最注重用户体验