博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
NVIDIA新作解读:用GAN生成前所未有的高清图像(附PyTorch复现) | PaperDaily #15
阅读量:6215 次
发布时间:2019-06-21

本文共 5595 字,大约阅读时间需要 18 分钟。

今天要介绍的文章是 NVIDIA 投稿 ICLR 2018 的一篇文章,Progressive Growing of GANs for Improved Quality, Stability, and Variation[1],姑且称它为 PG-GAN。

从行文可以看出文章是临时赶出来的,毕竟这么大的实验,用 P100 都要跑 20 天,更不用说调参时间了,不过人家在 NVIDIA,不缺卡。作者放出了基于 Lasagna 的代码,今天我也会简单解读一下代码。另外,我也在用 PyTorch 做复现。

在 PG-GAN 出来以前,训练高分辨率图像生成的 GAN 方法主要就是 LAPGAN[2] 和 BEGAN[6]。后者主要是针对人脸的,生成的人脸逼真而不会是鬼脸。

这里也提一下,生成鬼脸的原因是 Discriminator 不再更新,它不能再给予 Generator 其他指导,Generator 找到了一种骗过 Discriminator 的方法,也就是生成鬼脸,而且很大可能会 mode collapse。

下图是我用 PyTorch 做的 BEGAN 复现,当时没有跑很高的分辨率,但是效果确实比其他 GAN 好基本没有鬼脸。

a11d32d72aa7891e926879814d84509f2c467cbf

PG-GAN 能够稳定地训练生成高分辨率的 GAN。我们来看一下 PG-GAN 跟别的 GAN 不同在哪里。

1. 训练方式

作者采用 progressive growing 的训练方式,先训一个小分辨率的图像生成,训好了之后再逐步过渡到更高分辨率的图像。然后稳定训练当前分辨率,再逐步过渡到下一个更高的分辨率。

ab5376e507b284de3bf741db9873c2c5015ea801

如上图所示。更具体点来说,当处于 fade in(或者说 progressive growing)阶段的时候,上一分辨率(4x4)会通过 resize+conv 操作得到跟下一分辨率(8x8)同样大小的输出,然后两部分做加权,再通过 to_rgb 操作得到最终的输出。

这样做的一个好处是它可以充分利用上个分辨率训练的结果,通过缓慢的过渡(w 逐渐增大),使得训练生成下一分辨率的网络更加稳定。

上面展示的是 Generator 的 growing 阶段。下图是 Discriminator 的 growing,它跟 Generator 的类似,差别在于一个是上采样,一个是下采样。这里就不再赘述。

b1528b2f1eb484b0cea3c9634fe42a666adf1d06

不难想象,网络在 growing 的时候,如果不引入 progressive (fade in),那么有可能因为比较差的初始化,导致原来训练的进度功亏一篑,模型不得不从新开始学习,如此一来就没有充分利用以前学习的成果,甚至还可能误导。我们知道 GAN的训练不稳定,这样的突变有时候是致命的。所以 fade in 对训练的稳定性来说至关重要。

说到 growing 的训练方式,我们很容易想到 autoencoder 也有一种类似的训练方式:先训各一层的 encoder 和 decoder,训好了以后再过渡到训练各两层的 encoder 和 decoder,这样的好处是避免梯度消失,导致离 loss 太远的层更新不够充分。PG-GAN 的做法可以说是这种 autoencoder 训练方式在 GAN 训练上的应用。

此外,训练 GAN 生成高分辨率图像,还有一种方法,叫 LAPGAN[2]。LAPGAN 借助 CGAN,高分辨率图像的生成是以低分辨率图像作为条件去生成残差,然后低分辨率图上采样跟残差求和得到高分辨率图,通过不断堆叠 CGAN 得到我们想要的分辨率。

077599d3eef33e0e6a23e8f03324721c0e051dc3

LAPGAN 是多个 CGAN 堆叠一起训练,当然可以拆分成分阶段训练,但是它们本质上是不同的,LAPGAN 学的是残差,而 PG-GAN 存在 stabilize 训练阶段,学的不是残差,而直接是图像。

作者在代码中设计了一个 LODSelectLayer 来实现 progressive growing。对于 Generator,每一层插入一个 LODSelectLayer,它实际上就是一个输出分支,实现在特定层的输出。

从代码来看,作者应该是这样训练的(参见这里的 train_gan 函数),先构建 4x4 分辨率的网络,训练,然后把网络存出去。再构建 8x8 分辨率的网络,导入原来 4x4 的参数,然后训 fade in,再训 stabilize,再存出去。我在复现的时候,根据文章的意思,修改了 LODSelectLayer 层,因为 PyTorch 是动态图,能够很方便地写 if-else 逻辑语句。

借助这种 growing 的方式,PG-GAN 的效果超级好。另外,我认为这种 progressive growing 的方法比较适合 GAN 的训练,GAN 训练不稳定可以通过 growing 的方式可以缓解。

不只是在噪声生成图像的任务中可以这么做,在其他用到 GAN 的任务中都可以引入这种训练方式。我打算将 progressive growing 引入到 CycleGAN 中,希望能够得到更好的结果。

2. 增加生成多样性

增加生成样本的多样性有两种可行的方法:通过 loss 让网络自己调整、通过设计判别多样性的特征人为引导。

WGAN 属于前者,它采用更好的分布距离的估计(Wasserstein distance)。模型收敛意味着生成的分布和真实分布一致,能够有多样性的保证。PG-GAN 则属于后者。

作者沿用 improved GAN 的思路,通过人为地给 Discriminator 构造判别多样性的特征来引导 Generator 生成更多样的样本。Discriminator 能探测到 mode collapse 是否产生了,一旦产生,Generator 的 loss 就会增大,通过优化 Generator 就会往远离 mode collapse 的方向走,而不是一头栽进坑里。

Improved GAN 引入了 minibatch discrimination 层,构造一个 minibatch 内的多样性衡量指标。它引入了新的参数。

2d5f7e6d2b7e6fde41bf93ba8488951a26d3d07c

而 PG-GAN 不引入新的参数,利用特征的标准差作为衡量标准。

84150e86ed99fe9f40758caefee4152cb3cd6fbb

这里啰嗦地说明上面那张图做了什么。我们有 N 个样本的 feature maps(为了画图方便,不妨假设每个样本只有一个 feature map),我们对每个空间位置求标准差,用 numpy 的 std 函数来说就是沿着样本的维度求 std。这样就得到一张新的 feature map(如果样本的 feature map 不止一个,那么这样构造得到的 feature map 数量应该是一致的),接着 feature map 求平均得到一个数。

这个过程简单来说就是求 mean std,作者把这个数复制成一张 feature map 的大小,跟原来的 feature map 拼在一起送给 Discriminator。

从作者放出来的代码来看,这对应 averaging=“all”的情况。作者还尝试了其他的统计量:“spatial”,“gpool”,“flat”等。它们的主要差别在于沿着哪些维度求标准差。至于它们的作用,等我的代码复现完成了会做一个测试。估计作者调参发现“all”的效果最好。

3. Normalization

从 DCGAN[3]开始,GAN 的网络使用 batch (or instance) normalization 几乎成为惯例。使用 batch norm 可以增加训练的稳定性,大大减少了中途崩掉的情况。作者采用了两种新的 normalization 方法,不引入新的参数(不引入新的参数似乎是 PG-GAN 各种 tricks 的一个卖点)。

第一种 normalization 方法叫 pixel norm,它是 local response normalization 的变种。Pixel norm 沿着 channel 维度做归一化,这样归一化的一个好处在于,feature map 的每个位置都具有单位长度。这个归一化策略与作者设计的 Generator 输出有较大关系,注意到 Generator 的输出层并没有 Tanh 或者 Sigmoid 激活函数,后面我们针对这个问题进行探讨。

48abe70aba74ac035c7ccd01da9b30d9d8b7d326

第二种 normalization 方法跟凯明大神的初始化方法[4]挂钩。He 的初始化方法能够确保网络初始化的时候,随机初始化的参数不会大幅度地改变输入信号的强度。

ce7af75de8064cfcadd9f8c13dd1f3383b463c7b

根据这个式子,我们可以推导出网络每一层的参数应该怎样初始化。可以参考 PyTorch 提供的接口。

作者走得比这个要远一点,他不只是初始化的时候对参数做了调整,而是动态调整。初始化采用标准高斯分布,但是每次迭代都会对 weights 按照上面的式子做归一化。作者 argue 这样的归一化的好处在于它不用再担心参数的 scale 问题,起到均衡学习率的作用(euqalized learning rate)。

4. 有针对性地给样本加噪声

通过给真实样本加噪声能够起到均衡 Generator 和 Discriminator 的作用,起到缓解 mode collapse 的作用,这一点在 WGAN 的前传中就已经提到[5]。尽管使用 LSGAN 会比原始的 GAN 更容易训练,然而它在 Discriminator 的输出接近 1 的适合,梯度就消失,不能给 Generator 起到引导作用。

针对 D 趋近 1 的这种特性,作者提出了下面这种添加噪声的方式:

7d463232346e686eb1bd53732134b014eb698b33

其中,
?tp=webp&wxfrom=5&wx_lazy=1
分别为第 t 次迭代判别器输出的修正值、第 t-1 次迭代真样本的判别器输出。 

从式子可以看出,当真样本的判别器输出越接近 1 的时候,噪声强度就越大,而输出太小(<=0.5)的时候,不引入噪声,这是因为 0.5 是 LSGAN 收敛时,D 的合理输出(无法判断真假样本),而小于 0.5 意味着 D 的能力太弱。

文章还有其他很多 tricks,有些 tricks 不是作者提出的,如 Layer norm,还有一些比较细微的 tricks,比如每个分辨率训练好做 sample 的时候学习率怎么 decay,每个分辨率的训练迭代多少次等等,我们就不再详细展开。具体可以参见官方代码,也可以看我复现的代码。

目前复现的结果还在跑,现在训练到了 16x16 分辨率的 fade in 阶段,放一张当前的结果图,4 个方格的每个方格左边 4 列是生成的图,右边 4 列是真实样本。现在还处于训练早期,分辨率太低,过几天看一下高分辨率的结果。

7518816bea365273c824fe84896ee9fdd106b1c2

5. 相关代码

官方 Lasagna 代码:

https://github.com/tkarras/progressive_growing_of_gans

作者 PyTorch 复现:

https://github.com/github-pengge/PyTorch-progressive_growing_of_gans

6. 参考文献

[1]. Karras T, Aila T, Laine S, et al. Progressive Growing of GANs for Improved Quality, Stability, and Variation[J]. arXiv preprint arXiv:1710.10196, 2017.

[2]. Denton E L, Chintala S, Fergus R. Deep Generative Image Models using a Laplacian Pyramid of Adversarial Networks[C]//Advances in neural information processing systems. 2015: 1486-1494.

[3]. Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.

[4]. He K, Zhang X, Ren S, et al. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification[C]//Proceedings of the IEEE international conference on computer vision. 2015: 1026-1034.

[5]. Arjovsky M, Bottou L. Towards principled methods for training generative adversarial networks[J]. arXiv preprint arXiv:1701.04862, 2017.

[6]. Berthelot D, Schumm T, Metz L. Began: Boundary equilibrium generative adversarial networks[J]. arXiv preprint arXiv:1703.10717, 2017.

原文发布时间为:2017-11-16

本文作者:洪佳鹏

本文来自云栖社区合作伙伴“”,了解相关信息可以关注“”微信公众号

转载地址:http://owvja.baihongyu.com/

你可能感兴趣的文章
[置顶] 步步辨析JS中的对象成员
查看>>
上门送装家具建材已成趋势 还在“跑断腿”就OUT了!
查看>>
双11黑科技,阿里百万级服务器自动化运维系统StarAgent揭秘
查看>>
聊聊sentinel的SimpleHttpCommandCenter
查看>>
前端面试CSS
查看>>
CSS图形绘制总结
查看>>
FutureTask源码解析(1)——预备知识
查看>>
二叉树的非递归前序遍历
查看>>
animationend 事件
查看>>
JS进阶篇--JS中的反柯里化( uncurrying)
查看>>
关于多电脑布署hexo博客,和在线更新文章
查看>>
Python迭代器、生成器、装饰器深入解读
查看>>
Azure Service Fabric正式发布
查看>>
区块链安全:2019年我们走了多远?
查看>>
JHipster 3.0发布,新增微服务支持
查看>>
池化.NET内存流以解决大内存堆分配问题
查看>>
从Handler.post(Runnable r)再一次梳理Android的消息机制(以及handler的内存泄露)
查看>>
TIOBE 7 月编程语言榜:TypeScript 进入前 50 名
查看>>
python进程注入shellcode
查看>>
WPFの阴影效果
查看>>