当前位置: 首页 > news >正文

生成式人工智能实战 | 自注意力生成对抗网络(Self-Attention Generative Adversarial Network, SAGAN)

生成式人工智能实战 | 自注意力生成对抗网络

    • 0. 前言
    • 1. SAGAN 核心原理
      • 1.1 自注意力机制
      • 1.2 谱归一化
    • 2. 实现 SAGAN
      • 2.1 生成器
      • 2.2 判别器
    • 3. 模型训练
      • 3.1 数据加载
      • 3.2 训练流程

0. 前言

自注意力生成对抗网络 (Self-Attention Generative Adversarial Network, SAGAN) 通过在传统深度卷积 GAN 中嵌入自注意力机制,有效捕捉图像中远距离的依赖关系,从而生成更具全局一致性和细节丰富的图像。SAGAN 在生成器和判别器中均引入自注意力模块,并结合谱归一化 (Spectral Normalization)、条件批归一化 (Conditional Batch Normalization)、投影鉴别器 (Projection Discriminator)及铰链损失 (Hinge Loss),显著提升了训练的稳定性与样本质量。本节将全面介绍 SAGAN 的核心原理与并使用 PyTorch 实现 SAGAN 模型。

1. SAGAN 核心原理

1.1 自注意力机制

传统卷积神经网络主要依赖局部感受野,难以捕捉图像中跨区域的全局结构信息,而深层堆叠卷积层虽具备理论潜力,但优化难度大且统计鲁棒性不足。
SAGAN 通过在特征图上计算自注意力 (Self-Attention),使得每个位置的输出既依赖其局部邻域信息,又能够利用所有位置的全局线索,从而改善生成结果的全局一致性。
在图像场景下,自注意力模块首先将输入特征图通过三个 1×1 卷积分别映射为查询 (Query)、键 (Key)、值 (Value) 三组特征,然后按下述步骤计算注意力输出:

  • 将查询和键张量经矩阵乘法计算注意力权重矩阵,并通过 softmax 归一化
  • 将注意力权重与值张量相乘得到加权特征表示
  • 乘以可学习缩放因子 γγγ 并与原始特征相加,实现残差连接,初始时 γ=0γ=0γ=0,网络可先依赖局部信息再逐步学习非局部依赖

下图显示了 SAGAN 中的注意力模块,其中 θθθφφφggg 对应于键,查询和值:

自注意力机制

接下来,实现自注意力机制,先将输入映射到查询/键/值空间,计算注意力矩阵,再将加权值与原始输入融合:

class Self_Attn(nn.Module):""" Self attention Layer"""def __init__(self,in_dim,activation):super(Self_Attn,self).__init__()self.chanel_in = in_dimself.activation = activationself.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)self.gamma = nn.Parameter(torch.zeros(1))self.softmax  = nn.Softmax(dim=-1) #def forward(self,x):"""inputs :x : input feature maps( B X C X W X H)returns :out : self attention value + input feature attention: B X N X N (N is Width*Height)"""m_batchsize,C,width ,height = x.size()proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)energy =  torch.bmm(proj_query,proj_key) # transpose checkattention = self.softmax(energy) # BX (N) X (N) proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X Nout = torch.bmm(proj_value,attention.permute(0,2,1) )out = out.view(m_batchsize,C,width,height)out = self.gamma*out + xreturn out,attention

1.2 谱归一化

为稳定对抗训练,SAGAN 对生成器和判别器的所有卷积权重均施加谱归一化,强制权重矩阵的最大奇异值为 1,从而满足 Lipschitz 连续性约束,抑制梯度爆炸或消失。以下是执行频谱归一化的步骤:

  • 卷积层中的权重是一个 4 维张量,因此第一步是将其重塑为 2D 矩阵,在这里我们保留权重的最后一个维度。现在,权重的形状为 (H×W, C)
  • N(0,1) 初始化向量 uuu
  • for 循环中,计算以下内容:
    • 用矩阵转置和矩阵乘法计算 V=(W⊤)UV =(W^\top)UV=(W)U
    • 用其 L2 范数归一化 VVV,即 V=V∣∣V∣∣2V = \frac {V}{||V||_2}V=∣∣V2V
    • 计算 U=WVU = WVU=WV
    • 用 L2 范数归一化 UUU,即 U=U∣∣U∣∣2U =\frac {U}{||U||_2}U=∣∣U2U
  • 计算谱范数为 U⊤WVU^\top WVUWV
  • 最后,将权重除以谱范数
def l2normalize(v, eps=1e-12):return v / (v.norm() + eps)class SpectralNorm(nn.Module):def __init__(self, module, name='weight', power_iterations=1):super(SpectralNorm, self).__init__()self.module = moduleself.name = nameself.power_iterations = power_iterationsif not self._made_params():self._make_params()def _update_u_v(self):u = getattr(self.module, self.name + "_u")v = getattr(self.module, self.name + "_v")w = getattr(self.module, self.name + "_bar")height = w.data.shape[0]for _ in range(self.power_iterations):v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))sigma = u.dot(w.view(height, -1).mv(v))setattr(self.module, self.name, w / sigma.expand_as(w))def _made_params(self):try:u = getattr(self.module, self.name + "_u")v = getattr(self.module, self.name + "_v")w = getattr(self.module, self.name + "_bar")return Trueexcept AttributeError:return Falsedef _make_params(self):w = getattr(self.module, self.name)height = w.data.shape[0]width = w.view(height, -1).data.shape[1]u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)u.data = l2normalize(u.data)v.data = l2normalize(v.data)w_bar = Parameter(w.data)del self.module._parameters[self.name]self.module.register_parameter(self.name + "_u", u)self.module.register_parameter(self.name + "_v", v)self.module.register_parameter(self.name + "_bar", w_bar)def forward(self, *args):self._update_u_v()return self.module.forward(*args)

2. 实现 SAGAN

2.1 生成器

生成器以噪声作为输入并经过多个上采样和卷积块,同时在中层插入自注意力,以生成具有全局一致性的细节:

class Generator(nn.Module):"""Generator."""def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):super(Generator, self).__init__()self.imsize = image_sizelayer1 = []layer2 = []layer3 = []last = []repeat_num = int(np.log2(self.imsize)) - 3mult = 2 ** repeat_num # 8layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))layer1.append(nn.BatchNorm2d(conv_dim * mult))layer1.append(nn.ReLU())curr_dim = conv_dim * multlayer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))layer2.append(nn.ReLU())curr_dim = int(curr_dim / 2)layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))layer3.append(nn.ReLU())if self.imsize == 64:layer4 = []curr_dim = int(curr_dim / 2)layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))layer4.append(nn.ReLU())self.l4 = nn.Sequential(*layer4)curr_dim = int(curr_dim / 2)self.l1 = nn.Sequential(*layer1)self.l2 = nn.Sequential(*layer2)self.l3 = nn.Sequential(*layer3)last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))last.append(nn.Tanh())self.last = nn.Sequential(*last)self.attn1 = Self_Attn( 128, 'relu')self.attn2 = Self_Attn( 64,  'relu')def forward(self, z):z = z.view(z.size(0), z.size(1), 1, 1)out=self.l1(z)out=self.l2(out)out=self.l3(out)out,p1 = self.attn1(out)out=self.l4(out)out,p2 = self.attn2(out)out=self.last(out)return out, p1, p2

2.2 判别器

判别器同样使用引入自注意力以捕捉全局依赖:

class Discriminator(nn.Module):"""Discriminator, Auxiliary Classifier."""def __init__(self, batch_size=64, image_size=64, conv_dim=64):super(Discriminator, self).__init__()self.imsize = image_sizelayer1 = []layer2 = []layer3 = []last = []layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1)))layer1.append(nn.LeakyReLU(0.1))curr_dim = conv_dimlayer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))layer2.append(nn.LeakyReLU(0.1))curr_dim = curr_dim * 2layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))layer3.append(nn.LeakyReLU(0.1))curr_dim = curr_dim * 2if self.imsize == 64:layer4 = []layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))layer4.append(nn.LeakyReLU(0.1))self.l4 = nn.Sequential(*layer4)curr_dim = curr_dim*2self.l1 = nn.Sequential(*layer1)self.l2 = nn.Sequential(*layer2)self.l3 = nn.Sequential(*layer3)last.append(nn.Conv2d(curr_dim, 1, 4))self.last = nn.Sequential(*last)self.attn1 = Self_Attn(256, 'relu')self.attn2 = Self_Attn(512, 'relu')def forward(self, x):out = self.l1(x)out = self.l2(out)out = self.l3(out)out,p1 = self.attn1(out)out=self.l4(out)out,p2 = self.attn2(out)out=self.last(out)return out.squeeze(), p1, p2

3. 模型训练

3.1 数据加载

本节中,我们将继续使用 Celeb A 人脸图像数据集构建 SAGAN

from torchvision import transforms
import torchvision.utils as vutils
import cv2, numpy as np
import torch
import os
from glob import glob
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"transform=transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])class Faces(Dataset):def __init__(self, folder):super().__init__()self.folder = folderself.images = sorted(glob(folder))def __len__(self):return len(self.images)def __getitem__(self, ix):image_path = self.images[ix]image = Image.open(image_path)image = transform(image)return imageds = Faces(folder='cropped_faces/*.jpg')
dataloader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=8)

3.2 训练流程

使用标准的 GAN 训练步骤。损失函数使用铰链损失,使用 Adam 优化器,生成器 (1e-4) 和判别器 (4e-4) 使用不同的初始学习率:

class Trainer(object):def __init__(self, data_loader):# Data loaderself.data_loader = data_loader# exact and lossself.adv_loss = 'wgan-gp'# Model hyper-parametersself.imsize = 64self.g_num = 5self.z_dim = 128self.g_conv_dim = 64self.d_conv_dim = 64self.parallel = Falseself.lambda_gp = 10self.total_step = 50000self.d_iters = 5self.batch_size = 32self.num_workers = 2self.g_lr = 0.0001self.d_lr = 0.0004self.lr_decay = 0.95self.beta1 = 0.0self.beta2 = 0.9self.dataset = data_loaderself.sample_path = 'sagan_samples'self.sample_step = 100self.log_step = 10self.build_model()def train(self):# Data iteratordata_iter = iter(self.data_loader)step_per_epoch = len(self.data_loader)# Fixed input for debuggingfixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))start = 0# Start timestart_time = time.time()for step in range(start, self.total_step):# ================== Train D ================== #self.D.train()self.G.train()try:real_images = next(data_iter)except:data_iter = iter(self.data_loader)real_images = next(data_iter)# Compute loss with real images# dr1, dr2, df1, df2, gf1, gf2 are attention scoresreal_images = tensor2var(real_images)d_out_real,dr1,dr2 = self.D(real_images)if self.adv_loss == 'wgan-gp':d_loss_real = - torch.mean(d_out_real)elif self.adv_loss == 'hinge':d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()# apply Gumbel Softmaxz = tensor2var(torch.randn(real_images.size(0), self.z_dim))fake_images,gf1,gf2 = self.G(z)d_out_fake,df1,df2 = self.D(fake_images)if self.adv_loss == 'wgan-gp':d_loss_fake = d_out_fake.mean()elif self.adv_loss == 'hinge':d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()# Backward + Optimized_loss = d_loss_real + d_loss_fakeself.reset_grad()d_loss.backward()self.d_optimizer.step()if self.adv_loss == 'wgan-gp':# Compute gradient penaltyalpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images)interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True)out,_,_ = self.D(interpolated)grad = torch.autograd.grad(outputs=out,inputs=interpolated,grad_outputs=torch.ones(out.size()).cuda(),retain_graph=True,create_graph=True,only_inputs=True)[0]grad = grad.view(grad.size(0), -1)grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)# Backward + Optimized_loss = self.lambda_gp * d_loss_gpself.reset_grad()d_loss.backward()self.d_optimizer.step()# ================== Train G and gumbel ================== ## Create random noisez = tensor2var(torch.randn(real_images.size(0), self.z_dim))fake_images,_,_ = self.G(z)# Compute loss with fake imagesg_out_fake,_,_ = self.D(fake_images)  # batch x nif self.adv_loss == 'wgan-gp':g_loss_fake = - g_out_fake.mean()elif self.adv_loss == 'hinge':g_loss_fake = - g_out_fake.mean()self.reset_grad()g_loss_fake.backward()self.g_optimizer.step()# Print out log infoif (step + 1) % self.log_step == 0:elapsed = time.time() - start_timeelapsed = str(datetime.timedelta(seconds=elapsed))print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "" ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".format(elapsed, step + 1, self.total_step, (step + 1),self.total_step , d_loss_real.item(),self.G.attn1.gamma.mean().item(), self.G.attn2.gamma.mean().item() ))# Sample imagesif (step + 1) % self.sample_step == 0:fake_images,_,_= self.G(fixed_z)save_image(denorm(fake_images.data),os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)))def build_model(self):self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda()self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda()if self.parallel:self.G = nn.DataParallel(self.G)self.D = nn.DataParallel(self.D)# Loss and optimizerself.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2])self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])self.c_loss = torch.nn.CrossEntropyLoss()# print networksprint(self.G)print(self.D)

模型训练完成后,使用生成器生成人脸图像:

生成结果

http://www.dtcms.com/a/272596.html

相关文章:

  • 深入理解fork():系统调用创建进程的原理与实践
  • 项目部署:nginx的安装和配置
  • 利用Pandas进行条件替换与向前填充
  • Linux中的命令连接符
  • Layui —— select
  • 图解Java数据容器(三):Queue
  • CAS登录工作流程简述
  • 【前端】【Echarts】ECharts 词云图(WordCloud)教学详解
  • Prompt提示词的主要类型和核心原则
  • 在vscode中和obsidian中使用Mermaid
  • Spring AI Alibaba(2)——通过Graph实现工作流
  • Flutter 与 Android 的互通几种方式
  • Linux 中 sed 命令
  • RedisJSON 路径语法深度解析与实战
  • Spring Boot + Javacv-platform:解锁音视频处理的多元场景
  • 【TCP/IP】12. 文件传输协议
  • MySQL索引操作全指南:创建、查看、优化
  • Debian-10编译安装Mysql-5.7.44 笔记250706
  • macOS 上安装 Miniconda + Conda-Forge
  • Jekyll + Chirpy + GitHub Pages 搭建博客
  • 如何使用Java WebSocket API实现客户端和服务器端的通信?
  • 蓝桥杯第十六届(2025)真题深度解析:思路复盘与代码实战
  • MinerU将PDF转成md文件,并分拣图片
  • Alibaba Druid主要配置
  • 图片合并pdf
  • 新手向:实现ATM模拟系统
  • TDengine 数据库建模最佳实践
  • Oracle 视图
  • Tomcat:Java Web应用的幕后英雄
  • 线性探针是什么:是一种用于探测神经网络中特定特征的工具