当前位置: 首页 > news >正文

李宏毅机器学习笔记15

目录

摘要

Abstract

1.如何训练GAN

2.距离如何计算

3.GAN训练的小技巧


摘要

本周继续学习李宏毅老师2025春季机器学习课程,本周学习内容是如何训练GAN相关的知识,以及训练过程中相关参数的计算方法等。

Abstract

This week, I continued with Prof. Hung-yi Lee’s 2025 Spring Machine Learning Course. The study content covered how to train GANs (Generative Adversarial Networks), along with the calculation methods for relevant parameters during the training process.

1.如何训练GAN

先弄清楚我们训练的目标是什么,我们在训练network的时候,通常是定义一个loss function,之后用gradient design调整参数,minimize我们的loss function即可。那么在generation之中我们应该minimize是什么呢?

实际上,我们想要minimize的东西是这样的,我们有一个generator,给他一堆的normal distribution sample(正态分布样本)的东西,之后我们的generator会产生一个比较复杂的distribution,我们称它为Pg,我们有一些真正的数据,他们形成另外一个distribution,我们称他为Pdata,我们期待Pg和Pdata越接近越好。

用一个一维的例子说明,假设generator的输入是一维向量,generator的输出也是一维向量,真正的数据也是一维向量,那么我们的normal distribution如上图左侧橙色点所示,输入到generator后,每一个点的位置都会改变,就会产生一个新的distribution如上图绿色点所示,真正的数据分布可能更极端,如上图蓝色点所示。

我们期待Pg和Pdata接近,写成式子可以如上图所示,Div表示Pg和Pdata之间的某种距离,越大就代表Pg和Pdata越不像。所以我们找一个generator实际是找一组generator中的参数,让产生出来的Pg和Pdata距离最小。

2.距离如何计算

实际上,我们可能知道的一些divergence的式子,例如kl divergence,js divergence,这些divergence用在continuous distribution(连续分布)上,会做一个很复杂,不知道怎么算的积分。我们根本无法把divergence算出来,如何找一个G去minimize这个divergence,这就是GAN所遇到的问题,也是我们在训练中会遇到的问题。

但是在GAN中,你不需要知道Pg和Pdata他们实际上算距离的公式,只要你能sample东西出来就可以计算divergence。怎么sample呢,从图库中随机sample一些图片出来就可以了,那generator则是输入一些normal distribution sample出来的向量输入到generator中,让generator产生一些图片。

想要计算距离,我们要依靠discriminator。discriminator是如何训练出来的呢?我们有一大堆的真实数据(从Pdata中sample出来的结果),还有一大堆的生成数据(从Pg中sample出来的结果),根据这些数据我们会去训练一个discriminator,训练的目标是看到真实数据就给较高的分数,看到生成数据就给较低的分数,也可以写成式子,这个discriminator可以去maximize某一个function,这里称为objective function,我们要找一个D,它可以maximize这个objective function,objective function如上图中V()所示,我们有一堆y从Pdata中sample出来的是真正的图片,把真正的图片y丢入到D中得到一个分数再取log,另一方面我们我们有一堆y从Pg中sample出来的是生成的图片,同样把生成的图片y丢入到D中得到一个分数,用1减去这个分数,再取log。 

objective function的最大值跟divergence是相关的,所以我们即使不知道如何计算divergence,在训练完discriminator之后看看它的objective function的最大值是多少。如果两组data很不像,他们的divergence很大,那discriminator就可以很轻松的把他们分开,此时discriminator的objective function的最大值就可以冲的很大。

3.GAN训练的小技巧

最知名的是WGAN,在此之前,我们先了解js divergence有什么问题。

首先看一下Pg和Pdata有什么样的特性,他们往往重叠的部分非常少。为什么这么说呢?有两种理由。一种是Pg和Pdata都是高维空间里的低维部分,例如二维空间(平面)中的一条线,二维空间中,两条线的重叠部分是很少的。另一种是,Pg和Pdata本身是有较多重叠部分的,但是他们都是sample出来的点,如果sample的不够多不够密,对discriminator来说也是不重叠的。

对于js divergence来说,它有一个特性对于两个没有重叠的分布,js divergence算出来永远是log2。,上图中情况1与情况2对于js divergence都是一样坏,但实际上情况2是更好的。

因此有了wasserstein distance,计算方法是,想象你在开一台推土机,把P当作是土堆,把Q当作是堆放的目的地,把P的土移动到Q的平均距离是resistant distance,在这个例子中,假设P集中在一个点,Q集中在一个点,两点距离为d,那Q与P的resistant distance就是d。

但是如果是更复杂的distribution,要计算resistant distance就有点困难了。因为我们的做法有无穷多种,用不同的做法,我们算出来的距离就不一样。为了让resistant distance只有一个值,wasserstein distance的方法是穷举所有的做法,看哪一个推土的方法的平均移动距离最短,这个最短的平均距离就是resistant distance。

回到上面的例子,如果是用wasserstein distance我们就可以发现从左到右generator是越来越好的,但观察discriminator会发现观察不到任何东西,对discriminator而言每一个case算出来的js divergence都一样。

WGAN实际上就是用wasserstein distance取代js divergence的时候这个GAN就叫WGAN,实际上wasserstein distance怎么计算呢?如上图所示的式子,x~Pdata,x~Pg代表来自Pdata和Pg的图片,我们要计算的是D(x)的期望值并相减,我们要maximize这个函数就要做到,如果x是从Pdata中sample出来的D(x)的输出就越大越好,如果x是从Pg中sample出来的D(x)的输出就越小越好。

但是还有一个限制,D不能是随便一个function,必须是一个足够平滑的函数。举一个例子,如上图所示,只要两部分没有重叠,那么为了满足maximize就会让真正的数据部分无限大,生成的部分无穷大的负值。会导致实际上两组数据只要不重叠无论是否相近,他们的maximum都是无限大。

http://www.dtcms.com/a/419858.html

相关文章:

  • 数字化转型:开发者思维破局之道
  • 网站会员功能介绍营销背景包括哪些内容
  • 【NCS随笔】peripheral_hids_mouse例程修改为不使用PIN码绑定
  • 第三方软件验收测试:【AutoIt与Selenium结合测试文件上传/下载等Windows对话框】
  • 网站的二级目录是什么10个不愁销路的小型加工厂
  • K8S中关于容器对外提供服务网络类型
  • 建设网站需要虚拟空间嘛专业网站制作公司采用哪些技术制作网站?
  • 超声波水表:原理、实现与核心技术解析
  • 怎样 建设电子商务网站直播网站app开发
  • Nginx 核心功能配置:访问控制、用户认证、HTTPS 与 URL 重写等
  • 大模型显存占用完全指南:从训练到推理的计算公式与实战案例(建议收藏)
  • 惠州做网站采招网招标官网
  • 烟台做网站找哪家好哪个网站可以做海报
  • 【星海出品】计算机科学之磁盘数据读取时间逻辑
  • 模力通AI风格仿写 让公文写作告别“风格焦虑”
  • 构建AI智能体:四十七、Agent2Agent多智能体系统:基础通信与任务协作实现
  • 天猫建设网站的意义张家港网站建设做网站
  • python爬虫进阶版练习(只说重点,selenium)
  • 东莞网站设计教程为企业做好服务保障
  • 福州网站建设q.479185700強网页翻译网站
  • 134、【OS】【Nuttx】【周边】效果呈现方案解析:端口映射(三)
  • 网站开发 报价单网站源码asp
  • Java HHH000490: Using JtaPlatform implementation
  • 网站关键词检测郑州外贸网站推广
  • 苏州网站开发的企业wordpress 结合qq
  • 在Linux中安装应用
  • 【高级语言范型介绍】
  • android 权限申请封装类
  • 个人习惯的各类chat大模型的使用场景
  • 济南网站建设与优化网站建设验收报告范本