揭秘:Wasserstein GAN与梯度惩罚(WGAN-GP)

2023-09-13 13:06:00

一、说明

        什么是梯度惩罚?为什么它比渐变裁剪更好?如何实施梯度惩罚?在提起GAN对抗网络中,就不能避免Wasserstein距离的概念,本篇为系列读物,目的是揭示围绕Wasserstein-GAN建模的一些重要概念进行探讨。

图1(左)使用配重裁剪时的梯度范数要么爆炸,要么消失,不使用GP。(右)与 GP 不同,权重裁剪将权重推向两个值。 

二、背景资料

        在这篇文章中,我们将研究带有梯度惩罚的Wasserstein GAN。虽然最初的Wasserstein GAN[2]提高了训练稳定性,但仍存在生成较差样本或无法收敛的情况。回顾一下,WGAN的成本函数为:

公式 1:WGAN 值函数。

        其中  1-利普希茨连续的。WGAN的问题主要是因为用于对批评者强制执行Lipschitz连续性的权重裁剪方法。WGAN-GP用对批评家的梯度范数的约束代替了权重裁剪,以强制执行Lipschitz的连续性。这允许比WGAN更稳定的网络训练,并且需要很少的超参数调优。WGAN-GP和这篇文章建立在Wasserstein GANs之上,这已经在揭秘系列的上一篇文章中讨论过。查看下面的帖子以了解 WGAN。

报表 1

可微的最优1-Lipschitz函数,最小化方程1的f*在Pr和Pg下几乎在任何地方都有单位梯度范数。

Pr 和 Pg 分别是真假分布。语句 1 的证明可以在 [1] 中找到。

三、渐变剪切问题

3.1 容量未充分利用

图2:WGAN评论家(上)使用梯度裁剪学习的值表面,(下)使用梯度惩罚学习的值表面。图片来源: [1]

使用权重裁剪来强制执行 k-Lipschitz 约束会导致批评者学习非常简单的函数。

从语句 1 中,我们知道最优批评者的梯度范数在 Pr 和 Pg 中几乎无处不在都是 1。在权重裁剪设置中,批评家试图达到其最大梯度范数 k并最终学习简单的函数。

图2显示了这种效果。批评者被训练收敛固定生成分布(Pg)作为实际分布(Pr)+单位高斯噪声。我们可以清楚地看到,使用权重裁剪训练的批评家最终学习了简单的函数并且未能捕捉到更高的时刻,而使用梯度惩罚训练的批评家则没有这个问题。

3.2 梯度爆炸和消失

权重约束和损失函数之间的相互作用使得WGAN的训练变得困难,并导致梯度爆炸或消失。

这在图1(左)中可以清楚地看到,其中注释器的权重在不同的削波值下爆炸或消失。图 1(右)还显示,渐变削波将注释器的权重推到两个极端削波值。另一方面,接受梯度惩罚训练的批评家不会遇到此类问题。

四、梯度惩罚

梯度惩罚的想法是强制执行一个约束,使得批评者输出的梯度与输入具有单位范数(语句 1)。

作者提出了该约束的软版本,对样本x̂∈P的梯度范数进行惩罚。新目标是

公式2:批评家损失函数

在方程 2 中,总和左侧的项是原始批评者损失,总和右侧的项是梯度惩罚。

Px̂ 是通过在实分布和生成的分布 Pr 和 Pg 之间沿直线均匀采样而获得的分布。这样做是因为最优注释器在从Pr和Pg耦合的样品之间具有单位梯度范数的直线。

λ,惩罚系数用于对梯度惩罚项进行加权。在论文中,作者为所有实验设置了λ = 10。

批规范化不再在注释中使用,因为批范数将一批输入映射到一批输出。在我们的例子中,我们希望能够找到每个输出的梯度,w.r.t它们各自的输入。

五、代码示例

5.1 梯度惩罚

 梯度惩罚的实现如下所示。

def compute_gp(netD, real_data, fake_data):
        batch_size = real_data.size(0)
        # Sample Epsilon from uniform distribution
        eps = torch.rand(batch_size, 1, 1, 1).to(real_data.device)
        eps = eps.expand_as(real_data)
        
        # Interpolation between real data and fake data.
        interpolation = eps * real_data + (1 - eps) * fake_data
        
        # get logits for interpolated images
        interp_logits = netD(interpolation)
        grad_outputs = torch.ones_like(interp_logits)
        
        # Compute Gradients
        gradients = autograd.grad(
            outputs=interp_logits,
            inputs=interpolation,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
        )[0]
        
        # Compute and return Gradient Norm
        gradients = gradients.view(batch_size, -1)
        grad_norm = gradients.norm(2, 1)
        return torch.mean((grad_norm - 1) ** 2)

5.2 关于WGAN-GP代码

训练 WGAN-GP 模型的代码可以在这里找到:

5.3 输出

图3:WGAN-GP模型生成的图像。请注意,结果是早期结果,一旦确认模型按预期训练,训练就会停止。

 

        图例.3显示了训练WGAN-GP的一些早期结果。请注意,图 3 中的图像是早期结果,一旦确认模型按预期训练,训练就会停止。该模型未经过训练以收敛。

六、结论

        Wasserstein GAN 在训练生成对抗网络方面提供了急需的稳定性。但是,使用梯度削波导致各种问题,例如梯度爆炸和消失等。梯度惩罚约束不受这些问题的影响,因此与原始WGAN相比,允许更容易的优化和收敛。这篇文章研究了这些问题,介绍了梯度惩罚约束,还展示了如何使用 PyTorch 实现梯度惩罚。最后,提供了训练WGAN-GP模型的代码以及一些早期阶段的输出。阿迪西亚·桑卡尔

七、引用

[1] Gulrajani, Ishaan, et al. “改进了 wasserstein gans 的训练”。arXiv预印本arXiv:1704.00028(2017)。

[2] 阿尔约夫斯基、马丁、苏米斯·钦塔拉和莱昂·博图。“Wasserstein generative adversarial networks。”机器学习国际会议。PMLR, 2017.

[3] GitHub - aadhithya/gan-zoo-pytorch: A zoo of GAN implementations

更多推荐

C++ 基本的输入输出

C++基本的输入输出C++标准库提供了一组丰富的输入/输出功能,我们将在后续的章节进行介绍。本章将讨论C++编程中最基本和最常见的I/O操作。C++的I/O发生在流中,流是字节序列。如果字节流是从设备(如键盘、磁盘驱动器、网络连接等)流向内存,这叫做输入操作。如果字节流是从内存流向设备(如显示屏、打印机、磁盘驱动器、网

【PX4】PX4第一个offborad例程

【PX4】PX4第一个offborad例程文章目录【PX4】PX4第一个offborad例程1.什么是OFFBOARD2.第一个offboard例程3.编写launch文件Reference1.什么是OFFBOARDPX4的OFFBOARD指的是外部控制模式,飞行器根据飞行控制栈外部(如机载计算机)提供的设定值控制位置

API安全

1API的简介API代表应用程序编程接口,它由一组允许软件组件进行通信的定义和协议组成。作为软件系统之间的中介,API使软件应用程序或服务能够共享数据和功能。但是API不仅仅提供连接基础,它还管理软件应用程序如何被允许进行通信和交互。API控制程序之间交换请求的类型、请求的方式以及允许的数据格式。例如,智能手机上的天气

小谈设计模式(2)—简单工厂模式

小谈设计模式(2)—简单工厂模式专栏介绍专栏地址专栏介绍简单工厂模式简单工厂模式组成抽象产品(AbstractProduct)具体产品(ConcreteProduct)简单工厂(SimpleFactory)三者关系核心思想Java代码实现首先,我们定义一个抽象产品接口Product,其中包含一个抽象方法use():然后

Redis的缓存、消息队列、计数器应用

目录一、redis的应用场景二、redis如何用于缓存三、redis如何用于消息队列四、redis如何用于计数器一、redis的应用场景Redis在实际应用中有广泛的应用场景,以下是一些常见的Redis应用场景:缓存:Redis可以用作缓存层,将频繁读取的数据存储在内存中,提高数据读取速度,减轻数据库负载。计数器:Re

Vulnhub系列靶机---HarryPotter-Fawkes-哈利波特系列靶机-3

文章目录信息收集主机发现端口扫描dirsearch扫描gobuster扫描漏洞利用缓冲区溢出edb-debugger工具msf-pattern工具docker容器内提权tcpdump流量分析容器外-sudo漏洞提权靶机文档:HarryPotter:Fawkes下载地址:Download(Mirror)难易程度:难上难信

Redis 集合操作实战(全)

目录SADD插入集合SCARD取元素数量SPOP随机移除元素SREM移除多个元素SMOVE移动元素到别的集合SMEMBERS取所有成员SRANDMEMBER取指定数量元素SISMEMBER判断元素是否存在SUNION多集合求并集SUNIONSTORE多集合求并集(存储)SINTER多集合求交集SINTERSTORE多集

PY32F003F18之比较器问题

PY32F003F18的模拟模块,其内部参考电压容易受到电源电压影响。当我连接"USB转串口的RXD"时,PC接收到模拟数据均正常;当我连接“USB转串口的TXD”时,发现内部参考电压的AD值为0xFFF。断开连接的“USB转串口的TXD”,模拟功能模块又恢复正常。于是用万用表测量“USB转串口的TXD”的电压,开路电

Spring高手之路10——解锁Spring组件扫描的新视角

文章目录1.组件扫描路径2.按注解过滤组件(包含)3.按注解过滤组件(排除)4.通过正则表达式过滤组件5.Assignable类型过滤组件6.自定义组件过滤器7.组件扫描的其他特性7.1组合使用组件扫描8.组件扫描的组件名称生成8.1Spring是如何生成默认bean名称的(源码分析)8.2生成默认bean名称的特殊情

一文巩固Spring MVC的Bean加载机制

目录一、什么是SpringMVC的Bean二、SpringMVC的Bean加载机制三、SpringMVC如何动态装载Bean一、什么是SpringMVC的Bean在SpringMVC中,Bean指的是在SpringIoC容器中创建和管理的对象。这些对象可以是普通的Java类,也可以是服务层组件、数据访问对象(DAO)或

手动实现 Spring 底层机制【初始化 IOC容器+依赖注入+BeanPostProcessor 机制+AOP】之实现任务阶段 5- bean 后置处理器

😀前言手动实现Spring底层机制【初始化IOC容器+依赖注入+BeanPostProcessor机制+AOP】的第五篇具体实现了任务阶段5-bean后置处理器🏠个人主页:尘觉主页🧑个人简介:大家好,我是尘觉,希望我的文章可以帮助到大家,您的满意是我的动力😉😉在csdn获奖荣誉:🏆csdn城市之星2名⁣⁣⁣

热文推荐