当前位置: 首页 > news >正文

AIGC实战——CycleGAN详解与实现

AIGC实战——CycleGAN详解与实现

    • 0. 前言
    • 1. CycleGAN 基本原理
    • 2. CycleGAN 模型分析
    • 3. 实现 CycleGAN
    • 小结
    • 系列链接

0. 前言

CycleGAN 是一种用于图像转换的生成对抗网络(Generative Adversarial Network, GAN),可以在不需要配对数据的情况下将一种风格的图像转换成另一种风格,而无需为每一对输入-输出图像配对训练数据。CycleGAN 的核心思想是利用两个生成器和两个判别器,它们共同学习两个域之间的映射关系。例如,将马的图像转换成斑马的图像,或者将苹果图像转换为橙子图像。在本节中,我们将学习 CycleGAN 的基本原理,并实现该模型用于将夏天的风景图像转换成冬天的风景图像,或反之将冬天的风景图像转换为夏天的风景图像。

1. CycleGAN 基本原理

CycleGAN 是一种无需配对的图像转换技术,它可以将一个图像域中的图像转换为另一个图像域中的图像,而不需要匹配这两个域中的图像。它使用两个生成器和两个判别器,其中一个生成器将一个域中的图像转换为另一个域中的图像,而第二个生成器将其转换回来。这个过程被称为循环一致性,转换过程是可逆的。
CycleGAN 可以用于执行从一个类别到另一个类别的图像转换,而无需提供相匹配的输入-输出图像对来训练模型,只需要在两个不同的文件夹中提供这两个类别的图像。在本节中,我们将学习如何训练 CycleGAN 将夏天的风景图像转换成冬天的风景图像,或反之将冬天的风景图像转换为夏天的风景图像,CycleGAN 中的 Cycle 是指将图像从一个类别转换到另一个类别,然后再转换回原始类别的过程。
为了实现图像转换,使用两个 GAN,每个 GAN 的生成器执行从一个域到另一个域的图像转换。具体来说,假设输入是 X X X,那么第一个 GAN 的生成器执行映射 G : X → Y G:X\rightarrow Y G:XY,其输出为 Y = G ( X ) Y = G(X) Y=G(X);第二个 GAN 的生成器执行逆映射 F : Y → X F:Y\rightarrow X F:YX,结果为 X = F ( Y ) X = F(Y) X=F(Y)。每个判别器都训练用于区分真实图像和生成图像:

CycleGAN

为了训练 CycleGAN,除了传统的对抗损失外,还添加了循环一致性损失,用于确保给定图像 X X X 作为输入,那么经过两次转换 F ( G ( X ) ) ∼ X F(G(X)) \sim X F(G(X))X 后得到的图像与 X X X 相同,类似地,需要损失确保 G ( F ( Y ) ) ∼ Y ) G(F(Y)) \sim Y) G(F(Y))Y)
总体而言,在 CycleGAN 中,需要使用三种不同的损失值:

  • 鉴别器损失:用于区分真实图像和伪造图像
  • 循环一致性损失:由于 CycleGAN 使用了两个生成器,因此需要确保转换是可逆的,循环一致性损失通过将转换过的图像再次传递到原始的生成器中,并将生成的图像与原始图像进行比较来实现
  • 恒等损失 (Identity loss):确保生成器在不进行转换的情况下仍然能够生成与原始图像相似的图像,通过将原始图像传递到生成器中,并计算生成图像与原始图像之间的差异

2. CycleGAN 模型分析

CycleGAN 模型构建策略如下:

  1. 导入数据集并进行预处理
  2. 定义 UNet 架构用于构建生成器和判别器网络
  3. 定义两个生成器:
    • G_AB:将类别 A 图像转换为类别 B 图像的生成器
    • G_BA:将类别 B 图像转换为类别 A 图像的生成器
  4. 定义恒等损失:
    • 如果将一张橘子的图像输入到橙子生成器,理想情况下,如果生成器完全理解橙子的所有信息,它不应该改变图像,而应该“生成”完全相同的图像,据此,我们可以创建一个恒等变换
    • 当类别 A (real_A) 的图像通过 G_BA 并与 real_A 进行比较时,恒等损失应该是最小的
    • 当类别 B (real_B) 的图像通过 G_AB 并与 real_B 进行比较时,恒等损失应该是最小的
  5. 定义GAN损失:
    • real_Afake_A 的判别器和生成器损失(当 real_B 图像通过 G_BA 时得到 fake_A)
    • real_Bfake_B 的判别器和生成器损失(当 real_A 图像通过 G_AB 时得到 fake_B)
  6. 定义循环一致性损失:
    • 一张苹果图像需要通过橙子生成网络进行转换,生成伪造的橘子图像,然后再通过苹果生成网络将伪造的橙子图像转换回苹果图像
    • fake_Breal_A 通过 G_AB 时的输出,当 fake_B 通过 G_BA 时应该重新生成 real_A
    • fake_Areal_B 通过 G_BA 时的输出,当 fake_A 通过 G_AB 时应该重新生成 real_B
  7. 优化三个损失函数的加权和

3. 实现 CycleGAN

在本节中,我们使用 TensorFlow 实现 CycleGAN 模型。

(1) 导入所需模块,并使用 tensorflow_datasets 加载数据集,并使用tensorflow_examples 库中预定义的 pix2pix 模型的生成器和鉴别器:

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

tensorflow_examples 包含一组适用于 CycleGAN 的数据集,如马-斑马、苹果-橘子等。在本节中,我们将使用 summer2winter_yosemite 数据集,包含了夏季图像和冬季图像,训练 CycleGAN 将输入的夏季图像转换为冬季图像,或反之将冬季图像转换为夏季图像。

(2) 加载数据,并获取训练和测试图像:

import os
import time
import matplotlib.pyplot as plt
import tensorflow as tf
from glob import glob

AUTOTUNE = tf.data.AUTOTUNE

train_summer = tf.data.Dataset.list_files('summer2winter_yosemite/trainA/*.jpg')
train_winter = tf.data.Dataset.list_files('summer2winter_yosemite/trainB/*.jpg')
test_summer = tf.data.Dataset.list_files('summer2winter_yosemite/testA/*.jpg')
test_winter = tf.data.Dataset.list_files('summer2winter_yosemite/testB/*.jpg')

(3) 设置超参数:

BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

(4) 训练网络之前,对图像进行预处理,为了获得更好的性能,对训练图像添加随机抖动。执行归一化后,将图像调整为 286 x 286,然后随机裁剪为 256 x 256,最后应用随机抖动:

def random_crop(image):
    cropped_image = tf.image.random_crop(
        image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image

# normalizing the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

def load(image):
    image = tf.io.read_file(image)
    image = tf.image.decode_jpeg(image)
    input_image = tf.cast(image, tf.float32)
    return input_image

(5) 数据增强(随机裁剪和抖动)仅对训练图像进行,因此需要分别定义训练数据和测试数据的图像预处理函数:

def preprocess_image_train(image):
    image = load(image)
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image):
    image = load(image)
    image = normalize(image)
    return image

(6) 将以上函数应用于图像时,会将其归一化到范围 [-1,1] 之间,并对训练图像进行数据增强。在训练和测试数据集上应用以上函数,并创建一个数据加载器,用于批量提供训练图像:

train_summer = train_summer.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

train_winter = train_winter.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

test_summer = test_summer.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

test_winter = test_winter.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

在以上代码中,参数 num_parallel_calls 用于指定需要利用系统中的 CPU 核心数量,可以将其值设置为系统中的 CPU 全部核心数。可以使用 AUTOTUNE = tf.data.AUTOTUNE 值,以便 TensorFlow 动态确定合适的 CPU 核心数量。

(7) 使用在 tensorflow_examples 模块中定义的 pix2pix 模型的生成器和鉴别器,定义两个生成器和两个鉴别器:

sample_summer = next(iter(train_summer))
sample_winter = next(iter(train_winter))

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

(8) 查看示例图像,每张图像在绘制之前都会执行归一化处理:

to_winter = generator_g(sample_summer)
to_summer = generator_f(sample_winter)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_summer, to_winter, sample_winter, to_summer]
title = ['Summer', 'To Winter', 'Winter', 'To Summer']


for i in range(len(imgs)):
    plt.subplot(2, 2, i+1)
    plt.title(title[i])
    if i % 2 == 0:
        plt.imshow(imgs[i][0] * 0.5 + 0.5)
    else:
        plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

示例图像

(9) 定义损失函数和优化器,使用与 DCGAN 相同的生成器和鉴别器的损失函数:

LAMBDA = 10

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
    real_loss = loss_obj(tf.ones_like(real), real)

    generated_loss = loss_obj(tf.zeros_like(generated), generated)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss * 0.5

def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

(10) 由于 CycleGAN 包含四个模型,两个生成器和两个鉴别器,因此需要定义四个优化器:

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

(11) 此外,在 CycleGAN 中,还需要定义两个额外的损失函数。首先是循环一致性损失,用于确保生成结果接近原始输入:

def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
    
    return LAMBDA * loss1

另外还需要定义一个恒等损失,用于确保如果将图像 Y Y Y 输入生成器 G : X → Y G:X\rightarrow Y G:XY,它会输出类似于 Y Y Y 的图像。因此,如果给夏季图像生成器一个夏季的图像作为输入,它不应该对其进行过多修改:

def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

(12) 定义函数训练生成器和鉴别器。两个鉴别器和两个生成器将通过 tape 梯度进行训练。训练步骤可以分为 4 步:

  • 从两个生成器中获取输出图像
  • 计算损失
  • 计算梯度
  • 最后,应用梯度
@tf.function
def train_step(real_x, real_y):
    # persistent is set to True because the tape is used more than
    # once to calculate the gradients.
    with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y
        # Generator F translates Y -> X.
        
        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)

        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)

        disc_real_x = discriminator_x(real_x, training=True)
        disc_real_y = discriminator_y(real_y, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)

        # calculate the loss
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)
        
        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
        
        # Total generator loss = adversarial loss + cycle loss
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
    
    # Calculate the gradients for generator and discriminator
    generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                            generator_g.trainable_variables)
    generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                            generator_f.trainable_variables)
    
    discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                                discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                                discriminator_y.trainable_variables)
    
    # Apply the gradients to the optimizer
    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                                generator_g.trainable_variables))

    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                                generator_f.trainable_variables))
    
    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                    discriminator_x.trainable_variables))
    
    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                    discriminator_y.trainable_variables))

(13) 训练网络 200 个epoch:

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

定义检查点保存模型权重。由于训练一个优秀的 CycleGAN 可能需要大量时间,保存检查点能够用于确保模型从上次中断的地方继续学习,只需在下次开始时加载现有的检查点。

(14) 查看 CycleGAN 生成的图像。生成器 A 以夏季照片作为输入,将它们转换为冬季照片,而生成器 B 以冬季照片作为输入,将它们转换为夏季照片:

EPOCHS = 100

def generate_images(model, test_input):
    prediction = model(test_input)
        
    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.savefig('image.png')
plt.show()

for epoch in range(EPOCHS):
    start = time.time()

    n = 0
    for image_x, image_y in tf.data.Dataset.zip((train_summer, train_winter)):
        train_step(image_x, image_y)
        if n % 10 == 0:
            print ('.', end='')
        n += 1

    # Using a consistent image (sample_horse) so that the progress of the model
    # is clearly visible.
    generate_images(generator_g, sample_summer)

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                            ckpt_save_path))

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
    
to_winter = generator_g(sample_summer)
to_summer = generator_f(sample_winter)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_summer, to_winter, sample_winter, to_summer]
title = ['Summer', 'To Winter', 'Winter', 'To Summer']


for i in range(len(imgs)):
    plt.subplot(2, 2, i+1)
    plt.title(title[i])
    if i % 2 == 0:
        plt.imshow(imgs[i][0] * 0.5 + 0.5)
    else:
        plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

for inp in test_summer.take(5):
generate_images(generator_g, inp)

转换结果

可以尝试使用 TensorFlow CycleGAN 数据集中其他的数据集,如 apple2orange 数据集。

小结

CycleGAN 是一种用于无监督图像转换的深度学习模型,它通过两个生成器和两个判别器的组合来学习两个不同域之间的映射关系。生成器负责将一个域的图像转换成另一个域的图像,而判别器则用于区分生成的图像和真实的图像。CycleGAN 引入循环一致性损失,确保图像转换是可逆的,从而提高生成图像的质量。通过对抗训练和循环一致性损失,CycleGAN 可以实现在没有配对标签的情况下进行图像域转换。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)
AIGC实战——WGAN(Wasserstein GAN)
AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)
AIGC实战——自回归模型(Autoregressive Model)
AIGC实战——改进循环神经网络
AIGC实战——像素卷积神经网络(PixelCNN)
AIGC实战——归一化流模型(Normalizing Flow Model)
AIGC实战——能量模型(Energy-Based Model)
AIGC实战——扩散模型(Diffusion Model)
AIGC实战——GPT(Generative Pre-trained Transformer)
AIGC实战——Transformer模型
AIGC实战——ProGAN(Progressive Growing Generative Adversarial Network)
AIGC实战——StyleGAN(Style-Based Generative Adversarial Network)
AIGC实战——VQ-GAN(Vector Quantized Generative Adversarial Network)
AIGC实战——基于Transformer实现音乐生成
AIGC实战——MuseGAN详解与实现
AIGC实战——多模态模型DALL.E 2
AIGC实战——多模态模型Flamingo
AIGC实战——世界模型(World Model)
AIGC实战——生成式人工智能总结与展望

http://www.dtcms.com/a/112144.html

相关文章:

  • NVIDIA AgentIQ 详细介绍
  • 从Keep-Alive到页面关闭:解决Vue和React生命周期函数不触发的实战技巧
  • 相干光信号处理的一些基础知识
  • Spring依赖注入最佳实践:应对接口多实现的挑战
  • Centos7.9怎样安装Mysql 5.7
  • MySQL数据库如何在线修改表结构及字段类型?
  • FreeRTOS/任务创建和删除的API函数
  • HTML表单属性1
  • 线程同步与互斥(上)
  • 计算机通识
  • NB-IoT单灯控制器:智慧照明的“神经末梢”
  • 蓝桥杯嵌入式第15届真题-个人理解+解析
  • 【系统】换硬盘不换系统,使用WIN PE Ghost镜像给电脑无损扩容换硬盘
  • Python3.13安装教程-2025最新版超级详细图文安装教程(附所需安装包环境)
  • PhotoShop学习04
  • 详解大模型四类漏洞
  • Vue2+Vue3 45-90集学习笔记
  • P12013 [Ynoi April Fool‘s Round 2025] 牢夸 Solution
  • CMAKE中使用外部动态库
  • C++中,应尽可能将引用形参声明为const
  • Smart Link 技术全面解析
  • 使用人工智能大模型腾讯元宝和ttsmp3工具,免费使用文字进行配音
  • Python入门(6):Python序列结构-元组
  • FastAPI-Cache2: 高效Python缓存库
  • Linux系统调用编程
  • 嵌入式开发中栈溢出的处理方法
  • MySQL学习笔记(一)——MySQL下载安装配置
  • 一文全面了解GEO中的知识图谱
  • leetcode数组-长度最小的子数组
  • 【Git】“warning: LF will be replaced by CRLF”的解决办法