生成对抗网络(GAN)
1. 入门简介
gan整体的损失函数
minGmaxDV(G,D)=Ex−PdatalogD(x)+Ez−Pzlog(1−D(G(z))) 训练时,先训练Discriminator、然后训练Generator,迭代直至目标函数收敛。
需要注意的是,一切损失计算都是在D(判别器)输出处产生的,而D的输出一般是fake/true的判断,所以整体上采用的是二分类交叉熵函数。
首先看一下maxD部分,因为训练一般是先保持G(生成器)不变训练D的。D的训练目标是正确区分fake/true,如果我们以1/0代表true/fake,则对第一项E因为输入采样自真实数据所以我们期望D(x)趋近于1,也就是第一项更大。同理第二项E输入采样自G生成数据,所以我们期望D(G(z))趋近于0更好,也就是说第二项又是更大。所以是这一部分是期望训练使得整体更大了,也就是maxD
的含义了。
第二部分保持D不变,训练G,这个时候只有第二项E有用了,关键来了,因为我们要迷惑D,所以这时将label设置为1(我们知道是fake,所以才叫迷惑),希望D(G(z))输出接近于1,也就是这一项越小越好,这就是minG。当然判别器D哪有这么好糊弄,所以这个时候判别器就会产生比较大的误差,误差会更新G,那么G就会变得更好了,这次没有骗过你,只能下次更努力了。
Discriminator的损失函数
maxDlog[D(x)]+log[1−D(G(z))] Generator的损失函数
minGlog[1−D(G(z))] 在(近似)最优判别器下,最小化生成器的loss等价于最小化Pr与Pg之间的JS散度。
下图中可以发现,所有的loss都是由判别器产生的。如果没有D,G不知道自己生成的结果如何,便得不到权重更新。
1 | import torch |
1 | class Generator(nn.Module): |
2. 各式各样的GAN
2.1DCGAN
深度卷积生成对抗网络,在生成器中,对输入的一维向量不断进行转置卷积(上采样)最终生成对应的图像。在判别器中,则将输入的图像经过多层卷积最后经过sigmod函数进行二分类,判断这是原始数据图片还是生成器产生的图片。
1 | def weights_init(m): |
2.2 Conditional GAN
CGAN的目标函数与原始的并无太大不同,只不过加了一个限定条件。
minGmaxDV(D,G)=Ex−pdata[log(D(x|y))]+Ez−pz[log[1−D(G(z|y))]]1 | # G(z) |
结合介绍的两种,可以定义cDCNGAN
模型(就是把Linear全连接层换为了ConvTranspose2d或Conv2d卷积层)。
2.3 Bidirectional GAN
讲述BiGAN的两篇论文分别为:
Donahue, Jeff, Philipp Krähenbühl, and Trevor Darrell. “Adversarial feature learning.” arXiv preprint arXiv:1605.09782 (2016).
Dumoulin, Vincent, et al. “Adversarially learned inference.” arXiv preprint arXiv:1606.00704 (2016).
网络架构
- 目标函数minG,EmaxDV(D,E,G)
代码参考:https://github.com/fmu2/Wasserstein-BiGAN
2.4 WGAN
Arjovsky, M., Chintala, S., & Bottou, L. (2017). Wasserstein gan. arXiv preprint arXiv:1701.07875.(gradient clipping)
Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., & Courville, A. C. (2017). Improved training of wasserstein gans. In Advances in neural information processing systems (pp. 5767-5777).(gradient penalty)
参考:https://zhuanlan.zhihu.com/p/25071913(令人拍案叫绝的WGAN)。
Wasserstein距离也被称为Earthmover′s距离(推土机距离)。Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。
我们可以构造一个含参数w、最后一层不是非线性激活层的判别器网络fw,在限制w不超过某个范围的条件下,使得
L=Ex−Pr[fw(x)]−Ex−PG[fw(x)]尽可能取到最大,此时L就会近似真实分布与生成分布之间的Wasserstein距离(忽略常数倍数K)。
注:判别器要迭代训练多次。而生成器只训练一次。
WGAN在原生的GAN做出的改进:
- G和D的损失函数不用对数
- 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行
- D最后一层去掉sigmod二分类函数
- 采用gradient clipping和gradient penalty(改进)
原始GAN存在的问题:
- 判别器越好,生成器越容易产生梯度消失。
- 训练不稳定,容易导致collapsemode。
2.5 StackGAN由文本生成高分辨率图像
Zhang, Han, et al. “Stackgan: Text to photo-realistic image synthesis with stacked generative adversarial networks.” Proceedings of the IEEE international conference on computer vision. 2017.
2.6 GANomaly异常检测
- 网络架构:
可以看出,模型包含两个encoder、一个decoder(相当于生成器)和一个判别器。模型划分为三个部分:第一部分为一个自动编码器,包含一个encoder(GE)、一个decoder(GD),这一部分被记为G;第二部分为一个encoder,记为E;第三部分为一个判别器网络,记为D。前两部分也被称为G-Net。
输入图片数据x经过一个encoder(GE)编码为向量z,decoder(GD)将向量z还原为原尺寸图像数据ˆx,另一个encoder(E)将ˆx又编码为向量ˆz。将x和ˆx输入判别器网络(D)判断图片是原始图片还是生成器生成的图片。
- 损失函数
损失函数共分为三部分,第一部分是EnocderLoss,衡量两个encoder编码向量的损失;第二部分是ContextualLoss,衡量原图像与生成器生成图像的损失,第三部分是AdversialLoss,是常规的GAN中判别网络的损失,这里采用的是二分类的交叉熵损失。
优化D-net,采用AdversialLoss。
优化G-net时,采用三部分损失函数的加权和。
- 异常检测
原理:由于训练输入的都是正常数据,第一个encoder学习到的是正常数据的分布,经过生成器的重建后再经过encoder编码差异不会很大,当输入异常数据时,encoder编码后会损失部分信息,经过生成器重建后再编码会与原来的数据差异很大,从而进行异常检测。
A(ˆx)=||GE(ˆx)−E(G(ˆx))||1 当异常得分A大于某一阈值时,模型就会判定该数据为异常数据。(异常检测并没有用到判别器)。
2.7 DiscoGAN关联分析
模型主要由两个生成器和两个判别器构成。
GAB:输入A领域(domain)图片,生成B领域图片
GBA:输入B领域图片,生成A领域图片
DA:判别A领域原始图像和GBA生成的A领域图像
DB:判别B领域原始图像和GAB生成的B领域图像