自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与Teacher forcing(教师强制)

2023-07-14 15:27:35

自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)


作者:安静到无声 个人主页

作者简介:人工智能和硬件设计博士生、CSDN与阿里云开发者博客专家,多项比赛获奖者,发表SCI论文多篇。

Thanks♪(・ω・)ノ 如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒+关注哦! o( ̄▽ ̄)d

欢迎大家来到安静到无声的 《基于pytorch的自然语言处理入门与实践》,如果对所写内容感兴趣请看《基于pytorch的自然语言处理入门与实践》系列讲解 - 总目录,同时这也可以作为大家学习的参考。欢迎订阅,请多多支持!

1. Scheduled Sampling(计划采样)

1.1 概念解释

Scheduled Sampling是一种用于训练序列生成模型的策略,旨在缓解曝光偏差(Exposure Bias)问题。曝光偏差是指模型在训练时接触到的数据分布与测试时的数据分布不一致,导致性能下降。

在Scheduled Sampling中,模型在每个时间步骤都有一定的概率选择使用真实目标序列中的单词作为输入,而不是使用前一个时间步骤生成的单词。这样可以使模型更好地适应真实数据分布,减少曝光偏差问题。

具体来说,Scheduled Sampling使用以下公式计算每个时间步骤生成当前单词的概率:

P ( y t ∣ y 1 , . . . , y t − 1 ) = ( 1 − ϵ ) ∗ P model ( y t ∣ y 1 , . . . , y t − 1 ) + ϵ ∗ P data ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) = (1 - \epsilon) * P_{\text{model}}(y_t|y_1, ..., y_{t-1}) + \epsilon * P_{\text{data}}(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)=(1ϵ)Pmodel(yty1,...,yt1)+ϵPdata(yty1,...,yt1)其中, P ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)表示在给定前面的生成序列条件下生成当前单词 y t y_t yt的概率, P model ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{model}}(y_t|y_1, ..., y_{t-1}) Pmodel(yty1,...,yt1)表示模型生成该单词的概率, P data ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{data}}(y_t|y_1, ..., y_{t-1}) Pdata(yty1,...,yt1)表示真实目标序列中该单词的概率。参数 ϵ \epsilon ϵ用于控制采样策略,可以随着训练的进行而逐渐增加。

1.2 代码实现

下面是一个使用Python实现Scheduled Sampling的示例代码:

在这里插入图片描述

其中, P ( y t ∣ y < t , x ) P(y_t | y_{<t}, x) P(yty<t,x)表示在给定前文和输入的条件下,生成当前时间步的输出的概率。 P model P_{\text{model}} Pmodel表示由模型生成的概率分布, P prev P_{\text{prev}} Pprev表示根据上一个时间步的真实输出计算得到的概率分布。sample是从均匀分布中采样得到的一个随机数,threshold是一个控制Scheduled Sampling引入程度的超参数。

2. Teacher forcing(教师强制)

2.1 概念解释

Teacher forcing(教师强制)是一种在序列生成模型中使用的训练技术。具体来说,当使用RNN(循环神经网络)或类似架构的模型进行序列生成时,每个时间步都会根据前一个时间步的输入和隐藏状态生成输出。在训练期间,如果使用teacher forcing,那么每个时间步的输入将是真实的目标序列(而不是模型自身生成的序列)。这意味着模型在每个时间步都能够观察到正确的答案,从而更容易地学习到正确的模式和规律。

2.1 代码实现

import torch
import torch.nn as nn

# 定义序列到序列模型
class Seq2SeqModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Seq2SeqModel, self).__init__()
        self.hidden_dim = hidden_dim
        
        # 定义编码器
        self.encoder = nn.RNN(input_dim, hidden_dim)
        
        # 定义解码器
        self.decoder = nn.RNN(output_dim, hidden_dim)
        
        # 定义全连接层,将解码器的输出映射为目标序列
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, input_seq, target_seq):
        # 编码器计算输入序列的隐藏状态
        _, hidden_state = self.encoder(input_seq)
        
        # 解码器初始化隐藏状态
        decoder_hidden_state = hidden_state
        
        # 用真实目标序列作为输入来指导解码器的生成过程
        decoder_outputs, _ = self.decoder(target_seq, decoder_hidden_state)
        
        # 对解码器的输出应用全连接层进行映射
        output_seq = self.fc(decoder_outputs)
        
        return output_seq

# 创建模型实例
input_dim = 10
hidden_dim = 20
output_dim = 10
model = Seq2SeqModel(input_dim, hidden_dim, output_dim)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    # 步骤1:将模型设为训练模式
    model.train()
    
    # 步骤2:清零梯度
    optimizer.zero_grad()
    
    # 步骤3:前向传播
    input_seq = torch.randn(5, 3, input_dim)  # 输入序列
    target_seq = torch.randn(5, 3, output_dim)  # 目标序列
    output_seq = model(input_seq, target_seq)
    
    # 步骤4:计算损失
    loss = criterion(output_seq, target_seq)
    
    # 步骤5:反向传播和优化
    loss.backward()
    optimizer.step()
    
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

上述代码中,我们定义了一个简单的序列到序列模型Seq2SeqModel,其中包括一个RNN编码器、一个RNN解码器和一个全连接层。在forward方法中,我们首先使用编码器计算输入序列的隐藏状态,然后将隐藏状态作为解码器的初始隐藏状态。接下来,我们使用真实目标序列来指导解码器的生成过程,并将解码器的输出映射为目标序列。在训练阶段,我们使用真实目标序列作为输入来指导模型的生成过程。最后,我们定义了损失函数和优化器,并进行训练。

需要注意的是,在实际应用中,模型的推理阶段并不会使用真实目标序列来指导生成过程。在推理阶段,可以将前一个时间步的模型输出作为下一个时间步的输入,从而进行序列的自我生成。

--------推荐专栏--------
🔥 手把手实现Image captioning
💯CNN模型压缩
💖模式识别与人工智能(程序与算法)
🔥FPGA—Verilog与Hls学习与实践
💯基于Pytorch的自然语言处理入门与实践

参考

Scheduled Sampling的搜索结果_百度图片搜索 (baidu.com)
Teacher forcing RNN的搜索结果_百度图片搜索 (baidu.com)

在这里插入图片描述

更多推荐

<Altium Designer> 将.DSN文件导入并转换成SchDoc文件

目录01使用向导方式导入.DSN02消除UniqueIdentifiersErrors03文章总结大家好,这里是程序员杰克。一名平平无奇的嵌入式软件工程师。本文主要是总结和分享将OrCADCapture画的原理图文件(.DSN)导入到AltiumDesigner,转换成对应的原理图文件(SchDoc)的方法。本文所使用

MySQL正则表达式:模式匹配、中文匹配、替换、提取字符串

在MySQL中,使用REGEXP或RLIKE操作符进行正则表达式匹配,而使用NOTREGEXP或NOTRLIKE操作符进行不匹配。一些常用的MySQL正则表达式语法:匹配字符:.:匹配任意字符(除了换行符)。[]:匹配方括号中的任意字符。[^]:匹配不在方括号中的任意字符。匹配重复:*:匹配零个或多个前面的字符。+:匹

【C++从0到王者】第三十一站:map与set

文章目录一、关联式容器二、pair键值对三、set1.set的介绍2.set的部分接口以及应用3.count4.lower_bound和upper_bound5.equal_range6.multiset容器四、map1.map的介绍2.map的一些常见接口以及使用3.map的[]运算符重载4.使用map改进一些题5.

代理IP和Socks5代理:跨界电商与爬虫的智能引擎

跨界电商,作为全球市场的一部分,对数据的需求越来越大。同时,随着互联网的发展,爬虫技术也在不断演进,成为了跨界电商的关键工具之一。然而,随之而来的是网站的反爬虫机制和网络安全风险。在这种情况下,代理IP和Socks5代理应运而生,为企业提供了数据采集的解决方案和网络安全的保护。本文将深入研究代理IP和Socks5代理在

应用平台 - OPPO敏感权限

那天在OPPO平台更新app时,发现平台权限升级,新增了敏感权限校验,而且还是必填项…Google从Android6.0开始就对权限做了分类适配,粗浅来看将权限分为了普通权限、危险权限(运行时权限、敏感权限),如果需要用到危险权限除了需要在AndroidManifest(清单文件)注册之外,我们还需要进行申请动态权限有

【配电变电站的最佳位置和容量】基于遗传算法的最优配电变电站放置(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。⛳️座右铭:行百里者,半于九十。📋📋📋本文目录如下:🎁🎁🎁目录💥1概述📚2运行结果🎉3参考文献🌈4Matlab代码实现💥1概述基于遗传算法的最优配电变电站放置为了实现配电变电站

接口测试之文件上传

在日常工作中,经常有上传文件功能的测试场景,因此,本文介绍两种主流编写上传文件接口测试脚本的方法。首先,要知道文件上传的一般原理:客户端根据文件路径读取文件内容,将文件内容转换成二进制文件流的格式传输给服务端,而服务端接受客户端传过来的二进制文件流以及文件名称等信息(此时这些二进制文件流存储在内存中),然后将其写入存储

物联网网络安全:保护物理世界和数字世界的融合

我们正在见证数字技术如何成为我们日常生活和经济系统的一部分,从而提高福利并增强竞争力。尽管如此,新的尖端互联技术的迅速出现和采用也对政府、企业和整个社会构成了重大威胁。长期以来,网络安全威胁一直是电影行业的一个现成的灵感来源,它设想了一些令人担忧的场景,在这些场景中,滥用技术和数据会危及社会、企业和政府。然而,被描绘成

角度回归——角度编码方式

文章目录1.为什么研究角度的编码方式?1.1角度本身具有周期性1.2深度学习的损失函数因为角度本身的周期性,在周期性的点上可能产生很大的Loss,造成训练不稳定1.3那么如何处理边界问题呢:(以θ的边界问题为例)1.3顺时针(CW)1.4逆时针(CCW)2角度回归的方式2.1长边定义法,强制W<H,range范围[-9

如何更好的选择服务器硬盘?

一.选择服务器硬盘时,可以考虑以下几个因素:1.容量需求:首先确定您的服务器对存储容量的需求。评估您预计需要存储的数据量、应用程序和文件的大小,以及未来的扩展需求。确保选择的硬盘能够满足服务器的存储需求,并有足够的空间用于备份和增长。2.性能要求:考虑您的服务器对性能的需求。如果服务器需要处理大量的读写操作、高速数据传

解决@vueup/vue-quill图片上传、视频上传问题

Editor.vue<template><el-upload:action="uploadUrl":before-upload="handleBeforeUpload":on-success="handleUploadSuccess"name="files":on-error="handleUploadError":s

热文推荐