当前位置: 首页 > news >正文

【课堂笔记】生成对抗网络 Generative Adversarial Network(GAN)

文章目录

  • 问题背景
  • 原理
  • 更新过程
    • 判别器
    • 生成器

问题背景

  一方面,许多机器学习任务需要大量标注数据,但真实数据可能稀缺或昂贵(如医学影像、稀有事件数据)。如何在少量数据中达到一个很好的训练效果是一个很重要的问题。
  另一方面,传统生成模型(如变分自编码器VAE)生成的样本往往模糊或缺乏多样性,难以捕捉真实数据的复杂分布(如高分辨率图像、复杂文本等)。
  生成式对抗网络(GAN)提出了用生成器(Generator)和判别器(Discriminator),通过对抗训练相互竞争来提高性能。这样能够生成与真实数据分布相似的合成数据,用于数据增强;同时通过生成器和判别器的对抗训练,生成器学习到真实数据的概率分布,生成的样本更加逼真、细节丰富。

原理

  GAN由两个神经网络组成:
(1)生成器 G \mathbf{G} G:输入随机噪声 z ∼ p G ( z ) z \sim p_G(z) zpG(z)(通常是正态或均匀分布),输出生成的假数据 G ( z ) \mathbf{G}(z) G(z),试图模仿真实数据分布 p data p_{\text{data}} pdata
(2)判别器 D \mathbf{D} D:输入数据(真实数据 x ∼ p data x \sim p_{\text{data}} xpdata或假数据 p data p_{\text{data}} pdata),输出概率 D ( x ) ∈ [ 0 , 1 ] \mathbf{D}(x) \in [0, 1] D(x)[0,1],表示数据为真实的概率。
  这两个神经网络是对抗性的,生成器 G \mathbf{G} G企图让假数据更逼真,来让 D \mathbf{D} D犯错;而判别器 D \mathbf{D} D试图最大化区分真假数据的准确性。

  基于这个目的,我们构造一个损失函数:
(1)对于真实数据 x ∼ p data x \sim p_{\text{data}} xpdata,我们希望 D ( x ) → 1 \mathbf{D}(x) \rightarrow 1 D(x)1,定义损失为 − log ⁡ D ( x ) -\log\mathbf{D}(x) logD(x)
(2)对于生成数据 G ( z ) ∼ p G \mathbf{G}(z) \sim p_G G(z)pG,我们希望 D ( G ( z ) ) → 0 \mathbf{D}(\mathbf{G}(z))\rightarrow 0 D(G(z))0,定义损失为 − log ⁡ ( 1 − D ( G ( z ) ) ) -\log(1-\mathbf{D}(\mathbf{G}(z))) log(1D(G(z)))
  判别器的目标是最大化正确分类的概率,即最小化以下损失:
L D = − E x ∼ p data [ log ⁡ D ( x ) ] − E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_D = - \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] - \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] LD=Expdata[logD(x)]Ezpz[log(1D(G(z)))]
  生成器的目标是欺骗判别器,即最小化以下损失:
L G = E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_G = \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] LG=Ezpz[log(1D(G(z)))]
  结合两者,我们可以写出GAN的整体目标函数:
min ⁡ G max ⁡ D ( E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] ) \min_G \max_D \left(\mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right]\right) GminDmax(Expdata[logD(x)]+Ezpz[log(1D(G(z)))])
  接下来去解决这个目标,为了叙述方便定义记号 V ( N , G ) V(N, G) V(N,G),并改写为积分形式:
V ( D , G ) : = E x ∼ p data [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] = ∫ x p data ( x ) log ⁡ D ( x ) d x + ∫ x p g ( x ) log ⁡ ( 1 − D ( x ) ) d x = ∫ x f ( D ( x ) ) d x f ( D ( x ) ) : = p data ( x ) log ⁡ D ( x ) + p g ( x ) log ⁡ ( 1 − D ( x ) ) \begin{align*} V(D, G) &:= \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log D(x) \right] + \mathbb{E}_{z \sim p_z} \left[ \log (1 - D(G(z))) \right] \\ &=\int_x p_{\text{data}}(x) \log D(x) \, dx + \int_x p_g(x) \log (1 - D(x)) \, dx \\ &=\int_x f(D(x))dx \\ f(D(x)) &:= p_{\text{data}}(x) \log D(x) + p_g(x) \log (1 - D(x)) \end{align*} V(D,G)f(D(x)):=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]=xpdata(x)logD(x)dx+xpg(x)log(1D(x))dx=xf(D(x))dx:=pdata(x)logD(x)+pg(x)log(1D(x))
  首先我们要找最大化 V ( D , G ) V(D, G) V(D,G) D ∗ D^* D,于是对 D D D求导:
∂ f ∂ D ( x ) = p data ( x ) D ( x ) − p g ( x ) 1 − D ( x ) = 0 ⇒ D ∗ ( x ) = p data ( x ) p data ( x ) + p g ( x ) \frac{\partial f}{\partial D(x)} = \frac{p_{\text{data}}(x)}{D(x)} - \frac{p_g(x)}{1 - D(x)} = 0 \\ \Rightarrow D^*(x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_g(x)} D(x)f=D(x)pdata(x)1D(x)pg(x)=0D(x)=pdata(x)+pg(x)pdata(x)
  这个结果表面,最有判别器 D ∗ D^* D输出真实数据和生成数据分布的相对概率。
  接下来将 D ∗ D^* D代入:
V ( D ∗ , G ) = ∫ x [ p data ( x ) log ⁡ ( p data ( x ) p data ( x ) + p g ( x ) ) + p g ( x ) log ⁡ ( p g ( x ) p data ( x ) + p g ( x ) ) ] d x V(D^*, G) = \int_x \left[ p_{\text{data}}(x) \log \left( \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_g(x)} \right) + p_g(x) \log \left( \frac{p_g(x)}{p_{\text{data}}(x) + p_g(x)} \right) \right] dx V(D,G)=x[pdata(x)log(pdata(x)+pg(x)pdata(x))+pg(x)log(pdata(x)+pg(x)pg(x))]dx
  这个式子比较复杂,经过推导可以证明:
V ( D ∗ , G ) = − log ⁡ 4 + 2 ⋅ JS ( p data ∥ p g ) V(D^*, G) = - \log 4 + 2 \cdot \text{JS}(p_{\text{data}} \| p_g) V(D,G)=log4+2JS(pdatapg)
  其中 J S \mathbf{JS} JS是Jensen-Shannon 散度,它与 K L \mathbf{KL} KL散度的关系为:
JS ( p data ∥ p g ) = 1 2 KL ( p data ∥ p data + p g 2 ) + 1 2 KL ( p g ∥ p data + p g 2 ) \text{JS}(p_{\text{data}} \| p_g) = \frac{1}{2} \text{KL} \left( p_{\text{data}} \| \frac{p_{\text{data}} + p_g}{2} \right) + \frac{1}{2} \text{KL} \left( p_g \| \frac{p_{\text{data}} + p_g}{2} \right) JS(pdatapg)=21KL(pdata2pdata+pg)+21KL(pg2pdata+pg)
  这个结果是合理的。当 p g = p d a t a p_g = p_{data} pg=pdata时, J S \mathbf{JS} JS散度为0,此时目标函数达到最小值 − log ⁡ 4 -\log 4 log4 D ∗ ( x ) = 0.5 \mathbf{D}^*(x) = 0.5 D(x)=0.5,将无法区分数据的真假。
  对于生成器 G \mathbf{G} G的优化等价于最小化这个 J S \mathbf{JS} JS散度。

更新过程

  在上述推导中,对随机分布进行了期望积分,但实际操作过程中直接计算上述积分是不可行的,我们会采用蒙特卡洛方法近似期望值,于是下面的 L D L_D LD L G L_G LG是用约等于。
  蒙特卡洛方法:核心是利用随机性和大数定律,通过从分布 p ( x ) p(x) p(x)中采集大量样本点 x 1 , . . . , x n x_1, ..., x_n x1,...,xn,然后计算样本均值来近似期望值:
E [ f ( X ) ] ≈ 1 n ∑ i = 1 n f ( x i ) \mathbb{E}[f(X)] \approx \frac{1}{n} \sum_{i=1}^n f(x_i) E[f(X)]n1i=1nf(xi)

判别器

  在理论分析中,我们得到了最优判别器 D ∗ ( x ) = p data ( x ) p data ( x ) + p g ( x ) \mathbf{D}^*(x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_g(x)} D(x)=pdata(x)+pg(x)pdata(x),然而我们不知道数据实际分布 p data p_{\text{data}} pdata,通常采用梯度下降等方式来拟合:
(1)从真实数据中采集一批 x 1 , . . . , x m x_1, ..., x_m x1,...,xm,从生成器中生成一批 G ( z 1 ) , . . . , G ( z m ) G(z_1), ..., G(z_m) G(z1),...,G(zm)
(2)使用梯度下降优化损失 L D L_D LD θ D \theta_D θD是神经网络 D \mathbf{D} D的参数:
L D ≈ − 1 m ∑ i = 1 m [ log ⁡ D ( x i ) + log ⁡ ( 1 − D ( G ( z i ) ) ) ] θ D ← θ D + η ⋅ ∇ θ D L D L_D \approx -\frac{1}{m} \sum_{i=1}^m \left[ \log D(x_i) + \log (1 - D(G(z_i))) \right] \\ \theta_D \gets \theta_D + \eta \cdot \nabla_{\theta_D} L_D LDm1i=1m[logD(xi)+log(1D(G(zi)))]θDθD+ηθDLD

生成器

  生成器的训练和判别器交替进行,同样采用梯度下降等方法来拟合:
(1)从生成器中生成一批 G ( z 1 ) , . . . , G ( z m ) G(z_1), ..., G(z_m) G(z1),...,G(zm)
(2)使用当前判别器 D \mathbf{D} D(已部分训练)计算生成器损失的近似:
L G ≈ − 1 m ∑ i = 1 m log ⁡ D ( G ( z i ) ) L_G \approx -\frac{1}{m} \sum_{i=1}^m \log D(G(z_i)) LGm1i=1mlogD(G(zi))
(3)计算梯度并更新参数:
∇ θ G L G ≈ − 1 m ∑ i = 1 m ∇ θ G log ⁡ D ( G ( z i ) ) θ G ← θ G − η ⋅ ∇ θ G L G \nabla_{\theta_G} L_G \approx -\frac{1}{m} \sum_{i=1}^m \nabla_{\theta_G} \log D(G(z_i)) \\ \theta_G \gets \theta_G - \eta \cdot \nabla_{\theta_G} L_G θGLGm1i=1mθGlogD(G(zi))θGθGηθGLG

相关文章:

  • 图像处理篇---face_recognition库实现人脸检测
  • Vue3+SpringBoot全栈开发:从零实现增删改查与分页功能
  • 字节golang后端二面
  • 用dayjs解析时间戳,我被提了bug
  • 在IIS上无法使用PUT等请求
  • 基于机器学习的心脏病预测模型构建与可解释性分析
  • 西瓜书第十章——聚类
  • buuctf-web
  • unix/linux source 命令,其历史争议、兼容性、生态、未来展望
  • 在Flutter中定义全局对象(如$http)而不需要import
  • JVM学习(七)--JVM性能监控
  • Tomcat优化篇
  • ASP.NET Core SignalR 身份认证集成指南(Identity + JWT)
  • Axure组件即拖即用:垂直折叠菜单(动态展开/收回交互)
  • APM32主控键盘全功能开发实战教程:软件部分
  • 【Java基础】Java入门教程
  • DeepSeek 赋能智慧消防:以 AI 之力筑牢城市安全 “防火墙”
  • 归一化相关
  • 大模型备案中语料安全详细说明
  • Ubuntu终端性能监视工具
  • 想在网上做开发网站接活儿/好的营销网站设计公司
  • 网站开发及维护是什么/数据分析网站
  • 南宁市网站维护与推广公司/网站优化要多少钱
  • 开发软件需要学什么专业/石家庄seo结算
  • 青岛做企业网站的公司/定制网站+域名+企业邮箱
  • 中咨城建设计有限公司 网站/域名收录查询