当前位置: 首页 > news >正文

TensorFlow2 Python深度学习 - 生成对抗网络(GAN)实例

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 生成对抗网络(GAN)实例

我们以生成手写数字数据集为示例:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import keras
from keras import layers, Input
import matplotlib.pyplot as plt
import time
​
# 使用手写字体或单品样本做训练  这里注意的是 我们只需要训练数据,不需要答案和测试数据集。
(train_images, _), (_, _) = keras.datasets.mnist.load_data()
​
# 因为卷积层的需求,增加色深维度
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
# 规范化为-1 - +1
train_images = (train_images - 127.5) / 127.5
​
BUFFER_SIZE = 60000  # 以供60000个样本
BATCH_SIZE = 256  # 256张为一组
# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
​
​
# 生成器网络
def make_generator_model():  # 根据长度为100的随机数组,生成一张28,28,1的矩阵model = tf.keras.Sequential()model.add(Input(shape=(100,)))# 全联接层,输入纬度为[[100],[n]],  输出为7*7*256 = 12544的节点  use_bias=False不使用偏差model.add(layers.Dense(7 * 7 * 256, use_bias=False))# BatchNormalization层:该层在每个batch上将前一层的激活值重新规范化,即使得其输出数据的均值接近0,其标准差接近1# 该层作用:(1)加速收敛(2)控制过拟合,可以少用或不用Dropout和正则(3)降低网络对初始化权重不敏感(4)允许使用较大的学习率model.add(layers.BatchNormalization())# ReLU是将所有的负值都设为零,相反,Leaky ReLU是给所有负值赋予一个非零斜率(负数)model.add(layers.LeakyReLU())# 将平铺的节点转为7*7*256的shapemodel.add(layers.Reshape((7, 7, 256)))# 通俗的讲这个解卷积,也就做反卷积,也叫做转置卷积(最贴切),我们就叫做反卷积吧,它的目的就是卷积的反向操作# 个人理解,正常的卷积是提取卷积核特征,反卷积就是用卷积核反向修改图像,风格迁移应该也是这么回事,那么问题来了在这个gan中,卷积特征从哪来?model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())# 64, (5, 5), strides=(2, 2), 希望得到64个特征核,步长2,2# model.output_shape == (None, 14, 14, 64) 输出的节点数64就是上面的特征核,由于padding='same',所以卷积后无变化,# 14,14 是因为步长 2,2  所以7*2model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))return model
​
​
# 判别器网络
def make_discriminator_model():model = tf.keras.Sequential()model.add(Input(shape=(28, 28, 1)))# 将 28.28.1的图像卷积 输出64个节点model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))# 接着卷积出128个节点model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))# 激活函数 为非0的斜率model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))# 平铺 并输出一个数字model.add(layers.Flatten())model.add(layers.Dense(1))return model
​
​
generator = make_generator_model()
discriminator = make_discriminator_model()
​
# 交叉熵损失函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
​
​
# 辨别模型损失函数
def discriminator_loss(real_output, fake_output):# 样本图希望结果趋近1real_loss = cross_entropy(tf.ones_like(real_output), real_output)# 自己生成的图希望结果趋近0fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)# 总损失total_loss = real_loss + fake_lossreturn total_loss
​
​
# 生成模型的损失函数
def generator_loss(fake_output):# 生成模型期望最终的结果越来越接近1,也就是真实样本return cross_entropy(tf.ones_like(fake_output), fake_output)
​
​
# 优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
​
EPOCHS = 100  # 训练轮数
noise_dim = 100  # 噪声向量的维度
num_examples_to_generate = 16  # 生成图片数量
​
# 初始化16个种子向量,用于生成4x4的图片  seed shape: 16, 100
seed = tf.random.normal([num_examples_to_generate, noise_dim])
​
​
def train_step(images):  # 更新 模型权重数据的核心方法# 随机生成一个批次的种子向量 BATCH_SIZE = 256   noise_dim = 100  ,256个长度为100的噪音响亮noise = tf.random.normal([BATCH_SIZE, noise_dim])  # noise shape:[256],[100]
​# 查看每一次epoch参数更新  这个GradientTape 是每次梯度更新都会调用的,这个取代了model.fit的训练计算with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:# 生成一个批次的图片generated_images = generator(noise, training=True)
​# 辨别一个批次的真实样本real_output = discriminator(images, training=True)# 辨别一个批次的生成图片fake_output = discriminator(generated_images, training=True)
​# 计算两个损失值gen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)
​# 根据损失值调整模型的权重参量gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
​# 计算出的参量应用到模型   梯度修剪,用于改变值, 梯度修剪主要避免训练梯度爆炸和消失问题# zIP是个格式转换函数 例如:a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]; zip(*a) = [(1, 4, 7), (2, 5, 8), (3, 6, 9)]generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
​
​
# 训练
def train(dataset, epochs):for epoch in range(epochs + 1):start = time.time()
​# 训练for image_batch in dataset:train_step(image_batch)
​# 保存图片# 每个训练批次生成一张图片作为阶段成功print("=======================================")generate_and_save_images(generator, epoch + 1, seed)
​print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
​
​
# 生成图片
def generate_and_save_images(model, epoch, test_input):# 设置为非训练状态,生成一组图片predictions = model(test_input, training=False)
​# 4格x4格拼接for i in range(predictions.shape[0]):plt.subplot(4, 4, i + 1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')
​# 保存为pngplt.savefig('image_at_epoch_{:04d}.png'.format(epoch))plt.close()
​
​
# 以训练模式运行,进入训练状态
train(train_dataset, EPOCHS)

运行会生成100个训练图片,每个图片有16个数字小图。

越后面的图片,数字辨识度越高。

第1张,基本无法识别。

第16张,稍微有点辨识度:

第70张,基本有辨识度了:

http://www.dtcms.com/a/507276.html

相关文章:

  • 利用jmeter完成简单的压力测试
  • 做网站用什么编程软件黄页88网能不能发免费的广告
  • 电子商务网站开发合同网页设计基础教程第二版课后答案
  • 基于Vite创建一个Vue2
  • 小皮面板的MySQL点击启动后马上又停止了
  • 【Python入门】第5篇:数据结构初探(列表、元组、字典、集合)​
  • Redis的List数据结构底层实现
  • 基于半桥结构的双极性脉冲电源的研究
  • openEuler安装mysql
  • ADC 模拟量转数字量
  • 网络广告是什么网站优化外包费用
  • 【IEEE/EI/Scopus检索】2026年第六届信息技术与云计算国际会议(ITCC 2026)
  • 赋能天然产物科学研究:多模态大模型与知识图谱的革新之旅
  • 用C语言实现原型模式时,如何确定需要深拷贝还是浅拷贝?
  • Spring Boot 3零基础教程,WEB 开发 Thymeleaf 属性优先级 行内写法 变量选择 笔记42
  • Go语言:对其语法的一些见解
  • Go Web 编程快速入门 · 04 - 请求对象 Request:头、体与查询参数
  • 伦教九江网站建设辽宁工程建筑信息网
  • Deep End-to-End Alignment and Refinement for Time-of-Flight RGB-D Module,2019
  • Ubuntu 安装 Gitea
  • 通达信灵活屏
  • 亚马逊云代理商:AWS怎么通过加密实现数据保护目标?
  • C标准库--C99--控制浮点环境<fenv.h>
  • 【Linux】“ 权限 “ 与相关指令
  • webrtc弱网-ReceiveSideCongestionController类源码分析及算法原理
  • 通达信--主题投资分析
  • 揭阳专业做网站天台县建设规划局网站
  • 福海网站制作关键词堆砌的作弊网站
  • sql特训
  • LeetCode 刷题【126. 单词接龙 II】