当前位置: 首页 > news >正文

第41周——人脸图像生成

目录

目录

目录

前言

一、设置超参数并导入数据

二、模型定义

三、训练数据

四、可视化

总结


前言

  •  🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

一、设置超参数并导入数据

import torch, random, random, os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTMLmanualSeed = 999  # 随机种子
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible resultsdataroot = "data/GAN"  # 数据路径
batch_size = 128  # 训练过程中的批次大小
image_size = 64   # 图像的尺寸(宽度和高度)
nz  = 100         # z潜在向量的大小(生成器输入的尺寸)
ngf = 64          # 生成器中的特征图大小
ndf = 64          # 判别器中的特征图大小
num_epochs = 50   # 训练的总轮数,如果你显卡不太行,可调小,但是生成效果会随之降低
lr    = 0.0002    # 学习率
beta1 = 0.5       # Adam优化器的Beta1超参数# 创建数据集
dataset = dset.ImageFolder(root=dataroot,transform=transforms.Compose([transforms.Resize(image_size),        # 调整图像大小transforms.CenterCrop(image_size),    # 中心裁剪图像transforms.ToTensor(),                # 将图像转换为张量transforms.Normalize((0.5, 0.5, 0.5), # 标准化图像张量(0.5, 0.5, 0.5)),]))# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,  # 批量大小shuffle=True,           # 是否打乱数据集num_workers=5 # 使用多个线程加载数据的工作进程数)# 选择要在哪个设备上运行代码
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print("使用的设备是:",device)# 绘制一些训练图像
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24], padding=2, normalize=True).cpu(),(1,2,0)))


二、模型定义

# 自定义权重初始化函数,作用于netG和netD
def weights_init(m):# 获取当前层的类名classname = m.__class__.__name__# 如果类名中包含'Conv',即当前层是卷积层if classname.find('Conv') != -1:# 使用正态分布初始化权重数据,均值为0,标准差为0.02nn.init.normal_(m.weight.data, 0.0, 0.02)# 如果类名中包含'BatchNorm',即当前层是批归一化层elif classname.find('BatchNorm') != -1:# 使用正态分布初始化权重数据,均值为1,标准差为0.02nn.init.normal_(m.weight.data, 1.0, 0.02)# 使用常数初始化偏置项数据,值为0nn.init.constant_(m.bias.data, 0)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(# 输入为Z,经过一个转置卷积层nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),  # 批归一化层,用于加速收敛和稳定训练过程nn.ReLU(True),  # ReLU激活函数# 输出尺寸:(ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 输出尺寸:(ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 输出尺寸:(ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 输出尺寸:(ngf) x 32 x 32nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),nn.Tanh()  # Tanh激活函数# 输出尺寸:3 x 64 x 64)def forward(self, input):return self.main(input)# 创建生成器
netG = Generator().to(device)
# 使用 "weights_init" 函数对所有权重进行随机初始化,
# 平均值(mean)设置为0,标准差(stdev)设置为0.02。
netG.apply(weights_init)
# 打印生成器模型
print(netG)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义判别器的主要结构,使用Sequential容器将多个层按顺序组合在一起self.main = nn.Sequential(# 输入大小为3 x 64 x 64nn.Conv2d(3, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):# 将输入通过判别器的主要结构进行前向传播return self.main(input)# 创建判别器模型
netD = Discriminator().to(device)# 应用 "weights_init" 函数来随机初始化所有权重
# 使用 mean=0, stdev=0.2 的方式进行初始化
netD.apply(weights_init)# 打印模型
print(netD)

三、训练数据

img_list = []  # 用于存储生成的图像列表
G_losses = []  # 用于存储生成器的损失列表
D_losses = []  # 用于存储判别器的损失列表
iters = 0  # 迭代次数print("Starting Training Loop...")  # 输出训练开始的提示信息
# 对于每个epoch(训练周期)
for epoch in range(num_epochs):# 对于dataloader中的每个batchfor i, data in enumerate(dataloader, 0):############################# (1) 更新判别器网络:最大化 log(D(x)) + log(1 - D(G(z)))############################# 使用真实图像样本训练netD.zero_grad()  # 清除判别器网络的梯度# 准备真实图像的数据real_cpu = data[0].to(device)b_size = real_cpu.size(0)label = torch.full((b_size,), real_label, dtype=torch.float, device=device)  # 创建一个全是真实标签的张量# 将真实图像样本输入判别器,进行前向传播output = netD(real_cpu).view(-1)# 计算真实图像样本的损失errD_real = criterion(output, label)# 通过反向传播计算判别器的梯度errD_real.backward()D_x = output.mean().item()  # 计算判别器对真实图像样本的输出的平均值## 使用生成图像样本训练# 生成一批潜在向量noise = torch.randn(b_size, nz, 1, 1, device=device)# 使用生成器生成一批假图像样本fake = netG(noise)label.fill_(fake_label)  # 创建一个全是假标签的张量# 将所有生成的图像样本输入判别器,进行前向传播output = netD(fake.detach()).view(-1)# 计算判别器对生成图像样本的损失errD_fake = criterion(output, label)# 通过反向传播计算判别器的梯度errD_fake.backward()D_G_z1 = output.mean().item()  # 计算判别器对生成图像样本的输出的平均值# 计算判别器的总损失,包括真实图像样本和生成图像样本的损失之和errD = errD_real + errD_fake# 更新判别器的参数optimizerD.step()############################# (2) 更新生成器网络:最大化 log(D(G(z)))###########################netG.zero_grad()  # 清除生成器网络的梯度label.fill_(real_label)  # 对于生成器成本而言,将假标签视为真实标签# 由于刚刚更新了判别器,再次将所有生成的图像样本输入判别器,进行前向传播output = netD(fake).view(-1)# 根据判别器的输出计算生成器的损失errG = criterion(output, label)# 通过反向传播计算生成器的梯度errG.backward()D_G_z2 = output.mean().item()  # 计算判别器对生成器输出的平均值# 更新生成器的参数optimizerG.step()# 输出训练统计信息if i % 400 == 0:print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'% (epoch, num_epochs, i, len(dataloader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))# 保存损失值以便后续绘图G_losses.append(errG.item())D_losses.append(errD.item())# 通过保存生成器在固定噪声上的输出来检查生成器的性能if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):with torch.no_grad():fake = netG(fixed_noise).detach().cpu()img_list.append(vutils.make_grid(fake, padding=2, normalize=True))iters += 1

四、可视化

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()# 创建一个大小为8x8的图形对象
fig = plt.figure(figsize=(8, 8))# 不显示坐标轴
plt.axis("off")# 将图像列表img_list中的图像转置并创建一个包含每个图像的单个列表ims
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]# 使用图形对象、图像列表ims以及其他参数创建一个动画对象ani
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)# 将动画以HTML形式呈现
HTML(ani.to_jshtml())# 从数据加载器中获取一批真实图像
real_batch = next(iter(dataloader))# 绘制真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))# 绘制上一个时期生成的假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()


总结

在本次实验中,我们基于 PyTorch 框架实现了一个典型的 GAN 训练流程,主要包含超参数配置、模型构建、训练循环以及结果可视化四个部分。整体思路清晰,结构紧凑,具体如下:

首先,在数据准备与超参数设置阶段,我们通过 manualSeed 固定随机数种子以保证结果可复现。接着设定了训练所需的关键参数,包括批量大小(batch_size)、图像尺寸(image_size)以及潜在向量维度(nz)等。数据集使用 torchvision.datasets.ImageFolder 加载,并在进入模型前完成了一系列标准化和变换处理。随后构建了 dataloader 来实现高效的批量数据读取,并自动选择在 CPU 还是 GPU 上运行。

接下来是模型设计部分。我们实现了生成器(Generator)和判别器(Discriminator)两个核心网络。判别器由多层卷积结构组成,而生成器则采用转置卷积逐步上采样生成图像。为了保证训练的稳定性,我们还自定义了 weights_init 函数,对模型的卷积层参数进行特定分布的初始化,并在定义后打印模型结构,方便检查网络搭建是否正确。

训练流程方面,每个 epoch 内会依次更新判别器和生成器。判别器的更新分为两步:先利用真实样本,再结合由生成器产生的伪造样本,来提升其区分真假图像的能力。生成器的优化则基于判别器的反馈,力图生成更以假乱真的图像。在训练过程中,我们记录了损失值,并周期性输出训练信息。同时,还利用固定噪声输入生成图像,便于直观观察生成效果随迭代的演变。

最后是结果可视化。一方面绘制了判别器和生成器的损失曲线,用于分析模型收敛情况;另一方面生成了一个动画,将不同训练阶段的样本输出串联起来,直观展示了生成器性能随时间的提升。此外,还将最终生成的图像与真实数据进行对比,可以清晰看到模型在最后阶段的表现

http://www.dtcms.com/a/343975.html

相关文章:

  • Java 性能优化实战(三):并发编程的 4 个优化维度
  • 第3课:Flutter基础组件
  • 上海人工智能实验室开源基于Intern-S1同等技术的轻量化开源多模态推理模型
  • WPF MVVM入门系列教程(TabControl绑定到列表并单独指定每一页内容)
  • 【nl2sql综述】2025最新综述解读
  • RAG学习(五)——查询构建、Text2SQL、查询重构与分发
  • Docker 部署 Microsoft SQL Server 指南
  • 第10课:性能优化
  • 如何将照片从iPhone传输到Mac?
  • 如何将文件从 iPad 转移到 iPhone 16/15
  • Node.js 开发 JavaScript SDK 包的完整指南(AI)
  • Cloudflare + nginx 限制ip访问的几种方式(白嫖cloudflare的ip数据库)
  • 数据分类分级的概念、标准解读及实现路径
  • 新零售“实—虚—合”逻辑下定制开发开源AI智能名片S2B2C商城小程序的机遇与演进
  • TCP/UDP详解(一)
  • 高并发的 Spring Boot Web 项目注意点
  • HTTP代理与SOCKS代理的区别、应用场景与选择指南
  • Figma 开源替代品 Penpot 安装与使用
  • 要区分一张图片中的网状图(如网格结构或规则纹理)和噪点(随机分布的干扰像素),比如电路的方法 计算机视觉
  • Unreal Engine ClassName Rule
  • HTTP接口鉴权方式
  • Java面试实战系列【并发篇】- CompletableFuture异步编程实战
  • Node.js中Express框架入门教程
  • vue/react使用h5player对接海康ws视频流实时播放,监控回放
  • 快速入门Vue3——初体验
  • CS创世SD NAND在北京君正平台和瑞芯微RK平台的应用
  • 高压、高功率时代,飞机电气系统如何保障安全?
  • 安全运维过程文档体系规范
  • 2025软件供应链安全技术路线未来趋势预测
  • Docker的安装