当前位置: 首页 > news >正文

网站 模板下载wordpress 短代码失效

网站 模板下载,wordpress 短代码失效,站点提交,网站短期就业培训班PyTorch生成式人工智能——PatchGAN详解与实现0. 前言1. PatchGAN 核心原理2. 网络架构3. 实现 PatchGAN3.1 数据集加载3.2 模型构建3.3 模型训练相关链接0. 前言 在生成对抗网络 (Generative Adversarial Network, GAN) 的发展历程中,如何提升生成图像的细节质量一…

PyTorch生成式人工智能——PatchGAN详解与实现

    • 0. 前言
    • 1. PatchGAN 核心原理
    • 2. 网络架构
    • 3. 实现 PatchGAN
      • 3.1 数据集加载
      • 3.2 模型构建
      • 3.3 模型训练
    • 相关链接

0. 前言

在生成对抗网络 (Generative Adversarial Network, GAN) 的发展历程中,如何提升生成图像的细节质量一直是研究者关注的核心问题。传统的全局判别器往往过于关注图像的整体结构,而忽略了局部细节的真实性,导致生成的面部图像缺乏细腻的纹理特征。正是在这样的技术背景下,PatchGAN 作为一种创新的判别器架构脱颖而出。它通过将图像分割为多个局部区域并进行独立判别,极大地增强了对高频细节的感知能力。本节将深入探讨 PatchGAN 相关原理,并使用 PyTorch 从零开始实现 PatchGAN,并使用 huggan/selfie2anime 数据集训练模型生成动漫人脸图像。

1. PatchGAN 核心原理

PatchGAN 的设计是为了解决图像生成任务中一个常见问题:传统的判别器 (Discriminator) 需要将一整张图片输入,然后输出单一标量值用于预测输入是真实/虚假。这可能导致生成图片在整体结构上看起来合理,但局部细节模糊或缺乏高频信息。
PatchGAN 判别器不是将整张图像分类为真或假,而是将图像分割成重叠的 N x N 图像块 (Patche),并尝试对每个图像块的真假进行分类,如下图所示。最终,判别器的输出不是一个值,而是一个矩阵(或特征图),其中的每个元素代表对应图像块为真实图像的概率。

PatchGAN

PatchGAN 提供了全局图像块级约束,为了进一步强化关键细节可以采用局部判别 (Local Discriminator)。全局判别容易忽略微小纹理(如眼睛高光、睫毛、皮肤纹理),局部判别补强这些高频信息。选取图像中的局部裁剪(例如 128×128 )(随机或其它感兴趣区域),专门训练一个判别器区分该局部图像的真实/虚假。

2. 网络架构

模型采用一个生成器与两个判别器,两个分别提供全局图像块级约束(判别器 D_patch )与局部关键细节约束(判别器 D_local),判别器 D_patch 接受生成器产生的图像与真实图像作为输入,并输出 一个 M x M 的矩阵(例如 30x30, 70x70 等,取决于架构深度和输入大小),矩阵中的每个元素 (i, j) 对应于输入图像中的一个感受野 (Receptive Field),即一个图像块 Patch。计算这个 M x M 输出矩阵与一个全为 1 (真实)或 0 (虚假)的同样大小矩阵之间的损失,使用 BCEWithLogitsLoss (或 hinge loss)。
判别器 D_local 与原始生成对抗网络 (Generative Adversarial Network, GAN) 类似,不同之处在于接收图像中的局部裁剪区域作为输入,用于强化细节纹理与局部结构的真实感。

3. 实现 PatchGAN

接下来,我们使用 Hugging Facehuggan/selfie2anime 数据集(仅动漫脸部分 imageB )训练 PatchGAN,并采用 PatchGAN + 局部判别器的复合判别策略提升局部细节质量。

3.1 数据集加载

(1) 首先,导入所需库,并定义超参数与设备:

import os
import math
import random
from pathlib import Path
from tqdm import tqdmimport torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from PIL import Image
from datasets import load_datasetIMG_SIZE = 256
BATCH_SIZE = 16
Z_DIM = 256
LR = 2e-4
BETA1, BETA2 = 0.5, 0.999
EPOCHS = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLES_DIR = Path("samples")
CHECKPOINT_DIR = Path("checkpoints")
SAMPLES_DIR.mkdir(exist_ok=True, parents=True)
CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)

(2) 下载huggan/selfie2anime 数据集并定义数据集类 HFImageDataset()

def prepare_hf_dataset(image_field="imageB", split="train"):print("Loading dataset from Hugging Face: huggan/selfie2anime ...")hf = load_dataset("huggan/selfie2anime", split=split)  # split='train'print("Dataset loaded. Rows:", len(hf))return hfclass HFImageDataset(Dataset):def __init__(self, hf_dataset, image_field="imageB", transform=None):self.ds = hf_datasetself.image_field = image_fieldself.transform = transformdef __len__(self):return len(self.ds)def __getitem__(self, idx):item = self.ds[idx]img = item[self.image_field]if isinstance(img, (list, tuple)):img = img[0]if not isinstance(img, Image.Image):img = Image.fromarray(img)if self.transform:img = self.transform(img)return img

(3) 预处理使用 RandomCrop 缩放后裁剪,并归一化到 [-1,1],便于在生成器中使用 tanh 输出:

train_transform = transforms.Compose([transforms.Resize(int(IMG_SIZE * 1.12), interpolation=Image.BICUBIC),transforms.RandomCrop(IMG_SIZE),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5]*3, [0.5]*3),
])

(4) 加载数据集,且仅使用动漫脸图像 (imageB):

hf = prepare_hf_dataset(image_field="imageB", split="train")
dataset = HFImageDataset(hf, image_field="imageB", transform=train_transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

3.2 模型构建

(1) 定义生成器,先将长度为 Z_DIM 的噪声通过全连接映射到 ngf x 4 x 4,再通过一系列 ConvTranspose2d 逐步放大到 256×256,输出用 tanh 映射到 [-1, 1],与数据归一化一致:

class Generator(nn.Module):def __init__(self, z_dim=Z_DIM, ngf=1024, out_channels=3):super().__init__()# 从 z 映射到 4x4 特征图self.fc = nn.Linear(z_dim, ngf * 4 * 4)self.net = nn.Sequential(# 4x4 -> 8x8nn.ConvTranspose2d(ngf, ngf//2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//2),nn.ReLU(True),# 8x8 -> 16x16nn.ConvTranspose2d(ngf//2, ngf//4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//4),nn.ReLU(True),# 16x16 -> 32x32nn.ConvTranspose2d(ngf//4, ngf//8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//8),nn.ReLU(True),# 32x32 -> 64x64nn.ConvTranspose2d(ngf//8, ngf//16, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//16),nn.ReLU(True),# 64x64 -> 128x128nn.ConvTranspose2d(ngf//16, ngf//32, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(ngf//32),nn.ReLU(True),# 128x128 -> 256x256nn.ConvTranspose2d(ngf//32, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 最后一层输出 3 通道nn.Conv2d(64, out_channels, kernel_size=3, padding=1),nn.Tanh())def forward(self, z):x = self.fc(z).view(z.size(0), -1, 4, 4)  # B x ngf x 4 x 4return self.net(x)

(2) 定义 PatchGAN 判别器,输出为特征图(每个位置对应一个 patch 的判别),在训练过程中,我们可以对这个输出使用 BCEWithLogitsLoss 损失,标签 1 对应真实 patch0 对应虚假 patch

class PatchDiscriminator(nn.Module):def __init__(self, in_channels=3, ndf=64):super().__init__()# 多层卷积下采样,最后输出单通道特征图(判别 map)def conv_block(in_c, out_c, stride, use_bn=True):layers = [nn.Conv2d(in_c, out_c, kernel_size=4, stride=stride, padding=1, bias=not use_bn)]if use_bn:layers.append(nn.BatchNorm2d(out_c))layers.append(nn.LeakyReLU(0.2, inplace=True))return layerslayers = []layers += [nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]layers += conv_block(ndf, ndf*2, stride=2)layers += conv_block(ndf*2, ndf*4, stride=2)layers += conv_block(ndf*4, ndf*8, stride=1)  # keep stride=1 to keep patch granularitylayers += [nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=1)]  # 单通道 logits mapself.model = nn.Sequential(*layers)def forward(self, x):return self.model(x)  # 输出 shape: (B,1,H_p,W_p)

(3) 定义局部判别器,最终输出一个标量 logit (每个局部 patch 一个判别),训练过程中我们从真实图像与生成图像中随机裁剪局部 patch (例如 128×128) 传递给局部判别器:

class LocalDiscriminator(nn.Module):def __init__(self, in_channels=3, ndf=64):super().__init__()self.net = nn.Sequential(nn.Conv2d(in_channels, ndf, 4, 2, 1), nn.LeakyReLU(0.2, True),nn.Conv2d(ndf, ndf*2, 4, 2, 1), nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),nn.Conv2d(ndf*2, ndf*4, 4, 2, 1), nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),nn.Conv2d(ndf*4, ndf*8, 4, 2, 1), nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),nn.AdaptiveAvgPool2d(1),nn.Conv2d(ndf*8, 1, kernel_size=1)  # 输出 1x1 logit)def forward(self, x):# x 形状 (B,3,H_loc,W_loc)(例如 128x128)out = self.net(x)  # B x 1 x 1 x 1return out.view(x.size(0), 1)  # (B,1)

3.3 模型训练

(1) 初始化模型、优化器以及损失:

G = Generator(z_dim=Z_DIM).to(DEVICE)
D_patch = PatchDiscriminator().to(DEVICE)
D_local = LocalDiscriminator().to(DEVICE)# 初始化权重
def init_weights(m):if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):nn.init.normal_(m.weight.data, 0.0, 0.02)if hasattr(m, "bias") and m.bias is not None:nn.init.constant_(m.bias.data, 0)G.apply(init_weights)
D_patch.apply(init_weights)
D_local.apply(init_weights)# 优化器
opt_G = optim.Adam(G.parameters(), lr=LR, betas=(BETA1, BETA2))
opt_D = optim.Adam(list(D_patch.parameters()) + list(D_local.parameters()), lr=LR, betas=(BETA1, BETA2))# 损失
bce_loss = nn.BCEWithLogitsLoss().to(DEVICE)

(2) 定义辅助函数,random_local_patches() 用于在每张图像上随机裁剪同一大小的局部 patch,训练局部判别器,save_samples() 把生成图像从 [-1,1] 映射回 [0,1] 并保存:

def random_local_patches(imgs, loc_size=128):B, C, H, W = imgs.shapepatches = []for i in range(B):top = random.randint(0, H - loc_size)left = random.randint(0, W - loc_size)patches.append(imgs[i:i+1, :, top:top+loc_size, left:left+loc_size])return torch.cat(patches, dim=0).to(imgs.device)  # (B,3,loc_size,loc_size)def save_samples(fake_imgs, step, nrow=4):grid = utils.make_grid((fake_imgs + 1) / 2.0, nrow=nrow, padding=2, normalize=False)save_path = SAMPLES_DIR / f"sample_{step:06d}.png"utils.save_image(grid, save_path)print("Saved samples to", save_path)

(3) 定义训练流程,判别器训练阶段同时训练 Patch 判别器和局部判别器;Patch 判别器的损失与局部判别器按比例合并,生成器阶段以欺骗两个判别器为目标:

global_step = 0
sample_z = torch.randn(16, Z_DIM, device=DEVICE)  # 固定噪声用于可视化for epoch in range(EPOCHS):loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")for real in loop:real = real.to(DEVICE)  # (B,3,256,256)B = real.size(0)# Train Discriminators (patch + local)opt_D.zero_grad()# Patch discriminator on reallogits_real_patch = D_patch(real)# create labels (same shape as logits)labels_real_patch = torch.ones_like(logits_real_patch, device=DEVICE) * 0.9  # label smoothingloss_D_patch_real = bce_loss(logits_real_patch, labels_real_patch)# Generate fake imagesz = torch.randn(B, Z_DIM, device=DEVICE)fake = G(z)logits_fake_patch = D_patch(fake.detach())labels_fake_patch = torch.zeros_like(logits_fake_patch, device=DEVICE)loss_D_patch_fake = bce_loss(logits_fake_patch, labels_fake_patch)loss_D_patch = (loss_D_patch_real + loss_D_patch_fake) * 0.5# Local discriminator on random patchesreal_local = random_local_patches(real, loc_size=128)  # (B,3,128,128)fake_local = random_local_patches(fake.detach(), loc_size=128)logits_real_local = D_local(real_local)  # (B,1)logits_fake_local = D_local(fake_local)  # (B,1)labels_real_local = torch.ones_like(logits_real_local, device=DEVICE) * 0.9labels_fake_local = torch.zeros_like(logits_fake_local, device=DEVICE)loss_D_local = 0.5 * (bce_loss(logits_real_local, labels_real_local) +bce_loss(logits_fake_local, labels_fake_local))# Total discriminator lossloss_D = loss_D_patch + 0.7 * loss_D_localloss_D.backward()opt_D.step()# Train Generatoropt_G.zero_grad()z2 = torch.randn(B, Z_DIM, device=DEVICE)fake2 = G(z2)# Try to fool both discriminators# Patch GAN objective (we want D_patch(fake) -> 1)logits_fake_patch_forG = D_patch(fake2)labels_G_patch = torch.ones_like(logits_fake_patch_forG, device=DEVICE)loss_G_patch = bce_loss(logits_fake_patch_forG, labels_G_patch)# Local objectivefake_local2 = random_local_patches(fake2, loc_size=128)logits_fake_local_forG = D_local(fake_local2)labels_G_local = torch.ones_like(logits_fake_local_forG, device=DEVICE)loss_G_local = bce_loss(logits_fake_local_forG, labels_G_local)loss_G = loss_G_patch + 0.7 * loss_G_localloss_G.backward()opt_G.step()global_step += 1if global_step % 10 == 0:loop.set_postfix({"loss_D": f"{loss_D.item():.4f}","loss_G": f"{loss_G.item():.4f}"})if global_step % 500 == 0:with torch.no_grad():samples = G(sample_z).cpu()save_samples(samples, global_step, nrow=4)torch.save({'epoch': epoch,'G_state_dict': G.state_dict(),'D_patch_state_dict': D_patch.state_dict(),'D_local_state_dict': D_local.state_dict(),'opt_G': opt_G.state_dict(),'opt_D': opt_D.state_dict()}, CHECKPOINT_DIR / f"ckpt_epoch_{epoch:03d}.pth")print(f"Saved checkpoint for epoch {epoch}")

生成结果如下所示,可以看到随着训练生成的图像越来越逼真:

生成结果

相关链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现
PyTorch生成式人工智能(20)——像素卷积神经网络(PixelCNN)
PyTorch生成式人工智能(24)——使用PyTorch构建Transformer模型
PyTorch生成式人工智能(25)——基于Transformer实现机器翻译
PyTorch生成式人工智能(26)——使用PyTorch构建GPT模型
PyTorch生成式人工智能(27)——从零开始训练GPT模型
PyTorch生成式人工智能(28)——MuseGAN详解与实现

http://www.dtcms.com/a/401460.html

相关文章:

  • 支持手机网站的空间用dw做的网站怎么上传
  • 浙江网站建设与维护书parallax wordpress
  • 网站设计评级安庆建设银行网站
  • 网站建设报告总结品牌建设部
  • 济南网站建设李尚荣模板网站seo
  • 东莞网站建设服务有什么用开发网站做图文水印逻辑
  • 网站建设怎样上传程序手机代码网站有哪些问题吗
  • 用wordpress制作网站图案设计网
  • 网站编程学习网页版微信登陆
  • 网站建设 cn歌曲网站源码
  • 无锡网站推广公司排名哈尔滨工程招投标信息网
  • 1有免费建网站站酷做网站
  • 地产网站互动营销新昌建设局网站
  • 住房和城乡建设部执法网站湛江做网站说
  • 网站域名实名认证临汾做网站
  • 潍坊网站建设解决方案wordpress插件无法安装插件
  • 柯桥区建设集团网站建阳网站建设wzjseo
  • 网站开发+自动生成缩略图保定建设局网站
  • 网站右侧二维码代码网站页面做成自适应优缺点
  • 怎么做电商卖东西东莞市网络seo推广服务机构
  • 四川网站建设的公司排名磁力搜索器 磁力猫在线
  • 湖州网站开发版面设计绘画
  • 网站建设开场白怎么说福建建设厅网站工程履约保险
  • pv3d 优秀网站建网站卖产品怎么样
  • 东道设计公司难进吗快速排名生客seo
  • 南昌做网站哪家专业如何搭建o2o网站
  • 石桥铺网站建设阿里云域名注册万网
  • 世界著名设计公司郑州网站建设搜索优化
  • python微信网站开发北京西站附近的景点有哪些
  • 佛山做网站制作公司一级a做爰网站免费