生成式人工智能实战 | 自注意力生成对抗网络(Self-Attention Generative Adversarial Network, SAGAN)
生成式人工智能实战 | 自注意力生成对抗网络
- 0. 前言
- 1. SAGAN 核心原理
- 1.1 自注意力机制
- 1.2 谱归一化
- 2. 实现 SAGAN
- 2.1 生成器
- 2.2 判别器
- 3. 模型训练
- 3.1 数据加载
- 3.2 训练流程
0. 前言
自注意力生成对抗网络 (Self-Attention Generative Adversarial Network
, SAGAN
) 通过在传统深度卷积 GAN 中嵌入自注意力机制,有效捕捉图像中远距离的依赖关系,从而生成更具全局一致性和细节丰富的图像。SAGAN
在生成器和判别器中均引入自注意力模块,并结合谱归一化 (Spectral Normalization
)、条件批归一化 (Conditional Batch Normalization
)、投影鉴别器 (Projection Discriminator
)及铰链损失 (Hinge Loss
),显著提升了训练的稳定性与样本质量。本节将全面介绍 SAGAN
的核心原理与并使用 PyTorch
实现 SAGAN
模型。
1. SAGAN 核心原理
1.1 自注意力机制
传统卷积神经网络主要依赖局部感受野,难以捕捉图像中跨区域的全局结构信息,而深层堆叠卷积层虽具备理论潜力,但优化难度大且统计鲁棒性不足。
SAGAN
通过在特征图上计算自注意力 (Self-Attention
),使得每个位置的输出既依赖其局部邻域信息,又能够利用所有位置的全局线索,从而改善生成结果的全局一致性。
在图像场景下,自注意力模块首先将输入特征图通过三个 1×1
卷积分别映射为查询 (Query
)、键 (Key
)、值 (Value
) 三组特征,然后按下述步骤计算注意力输出:
- 将查询和键张量经矩阵乘法计算注意力权重矩阵,并通过
softmax
归一化 - 将注意力权重与值张量相乘得到加权特征表示
- 乘以可学习缩放因子 γγγ 并与原始特征相加,实现残差连接,初始时 γ=0γ=0γ=0,网络可先依赖局部信息再逐步学习非局部依赖
下图显示了 SAGAN
中的注意力模块,其中 θθθ,φφφ 和 ggg 对应于键,查询和值:
接下来,实现自注意力机制,先将输入映射到查询/键/值空间,计算注意力矩阵,再将加权值与原始输入融合:
class Self_Attn(nn.Module):""" Self attention Layer"""def __init__(self,in_dim,activation):super(Self_Attn,self).__init__()self.chanel_in = in_dimself.activation = activationself.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)self.gamma = nn.Parameter(torch.zeros(1))self.softmax = nn.Softmax(dim=-1) #def forward(self,x):"""inputs :x : input feature maps( B X C X W X H)returns :out : self attention value + input feature attention: B X N X N (N is Width*Height)"""m_batchsize,C,width ,height = x.size()proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)energy = torch.bmm(proj_query,proj_key) # transpose checkattention = self.softmax(energy) # BX (N) X (N) proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X Nout = torch.bmm(proj_value,attention.permute(0,2,1) )out = out.view(m_batchsize,C,width,height)out = self.gamma*out + xreturn out,attention
1.2 谱归一化
为稳定对抗训练,SAGAN
对生成器和判别器的所有卷积权重均施加谱归一化,强制权重矩阵的最大奇异值为 1
,从而满足 Lipschitz
连续性约束,抑制梯度爆炸或消失。以下是执行频谱归一化的步骤:
- 卷积层中的权重是一个
4
维张量,因此第一步是将其重塑为2D
矩阵,在这里我们保留权重的最后一个维度。现在,权重的形状为(H×W, C)
- 用
N(0,1)
初始化向量 uuu - 在
for
循环中,计算以下内容:- 用矩阵转置和矩阵乘法计算 V=(W⊤)UV =(W^\top)UV=(W⊤)U
- 用其 L2 范数归一化 VVV,即 V=V∣∣V∣∣2V = \frac {V}{||V||_2}V=∣∣V∣∣2V
- 计算 U=WVU = WVU=WV
- 用 L2 范数归一化 UUU,即 U=U∣∣U∣∣2U =\frac {U}{||U||_2}U=∣∣U∣∣2U
- 计算谱范数为 U⊤WVU^\top WVU⊤WV
- 最后,将权重除以谱范数
def l2normalize(v, eps=1e-12):return v / (v.norm() + eps)class SpectralNorm(nn.Module):def __init__(self, module, name='weight', power_iterations=1):super(SpectralNorm, self).__init__()self.module = moduleself.name = nameself.power_iterations = power_iterationsif not self._made_params():self._make_params()def _update_u_v(self):u = getattr(self.module, self.name + "_u")v = getattr(self.module, self.name + "_v")w = getattr(self.module, self.name + "_bar")height = w.data.shape[0]for _ in range(self.power_iterations):v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))sigma = u.dot(w.view(height, -1).mv(v))setattr(self.module, self.name, w / sigma.expand_as(w))def _made_params(self):try:u = getattr(self.module, self.name + "_u")v = getattr(self.module, self.name + "_v")w = getattr(self.module, self.name + "_bar")return Trueexcept AttributeError:return Falsedef _make_params(self):w = getattr(self.module, self.name)height = w.data.shape[0]width = w.view(height, -1).data.shape[1]u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)u.data = l2normalize(u.data)v.data = l2normalize(v.data)w_bar = Parameter(w.data)del self.module._parameters[self.name]self.module.register_parameter(self.name + "_u", u)self.module.register_parameter(self.name + "_v", v)self.module.register_parameter(self.name + "_bar", w_bar)def forward(self, *args):self._update_u_v()return self.module.forward(*args)
2. 实现 SAGAN
2.1 生成器
生成器以噪声作为输入并经过多个上采样和卷积块,同时在中层插入自注意力,以生成具有全局一致性的细节:
class Generator(nn.Module):"""Generator."""def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):super(Generator, self).__init__()self.imsize = image_sizelayer1 = []layer2 = []layer3 = []last = []repeat_num = int(np.log2(self.imsize)) - 3mult = 2 ** repeat_num # 8layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))layer1.append(nn.BatchNorm2d(conv_dim * mult))layer1.append(nn.ReLU())curr_dim = conv_dim * multlayer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))layer2.append(nn.ReLU())curr_dim = int(curr_dim / 2)layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))layer3.append(nn.ReLU())if self.imsize == 64:layer4 = []curr_dim = int(curr_dim / 2)layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))layer4.append(nn.ReLU())self.l4 = nn.Sequential(*layer4)curr_dim = int(curr_dim / 2)self.l1 = nn.Sequential(*layer1)self.l2 = nn.Sequential(*layer2)self.l3 = nn.Sequential(*layer3)last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))last.append(nn.Tanh())self.last = nn.Sequential(*last)self.attn1 = Self_Attn( 128, 'relu')self.attn2 = Self_Attn( 64, 'relu')def forward(self, z):z = z.view(z.size(0), z.size(1), 1, 1)out=self.l1(z)out=self.l2(out)out=self.l3(out)out,p1 = self.attn1(out)out=self.l4(out)out,p2 = self.attn2(out)out=self.last(out)return out, p1, p2
2.2 判别器
判别器同样使用引入自注意力以捕捉全局依赖:
class Discriminator(nn.Module):"""Discriminator, Auxiliary Classifier."""def __init__(self, batch_size=64, image_size=64, conv_dim=64):super(Discriminator, self).__init__()self.imsize = image_sizelayer1 = []layer2 = []layer3 = []last = []layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1)))layer1.append(nn.LeakyReLU(0.1))curr_dim = conv_dimlayer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))layer2.append(nn.LeakyReLU(0.1))curr_dim = curr_dim * 2layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))layer3.append(nn.LeakyReLU(0.1))curr_dim = curr_dim * 2if self.imsize == 64:layer4 = []layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))layer4.append(nn.LeakyReLU(0.1))self.l4 = nn.Sequential(*layer4)curr_dim = curr_dim*2self.l1 = nn.Sequential(*layer1)self.l2 = nn.Sequential(*layer2)self.l3 = nn.Sequential(*layer3)last.append(nn.Conv2d(curr_dim, 1, 4))self.last = nn.Sequential(*last)self.attn1 = Self_Attn(256, 'relu')self.attn2 = Self_Attn(512, 'relu')def forward(self, x):out = self.l1(x)out = self.l2(out)out = self.l3(out)out,p1 = self.attn1(out)out=self.l4(out)out,p2 = self.attn2(out)out=self.last(out)return out.squeeze(), p1, p2
3. 模型训练
3.1 数据加载
本节中,我们将继续使用 Celeb A 人脸图像数据集构建 SAGAN
:
from torchvision import transforms
import torchvision.utils as vutils
import cv2, numpy as np
import torch
import os
from glob import glob
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"transform=transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])class Faces(Dataset):def __init__(self, folder):super().__init__()self.folder = folderself.images = sorted(glob(folder))def __len__(self):return len(self.images)def __getitem__(self, ix):image_path = self.images[ix]image = Image.open(image_path)image = transform(image)return imageds = Faces(folder='cropped_faces/*.jpg')
dataloader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=8)
3.2 训练流程
使用标准的 GAN
训练步骤。损失函数使用铰链损失,使用 Adam
优化器,生成器 (1e-4
) 和判别器 (4e-4
) 使用不同的初始学习率:
class Trainer(object):def __init__(self, data_loader):# Data loaderself.data_loader = data_loader# exact and lossself.adv_loss = 'wgan-gp'# Model hyper-parametersself.imsize = 64self.g_num = 5self.z_dim = 128self.g_conv_dim = 64self.d_conv_dim = 64self.parallel = Falseself.lambda_gp = 10self.total_step = 50000self.d_iters = 5self.batch_size = 32self.num_workers = 2self.g_lr = 0.0001self.d_lr = 0.0004self.lr_decay = 0.95self.beta1 = 0.0self.beta2 = 0.9self.dataset = data_loaderself.sample_path = 'sagan_samples'self.sample_step = 100self.log_step = 10self.build_model()def train(self):# Data iteratordata_iter = iter(self.data_loader)step_per_epoch = len(self.data_loader)# Fixed input for debuggingfixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))start = 0# Start timestart_time = time.time()for step in range(start, self.total_step):# ================== Train D ================== #self.D.train()self.G.train()try:real_images = next(data_iter)except:data_iter = iter(self.data_loader)real_images = next(data_iter)# Compute loss with real images# dr1, dr2, df1, df2, gf1, gf2 are attention scoresreal_images = tensor2var(real_images)d_out_real,dr1,dr2 = self.D(real_images)if self.adv_loss == 'wgan-gp':d_loss_real = - torch.mean(d_out_real)elif self.adv_loss == 'hinge':d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()# apply Gumbel Softmaxz = tensor2var(torch.randn(real_images.size(0), self.z_dim))fake_images,gf1,gf2 = self.G(z)d_out_fake,df1,df2 = self.D(fake_images)if self.adv_loss == 'wgan-gp':d_loss_fake = d_out_fake.mean()elif self.adv_loss == 'hinge':d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()# Backward + Optimized_loss = d_loss_real + d_loss_fakeself.reset_grad()d_loss.backward()self.d_optimizer.step()if self.adv_loss == 'wgan-gp':# Compute gradient penaltyalpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images)interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True)out,_,_ = self.D(interpolated)grad = torch.autograd.grad(outputs=out,inputs=interpolated,grad_outputs=torch.ones(out.size()).cuda(),retain_graph=True,create_graph=True,only_inputs=True)[0]grad = grad.view(grad.size(0), -1)grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)# Backward + Optimized_loss = self.lambda_gp * d_loss_gpself.reset_grad()d_loss.backward()self.d_optimizer.step()# ================== Train G and gumbel ================== ## Create random noisez = tensor2var(torch.randn(real_images.size(0), self.z_dim))fake_images,_,_ = self.G(z)# Compute loss with fake imagesg_out_fake,_,_ = self.D(fake_images) # batch x nif self.adv_loss == 'wgan-gp':g_loss_fake = - g_out_fake.mean()elif self.adv_loss == 'hinge':g_loss_fake = - g_out_fake.mean()self.reset_grad()g_loss_fake.backward()self.g_optimizer.step()# Print out log infoif (step + 1) % self.log_step == 0:elapsed = time.time() - start_timeelapsed = str(datetime.timedelta(seconds=elapsed))print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "" ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".format(elapsed, step + 1, self.total_step, (step + 1),self.total_step , d_loss_real.item(),self.G.attn1.gamma.mean().item(), self.G.attn2.gamma.mean().item() ))# Sample imagesif (step + 1) % self.sample_step == 0:fake_images,_,_= self.G(fixed_z)save_image(denorm(fake_images.data),os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)))def build_model(self):self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda()self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda()if self.parallel:self.G = nn.DataParallel(self.G)self.D = nn.DataParallel(self.D)# Loss and optimizerself.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2])self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])self.c_loss = torch.nn.CrossEntropyLoss()# print networksprint(self.G)print(self.D)
模型训练完成后,使用生成器生成人脸图像: