当前位置: 首页 > news >正文

深度学习G3周:CGAN入门(生成手势图像)

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

基础任务:

1.条件生成对抗网络(CGAN)的基本原理

2.CGAN是如何实现条件控制的

3.学习本文CGAN代码,并跑通代码

进阶任务:

生成指定手势的图像

一、理论知识

条件生成对抗网络(CGAN):在生成对抗网络(GAN)的基础上进行了一些改进。对于原始GAN的生成器而言,其生成的图像数据是随机不可预测的,因此,我们无法控制网络的输出,在实际操作中的可控性不强。

针对上述原始GAN无法生成具有特定属性的图像数据的问题,Mehdi Mirza等人在2014年提出了条件生成对抗网络,通过给原始生成对抗网络中的生成器G和判别器D增加额外的条件,如我们需要生成器G生成一张没有阴影的图像,此时判别器D就需要判断生成器所生成的图像是否是一张没有阴影的图像。

条件生成对抗网络的本质:将额外添加的信息融入到生成器和判别器中,其中添加的信息可以是图像的类别、人脸表情和其他辅助信息等,旨在把无监督学习的GAN转化为有监督学习的CGAN,便于网络能在我们的掌控下更好地进行训练。

CGAN网络结构图:

由图可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用。

二、准备工作

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import matplotlib.pyplot as plt
import datetimetorch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128

1.导入数据

train_transform = transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])train_dataset = datasets.ImageFolder(root="D:/study/data/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)def show_images(images):fig, ax = plt.subplots(figsize=(20, 20))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))def show_batch(dl):for images, _ in dl:show_images(images)breakshow_batch(train_loader)

遇到问题:

多线程缘故---去掉num_workers参数

image_shape = (3, 128, 128)
image_dim = int(np.prod(image_shape))
latent_dim = 100n_classes = 3
embedding_dim = 100

三、构建模型

def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)

1.构建生成器

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim), nn.Linear(embedding_dim, 16)            )self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),  nn.LeakyReLU(0.2, inplace=True)  )self.model = nn.Sequential( nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  nn.ReLU(True),            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.ReLU(True),     nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),nn.ReLU(True),       nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),nn.ReLU(True),       nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),nn.Tanh()  )def forward(self, inputs):noise_vector, label = inputs  label_output = self.label_conditioned_generator(label)     label_output = label_output.view(-1, 1, 4, 4)        latent_output = self.latent(noise_vector)     latent_output = latent_output.view(-1, 512, 4, 4) concat = torch.cat((latent_output, label_output), dim=1)image = self.model(concat)return image
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

输出:

from torchinfo import summary
summary(generator)
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)

输出:

 

2.构建鉴别器

import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim),     nn.Linear(embedding_dim, 3*128*128)         )self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False),      nn.LeakyReLU(0.2, inplace=True),             nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Flatten(),                               nn.Dropout(0.4),                            nn.Linear(4608, 1),                         nn.Sigmoid()                                )def forward(self, inputs):img, label = inputslabel_output = self.label_condition_disc(label)label_output = label_output.view(-1, 3, 128, 128)concat = torch.cat((img, label_output), dim=1)output = self.model(concat)return output
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)
summary(discriminator)
a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)
c = discriminator((a,b))
c.size()

输出:

 

四、训练模型

1.定义损失函数

adversarial_loss = nn.BCELoss() def generator_loss(fake_output, label):gen_loss = adversarial_loss(fake_output, label)return gen_lossdef discriminator_loss(output, label):disc_loss = adversarial_loss(output, label)return disc_loss

2.定义优化器

learning_rate = 0.0002G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

3.训练模型

代码逻辑结构图:

1.首先设置了训练的总轮数和用于存储每轮训练中判别器和生成器损失的列表

2.然后进行GAN模型的训练。在每轮训练中,它首先从数据加载器中加载真实图像和标签,然后计算判别器对真实图像的损失,接着从噪声向量中生成假图像,计算判别器对假图像的损失,计算判别器总体损失并反向传播更新判别器参数,然后计算生成器的损失并反向传播更新生成器的参数

3.最后,它打印当前轮次的判别器和生成器的平均损失,并将当前轮次的判别器和生成器的平均损失保存到列表中

4.在每10轮训练后,它会将生成的假图像保存为图片文件,并将当前轮次的生成器和判别器的权重保存到文件

num_epochs = 100D_loss_plot, G_loss_plot = [], []for epoch in range(1, num_epochs + 1):D_loss_list, G_loss_list = [], []for index, (real_images, labels) in enumerate(train_loader):D_optimizer.zero_grad()real_images = real_images.to(device)labels      = labels.to(device)labels = labels.unsqueeze(1).long()real_target = Variable(torch.ones(real_images.size(0), 1).to(device))fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)noise_vector = noise_vector.to(device)generated_image = generator((noise_vector, labels))output = discriminator((generated_image.detach(), labels))D_fake_loss = discriminator_loss(output, fake_target)D_total_loss = (D_real_loss + D_fake_loss) / 2D_loss_list.append(D_total_loss)D_total_loss.backward()D_optimizer.step()G_optimizer.zero_grad()G_loss = generator_loss(discriminator((generated_image, labels)), real_target)G_loss_list.append(G_loss)G_loss.backward()G_optimizer.step()print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ((epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list))))D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))if epoch%10 == 0:save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

输出:

 

五、模型分析

1.加载模型

generator.load_state_dict(torch.load('./training_weights/generator_epoch_100.pth'), strict=False)
generator.eval()               
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from numpy import linspace
from matplotlib import pyplot
from matplotlib import gridspecdef generate_latent_points(latent_dim, n_samples, n_classes=3):x_input = randn(latent_dim * n_samples)z_input = x_input.reshape(n_samples, latent_dim)return z_inputdef interpolate_points(p1, p2, n_steps=10):ratios = linspace(0, 1, num=n_steps)vectors = list()for ratio in ratios:v = (1.0 - ratio) * p1 + ratio * p2vectors.append(v)return asarray(vectors)pts = generate_latent_points(100, 2)
interpolated = interpolate_points(pts[0], pts[1])
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)output = None
for label in range(3):labels = torch.ones(10) * labellabels = labels.to(device)labels = labels.unsqueeze(1).long()print(labels.size())predictions = generator((interpolated, labels))predictions = predictions.permute(0,2,3,1)pred = predictions.detach().cpu()if output is None:output = predelse:output = np.concatenate((output,pred))
output.shape
nrow = 3
ncol = 10fig = plt.figure(figsize=(15,4))
gs = gridspec.GridSpec(nrow, ncol) k = 0
for i in range(nrow):for j in range(ncol):pred = (output[k, :, :, :] + 1 ) * 127.5pred = np.array(pred)  ax= plt.subplot(gs[i,j])ax.imshow(pred.astype(np.uint8))ax.set_xticklabels([])ax.set_yticklabels([])ax.axis('off')k += 1   plt.show()

五、总结

学习了条件生成对抗网络的基本原理和代码。了解CGAN是怎么实现条件控制。上次遇到的问题这次又忘记了,还是得多看。

http://www.dtcms.com/a/284440.html

相关文章:

  • 理解欧拉角:定义、转换与应用
  • HTTPS的工作原理及DNS的工作过程
  • 【LeetCode 热题 100】108. 将有序数组转换为二叉搜索树
  • SpringBoot使用ThreadLocal共享数据
  • 2021-07-21 VB窗体求范围质数(Excel复制工作簿)
  • Python 基础语法与数据类型(十三) - 实例方法、类方法、静态方法
  • 【测试100问】没有接口文档的情况下,如何做接口测试?
  • MinIO:开源对象存储解决方案的领先者
  • DiffPy-CMI详细安装教程
  • 【Vue进阶学习笔记】组合式API(Composition API)
  • Go 程序无法使用 /etc/resolv.conf 的 DNS 配置排查记录
  • React hooks——memo
  • 【软件开发】主流 AI 编码插件
  • 关于el-table异步获取数据渲染动态列数据赋值列数据渲染时title高度异常闪过问题
  • 深度解析:基于EasyX的C++黑白棋AI实现 | 算法核心+图形化实战
  • 数据呈现进阶:漏斗图与雷达图的实战指南
  • 基于Echarts的气象数据可视化网站系统的设计与实现(Python版)
  • Idea使用git不提示账号密码登录,而是输入token问题解决
  • 【解决方案】yakit流量转发到mitmproxy
  • 浅谈 awk 中管道的用法
  • zynq mpsoc switch级联ssd高速存储方案
  • 贴吧项目总结二
  • mysql——搭建MGR集群
  • CommonJS 功能介绍
  • 基于dcmtk的dicom工具 第二章 图像接受StoreSCP(2)
  • Python Day16
  • Java行为型模式---备忘录模式
  • 从零开始的云计算生活——第三十三天,关山阻隔,ELK日志分析
  • rtp传输推流h265
  • Unity使用GTCRN实现流式语音增强