当前位置: 首页 > news >正文

网站 模板下载陕西富通建设有限公司网站

网站 模板下载,陕西富通建设有限公司网站,上海 微信网站 建站,官网服务器一般一年多少钱PyTorch生成式人工智能——PatchGAN详解与实现0. 前言1. PatchGAN 核心原理2. 网络架构3. 实现 PatchGAN3.1 数据集加载3.2 模型构建3.3 模型训练相关链接0. 前言 在生成对抗网络 (Generative Adversarial Network, GAN) 的发展历程中,如何提升生成图像的细节质量一…

PyTorch生成式人工智能——PatchGAN详解与实现

    • 0. 前言
    • 1. PatchGAN 核心原理
    • 2. 网络架构
    • 3. 实现 PatchGAN
      • 3.1 数据集加载
      • 3.2 模型构建
      • 3.3 模型训练
    • 相关链接

0. 前言

在生成对抗网络 (Generative Adversarial Network, GAN) 的发展历程中,如何提升生成图像的细节质量一直是研究者关注的核心问题。传统的全局判别器往往过于关注图像的整体结构,而忽略了局部细节的真实性,导致生成的面部图像缺乏细腻的纹理特征。正是在这样的技术背景下,PatchGAN 作为一种创新的判别器架构脱颖而出。它通过将图像分割为多个局部区域并进行独立判别,极大地增强了对高频细节的感知能力。本节将深入探讨 PatchGAN 相关原理,并使用 PyTorch 从零开始实现 PatchGAN,并使用 huggan/selfie2anime 数据集训练模型生成动漫人脸图像。

1. PatchGAN 核心原理

PatchGAN 的设计是为了解决图像生成任务中一个常见问题:传统的判别器 (Discriminator) 需要将一整张图片输入,然后输出单一标量值用于预测输入是真实/虚假。这可能导致生成图片在整体结构上看起来合理,但局部细节模糊或缺乏高频信息。
PatchGAN 判别器不是将整张图像分类为真或假,而是将图像分割成重叠的 N x N 图像块 (Patche),并尝试对每个图像块的真假进行分类,如下图所示。最终,判别器的输出不是一个值,而是一个矩阵(或特征图),其中的每个元素代表对应图像块为真实图像的概率。

PatchGAN

PatchGAN 提供了全局图像块级约束,为了进一步强化关键细节可以采用局部判别 (Local Discriminator)。全局判别容易忽略微小纹理(如眼睛高光、睫毛、皮肤纹理),局部判别补强这些高频信息。选取图像中的局部裁剪(例如 128×128 )(随机或其它感兴趣区域),专门训练一个判别器区分该局部图像的真实/虚假。

2. 网络架构

模型采用一个生成器与两个判别器,两个分别提供全局图像块级约束(判别器 D_patch )与局部关键细节约束(判别器 D_local),判别器 D_patch 接受生成器产生的图像与真实图像作为输入,并输出 一个 M x M 的矩阵(例如 30x30, 70x70 等,取决于架构深度和输入大小),矩阵中的每个元素 (i, j) 对应于输入图像中的一个感受野 (Receptive Field),即一个图像块 Patch。计算这个 M x M 输出矩阵与一个全为 1 (真实)或 0 (虚假)的同样大小矩阵之间的损失,使用 BCEWithLogitsLoss (或 hinge loss)。
判别器 D_local 与原始生成对抗网络 (Generative Adversarial Network, GAN) 类似,不同之处在于接收图像中的局部裁剪区域作为输入,用于强化细节纹理与局部结构的真实感。

3. 实现 PatchGAN

接下来,我们使用 Hugging Facehuggan/selfie2anime 数据集(仅动漫脸部分 imageB )训练 PatchGAN,并采用 PatchGAN + 局部判别器的复合判别策略提升局部细节质量。

3.1 数据集加载

(1) 首先,导入所需库,并定义超参数与设备:

import os
import math
import random
from pathlib import Path
from tqdm import tqdmimport torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from PIL import Image
from datasets import load_datasetIMG_SIZE = 256
BATCH_SIZE = 16
Z_DIM = 256
LR = 2e-4
BETA1, BETA2 = 0.5, 0.999
EPOCHS = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLES_DIR = Path("samples")
CHECKPOINT_DIR = Path("checkpoints")
SAMPLES_DIR.mkdir(exist_ok=True, parents=True)
CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)

(2) 下载huggan/selfie2anime 数据集并定义数据集类 HFImageDataset()

def prepare_hf_dataset(image_field="imageB", split="train"):print("Loading dataset from Hugging Face: huggan/selfie2anime ...")hf = load_dataset("huggan/selfie2anime", split=split)  # split='train'print("Dataset loaded. Rows:", len(hf))return hfclass HFImageDataset(Dataset):def __init__(self, hf_dataset, image_field="imageB", transform=None):self.ds = hf_datasetself.image_field = image_fieldself.transform = transformdef __len__(self):return len(self.ds)def __getitem__(self, idx):item = self.ds[idx]img = item[self.image_field]if isinstance(img, (list, tuple)):img = img[0]if not isinstance(img, Image.Image):img = Image.fromarray(img)if self.transform:img = self.transform(img)return img

(3) 预处理使用 RandomCrop 缩放后裁剪,并归一化到 [-1,1],便于在生成器中使用 tanh 输出:

train_transform = transforms.Compose([transforms.Resize(int(IMG_SIZE * 1.12), interpolation=Image.BICUBIC),transforms.RandomCrop(IMG_SIZE),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5]*3, [0.5]*3),
])

(4) 加载数据集,且仅使用动漫脸图像 (imageB):

hf = prepare_hf_dataset(image_field="imageB", split="train")
dataset = HFImageDataset(hf, image_field="imageB", transform=train_transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

3.2 模型构建

(1) 定义生成器,先将长度为 Z_DIM 的噪声通过全连接映射到 ngf x 4 x 4,再通过一系列 ConvTranspose2d 逐步放大到 256×256,输出用 tanh 映射到 [-1, 1],与数据归一化一致:

class Generator(nn.Module):def __init__(self, z_dim=Z_DIM, ngf=1024, out_channels=3):super().__init__()# 从 z 映射到 4x4 特征图self.fc = nn.Linear(z_dim, ngf * 4 * 4)self.net = nn.Sequential(# 4x4 -> 8x8nn.ConvTranspose2d(ngf, ngf//2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//2),nn.ReLU(True),# 8x8 -> 16x16nn.ConvTranspose2d(ngf//2, ngf//4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//4),nn.ReLU(True),# 16x16 -> 32x32nn.ConvTranspose2d(ngf//4, ngf//8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//8),nn.ReLU(True),# 32x32 -> 64x64nn.ConvTranspose2d(ngf//8, ngf//16, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//16),nn.ReLU(True),# 64x64 -> 128x128nn.ConvTranspose2d(ngf//16, ngf//32, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//32),nn.ReLU(True),# 128x128 -> 256x256nn.ConvTranspose2d(ngf//32, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 最后一层输出 3 通道nn.Conv2d(64, out_channels, kernel_size=3, padding=1),nn.Tanh())def forward(self, z):x = self.fc(z).view(z.size(0), -1, 4, 4)  # B x ngf x 4 x 4return self.net(x)

(2) 定义 PatchGAN 判别器,输出为特征图(每个位置对应一个 patch 的判别),在训练过程中,我们可以对这个输出使用 BCEWithLogitsLoss 损失,标签 1 对应真实 patch0 对应虚假 patch

class PatchDiscriminator(nn.Module):def __init__(self, in_channels=3, ndf=64):super().__init__()# 多层卷积下采样,最后输出单通道特征图(判别 map)def conv_block(in_c, out_c, stride, use_bn=True):layers = [nn.Conv2d(in_c, out_c, kernel_size=4, stride=stride, padding=1, bias=not use_bn)]if use_bn:layers.append(nn.BatchNorm2d(out_c))layers.append(nn.LeakyReLU(0.2, inplace=True))return layerslayers = []layers += [nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]layers += conv_block(ndf, ndf*2, stride=2)layers += conv_block(ndf*2, ndf*4, stride=2)layers += conv_block(ndf*4, ndf*8, stride=1)  # keep stride=1 to keep patch granularitylayers += [nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=1)]  # 单通道 logits mapself.model = nn.Sequential(*layers)def forward(self, x):return self.model(x)  # 输出 shape: (B,1,H_p,W_p)

(3) 定义局部判别器,最终输出一个标量 logit (每个局部 patch 一个判别),训练过程中我们从真实图像与生成图像中随机裁剪局部 patch (例如 128×128) 传递给局部判别器:

class LocalDiscriminator(nn.Module):def __init__(self, in_channels=3, ndf=64):super().__init__()self.net = nn.Sequential(nn.Conv2d(in_channels, ndf, 4, 2, 1), nn.LeakyReLU(0.2, True),nn.Conv2d(ndf, ndf*2, 4, 2, 1), nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),nn.Conv2d(ndf*2, ndf*4, 4, 2, 1), nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),nn.Conv2d(ndf*4, ndf*8, 4, 2, 1), nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),nn.AdaptiveAvgPool2d(1),nn.Conv2d(ndf*8, 1, kernel_size=1)  # 输出 1x1 logit)def forward(self, x):# x 形状 (B,3,H_loc,W_loc)(例如 128x128)out = self.net(x)  # B x 1 x 1 x 1return out.view(x.size(0), 1)  # (B,1)

3.3 模型训练

(1) 初始化模型、优化器以及损失:

G = Generator(z_dim=Z_DIM).to(DEVICE)
D_patch = PatchDiscriminator().to(DEVICE)
D_local = LocalDiscriminator().to(DEVICE)# 初始化权重
def init_weights(m):if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):nn.init.normal_(m.weight.data, 0.0, 0.02)if hasattr(m, "bias") and m.bias is not None:nn.init.constant_(m.bias.data, 0)G.apply(init_weights)
D_patch.apply(init_weights)
D_local.apply(init_weights)# 优化器
opt_G = optim.Adam(G.parameters(), lr=LR, betas=(BETA1, BETA2))
opt_D = optim.Adam(list(D_patch.parameters()) + list(D_local.parameters()), lr=LR, betas=(BETA1, BETA2))# 损失
bce_loss = nn.BCEWithLogitsLoss().to(DEVICE)

(2) 定义辅助函数,random_local_patches() 用于在每张图像上随机裁剪同一大小的局部 patch,训练局部判别器,save_samples() 把生成图像从 [-1,1] 映射回 [0,1] 并保存:

def random_local_patches(imgs, loc_size=128):B, C, H, W = imgs.shapepatches = []for i in range(B):top = random.randint(0, H - loc_size)left = random.randint(0, W - loc_size)patches.append(imgs[i:i+1, :, top:top+loc_size, left:left+loc_size])return torch.cat(patches, dim=0).to(imgs.device)  # (B,3,loc_size,loc_size)def save_samples(fake_imgs, step, nrow=4):grid = utils.make_grid((fake_imgs + 1) / 2.0, nrow=nrow, padding=2, normalize=False)save_path = SAMPLES_DIR / f"sample_{step:06d}.png"utils.save_image(grid, save_path)print("Saved samples to", save_path)

(3) 定义训练流程,判别器训练阶段同时训练 Patch 判别器和局部判别器;Patch 判别器的损失与局部判别器按比例合并,生成器阶段以欺骗两个判别器为目标:

global_step = 0
sample_z = torch.randn(16, Z_DIM, device=DEVICE)  # 固定噪声用于可视化for epoch in range(EPOCHS):loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")for real in loop:real = real.to(DEVICE)  # (B,3,256,256)B = real.size(0)# Train Discriminators (patch + local)opt_D.zero_grad()# Patch discriminator on reallogits_real_patch = D_patch(real)# create labels (same shape as logits)labels_real_patch = torch.ones_like(logits_real_patch, device=DEVICE) * 0.9  # label smoothingloss_D_patch_real = bce_loss(logits_real_patch, labels_real_patch)# Generate fake imagesz = torch.randn(B, Z_DIM, device=DEVICE)fake = G(z)logits_fake_patch = D_patch(fake.detach())labels_fake_patch = torch.zeros_like(logits_fake_patch, device=DEVICE)loss_D_patch_fake = bce_loss(logits_fake_patch, labels_fake_patch)loss_D_patch = (loss_D_patch_real + loss_D_patch_fake) * 0.5# Local discriminator on random patchesreal_local = random_local_patches(real, loc_size=128)  # (B,3,128,128)fake_local = random_local_patches(fake.detach(), loc_size=128)logits_real_local = D_local(real_local)  # (B,1)logits_fake_local = D_local(fake_local)  # (B,1)labels_real_local = torch.ones_like(logits_real_local, device=DEVICE) * 0.9labels_fake_local = torch.zeros_like(logits_fake_local, device=DEVICE)loss_D_local = 0.5 * (bce_loss(logits_real_local, labels_real_local) +bce_loss(logits_fake_local, labels_fake_local))# Total discriminator lossloss_D = loss_D_patch + 0.7 * loss_D_localloss_D.backward()opt_D.step()# Train Generatoropt_G.zero_grad()z2 = torch.randn(B, Z_DIM, device=DEVICE)fake2 = G(z2)# Try to fool both discriminators# Patch GAN objective (we want D_patch(fake) -> 1)logits_fake_patch_forG = D_patch(fake2)labels_G_patch = torch.ones_like(logits_fake_patch_forG, device=DEVICE)loss_G_patch = bce_loss(logits_fake_patch_forG, labels_G_patch)# Local objectivefake_local2 = random_local_patches(fake2, loc_size=128)logits_fake_local_forG = D_local(fake_local2)labels_G_local = torch.ones_like(logits_fake_local_forG, device=DEVICE)loss_G_local = bce_loss(logits_fake_local_forG, labels_G_local)loss_G = loss_G_patch + 0.7 * loss_G_localloss_G.backward()opt_G.step()global_step += 1if global_step % 10 == 0:loop.set_postfix({"loss_D": f"{loss_D.item():.4f}","loss_G": f"{loss_G.item():.4f}"})if global_step % 500 == 0:with torch.no_grad():samples = G(sample_z).cpu()save_samples(samples, global_step, nrow=4)torch.save({'epoch': epoch,'G_state_dict': G.state_dict(),'D_patch_state_dict': D_patch.state_dict(),'D_local_state_dict': D_local.state_dict(),'opt_G': opt_G.state_dict(),'opt_D': opt_D.state_dict()}, CHECKPOINT_DIR / f"ckpt_epoch_{epoch:03d}.pth")print(f"Saved checkpoint for epoch {epoch}")

生成结果如下所示,可以看到随着训练生成的图像越来越逼真:

生成结果

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(24)——使用PyTorch构建Transformer模型
PyTorch生成式人工智能(25)——基于Transformer实现机器翻译
PyTorch生成式人工智能(26)——使用PyTorch构建GPT模型
PyTorch生成式人工智能(27)——从零开始训练GPT模型
PyTorch生成式人工智能(28)——MuseGAN详解与实现

http://www.dtcms.com/a/541791.html

相关文章:

  • 淄博高效网站建设免费网站建站模板
  • Bootstrap4 Jumbotron详解与使用指南
  • IoT技术在产线实践中的应用
  • 合格VR大空间企业:核心要素有哪些?
  • 06.OpenStack网络管理
  • C++学习记录(23)智能指针
  • 网站内容策划方案wordpress底部版权信息修改
  • python 在class中几种函数的定义和用法
  • 电商数据中台基石:通过 API 构建淘宝商品实时数据源
  • 川崎机器人焊接电源气体省气
  • 理想汽车基于 Hologres + Flink 构建万亿级车联网信号实时分析平台
  • php教育视频网站开发如何做古诗词网站
  • 自发购卡网站在吗做手机建立网站软件
  • Git Tag 理解和使用
  • 如何写一个WebRTC ACE音频应用处理模块
  • 当机器拥有感觉:从电子皮肤到视频神经系统的具身智能革命
  • 快速搭建网站服务器网站推广策划方案
  • 【Linux基础知识系列:第一百六十三篇】创建虚拟网络:Linux网络桥接
  • 东方财经报道|深兰科技落户张江,AI医疗与情感陪伴并进,拓展智能未来版图
  • 跨区域多院区如何破局?浙江三甲医院实现核心医疗系统国产化重构
  • 做网站的怎么挣钱wordpress设计漂亮的页面
  • 【前端】圆角和非圆角实现渐变边框的区别(border)
  • 模板网站免费淘宝网页版官网
  • 苏州建设工程招标在哪个网站电子商务网站建设需要什么
  • 网站建设丿金手指花总9志愿北京网站注册
  • Linux魔法设备:/dev/null、/dev/zero、/dev/full详解
  • 【系统分析师】预测试卷一:综合知识题目及答案详解
  • 引领未来网络新体验——全方位解析外置WiFi模块的魅力与应用
  • 湖南的商城网站建设怎么注册公司公众号
  • vue3封装table组件及属性介绍