从原理到代码实践 | pytorch损失函数

2023-09-16 10:48:15


对于图像分类任务,模型最终是通过softmax操作输出一个概率分布向量的(各个类别和为1)

假设我们有三类别 [ 小车,小牛,小火箭 ],假设有两张图片,分别有两个模型来对这两张图片分别预测 ,我们将真实标签转换为概率分布——独热码,如下图所示

在这里插入图片描述

假设两个模型经过训练后输出概率分布如下图

比如模型一对图片一火箭的预测为 [0.3,0.3,0.4 ] 说明模型认为图片是小车的概率是0.3,小牛的概率是0.3,小火箭的概率是0.4 ,最大的概率值便是最后的预测——小火箭
在这里插入图片描述

所以

模型一对图片一,二的预测分别为 火箭 和 小牛

模型二对图片 一,二的预测也分别为 火箭和 小牛

1.损失函数原理

1.1 Classification Error(分类错误率)

如果我们以图片错误率作为损失函数

预测错误数 图片总数 \frac{预测错误数}{图片总数} 图片总数预测错误数

模型一二都预测对一张图片,预测错一张图片

那么模型一二 的损失都为0.5,这样无法区分两个模型的好坏

但实际上对于图片一,模型一预测小火箭的概率很低0.4 而模型二预测很高0.9,我们更期望模型二这样的情况出现

1.2. 均方差损失

在这里插入图片描述

模型一

​ 图片一

( 0.3 − 0 ) 2 + ( 0.3 − 0 ) 2 + ( 0.4 − 1 ) 2 = 0.54 (0.3-0)^2+(0.3-0)^2+(0.4-1)^2=0.54 (0.30)2+(0.30)2+(0.41)2=0.54

​ 图片二

( 0.1 − 1 ) 2 + ( 0.8 − 0 ) 2 + ( 0.1 − 0 ) 2 = 1.46 (0.1-1)^2+(0.8-0)^2+(0.1-0)^2=1.46 (0.11)2+(0.80)2+(0.10)2=1.46

M S E = 0.54 + 1.46 2 = 1 MSE=\frac{0.54+1.46}{2}=1 MSE=20.54+1.46=1

模型二

​ 图片一

( 0.1 − 0 ) 2 + ( 0.1 − 0 ) 2 + ( 0.9 − 1 ) 2 = 0.03 (0.1-0)^2+(0.1-0)^2+(0.9-1)^2=0.03 (0.10)2+(0.10)2+(0.91)2=0.03

​ 图片二

( 0.4 − 1 ) 2 + ( 0.5 − 0 ) 2 + ( 0.1 − 0 ) 2 = 0.62 (0.4-1)^2+(0.5-0)^2+(0.1-0)^2=0.62 (0.41)2+(0.50)2+(0.10)2=0.62

M S E = 0.03 + 0.62 2 = 0.325 MSE=\frac{0.03+0.62}{2}=0.325 MSE=20.03+0.62=0.325

我们发现,MSE能够判断出来模型2优于模型1,那为什么不采样这种损失函数呢?在分类中更多采用交叉熵损失函数呢

(1)在分类问题中,我们通常希望模型的输出概率分布尽可能地接近真实标签的概率分布,而交叉熵损失函数可以直接衡量这种差异。相比之下,MSE损失函数只能衡量模型输出和标签之间的距离,无法直接反映概率分布的差异。

(2)交叉熵损失函数对于模型的误差敏感度更高。在分类问题中,误差大部分发生在模型输出概率最大的那个类别上,而交叉熵损失函数在这种情况下的梯度更大,可以更快地更新模型参数,从而提高模型的准确性。相比之下,MSE损失函数在这种情况下的梯度相对较小,更新速度更慢。

1.3 交叉熵损失函数

1.3.1 数学原理

假设存在两个概率分布 P,Q

注意这里的log是以e为底

H ( p ) = − ∑ x p ( x ) l o g p ( x ) H(p)=-\sum_xp(x)logp(x) H(p)=xp(x)logp(x)

熵是信息论中用于衡量随机变量不确定性的指标,它表示一个随机变量的平均信息量。熵越大,表示随机变量的不确定性越大,即信息量越大。例如 [ 0 , 0 , 1 ] [0 , 0 ,1] [0,0,1]这个分布没啥信息量,因为他的不确定度很小

我们通过上面公式计算一下他的熵会发现为0

而对于分布[0.3,0.3.0.4]这个分布不确定性比较大,熵值就更大了

相对熵: K L ( p ∣ ∣ q ) = − ∑ x p ( x ) l o g q ( x ) p ( x ) KL(p||q)=-\sum_xp(x)log\frac{q(x)}{p(x)} KL(p∣∣q)=xp(x)logp(x)q(x)

相对熵,也叫KL散度用来度量两个分布的不相似性(这里不叫做距离,是因为距离的话P到q和q到p的距离应该是一样的)而这里的话有可能不一样

如果两个分布一样,则相对熵为0,如果两个分布差异越大,相对熵越大

比如分布P[0,0,1]为和分布Q为 [0.3,0.3,0.4] 的相对熵为0.39,说明他俩相差比较大

而分布P为[0,0,1]和分布Q为 [0,0.1,0.9]相对熵0.04.说明他俩相差较小

实际中用到更多的是交叉熵

交叉熵: H ( p , q ) = − ∑ x p ( x ) l o g q ( x ) H(p,q)=-\sum_xp(x)logq(x) H(p,q)=xp(x)logq(x)

因为三者存在这样一个关系

H ( p , q ) = H ( p ) + K L ( p ∣ ∣ q ) H(p,q)=H(p)+KL(p||q) H(p,q)=H(p)+KL(p∣∣q)

而如果P分布是标答,分布是独热码的形式,那么它的H§ 就等于0 ,这样的话,我们就可以用交叉熵来代表相对熵了,计算更简单

再代入计算刚刚的交叉熵损失函数

在这里插入图片描述

模型一

注意这里的log是以e为底

​ 图片一

− ( 0 ∗ l o g ( 0.3 ) + 0 ∗ l o g ( 0.3 ) + 1 ∗ l o g ( 0.4 ) ) = − l o g ( 0.4 ) = 0.9163 -(0*log(0.3)+0*log(0.3)+1*log(0.4))=-log(0.4)=0.9163 (0log(0.3)+0log(0.3)+1log(0.4))=log(0.4)=0.9163

​ 图片二

− ( 1 ∗ l o g ( 0.1 ) + 0 ∗ l o g ( 0.8 ) + 0 ∗ l o g ( 0.1 ) ) = − l o g ( 0.1 ) = 2.3026 -(1*log(0.1)+0*log(0.8)+0*log(0.1))=-log(0.1)=2.3026 (1log(0.1)+0log(0.8)+0log(0.1))=log(0.1)=2.3026

模型二

​ 图片一

− ( 0 ∗ l o g ( 0.1 ) + 0 ∗ l o g ( 0.1 ) + 1 ∗ l o g ( 0.9 ) ) = − l o g ( 0.9 ) = 0.1054 -(0*log(0.1)+0*log(0.1)+1*log(0.9))=-log(0.9)=0.1054 (0log(0.1)+0log(0.1)+1log(0.9))=log(0.9)=0.1054

​ 图片二

− ( 1 ∗ l o g ( 0.4 ) + 0 ∗ l o g ( 0.5 ) + 0 ∗ l o g ( 0.1 ) ) = − l o g ( 0.4 ) = 0.9163 -(1*log(0.4)+0*log(0.5)+0*log(0.1))=-log(0.4)=0.9163 (1log(0.4)+0log(0.5)+0log(0.1))=log(0.4)=0.9163

明显 模型二损失更低,要优于模型一

1.3.2 代码实现

观察发现实际输出就是真实标签概率的负对数

所以用pytorch库来一句话简单实现

import torch
#第一个模型对两张图片的预测
y_hat=torch.tensor([[0.3,0.3,0.4],[0.1,0.8,0.1]])
#真实标签 2代表第3类小火箭,0代表第1类小车
y=torch.tensor([2,0])  
def cross_entropy(y_hat,y):
    return -torch.log(y_hat[range(len(y_hat)),y])
print(cross_entropy(y_hat,y))

可能会对y_hat[range(len(y_hat)),y]操作比较迷糊,我当时也是这样的

这里参考李沐老师的动手深度学习,我们没有用for循环进行计算,而是用了索引,会更加高效,具体如下:

  1. y_hat[range(len(y_hat)),y]:对 y_hat 进行了索引操作,其中 y 是一个一维张量,包含了两个整数,用于选取每行中的一个元素。

    具体来说,range(len(y_hat)) 表示一个从 0 到 1 的整数序列,对应于 y_hat 中的两行,而 y 则表示选取每行中的一个元素的下标。因此,这个表达式的含义是选取 y_hat 中每行中下标为 y 的元素,返回一个一维张量。

当然,实际中,我们更多使用高级API

nn.CrossEntropyLoss

可以一步计算Softmax和交叉熵损失,同时可以解决溢出等问题

更多推荐

springboot和vue:三、web入门(spring-boot-starter- web+控制器+路由映射+参数传递)

spring-boot-starter-webSpringBoot将传统Web开发的mvc、json、tomcat等框架整合,提供了spring-boot-starter-web组件,简化了Web应用配置。创建SpringBoot项目勾选SpringWeb选项后,会自动将spring-boot-starter-web组

Android 回声消除

Android回声消除前言在语音聊天、语音通话、互动直播、语音转文字类应用或者游戏中,需要采集用户的麦克风音频数据,然后将音频数据发送给其它终端或者语音识别服务。如果直接使用采集的麦克风数据,就会存在回音问题。所谓回音就是在语音通话过程中,如果用户开着扬声器,那么自己讲话的声音和对方讲话的声音(即是扬声器的声音)就会混

c++基础:new函数

new函数new是用于动态分配内存的操作符。它用于在堆内存中创建一个新的对象或数据结构,并返回一个指向该内存的指针。这是C++中进行动态内存分配的主要方式之一,通常与delete操作符一起使用来释放先前分配的内存。以下是使用new操作符的一些示例:动态分配一个整数,并将其赋值给指针:int*pInt=newint;*p

Python基础运算分享

Python的运算符和其他语言类似(我们暂时只了解这些运算符的基本用法,方便我们展开后面的内容,高级应用暂时不介绍)数学运算>>>print1+9#加法>>>print1.3-4#减法>>>print3*5#乘法>>>print4.5/1.5#除法>>>print3**2#乘方>>>print10%3#求余数判断判断是

抖音开网店无货源怎么找

随着社交媒体的快速发展,抖音已经成为了一种极具潜力的电商平台。许多人想要利用这个平台开设网店,但是其中很多人面临的问题是如何找到货源。无货源的抖音网店经营固然具有一定的难度,但并非不可行。以下是一些帮助你在抖音开网店无货源的方法。代销合作:寻找与制造商或批发商的代销合作是一种常见的方式。你可以与他们签订协议,销售他们的

1787_函数指针的使用

全部学习汇总:GitHub-GreyZhang/c_basic:littlebitsofc.前阵子似乎写了不少错代码,因为对函数指针的理解还不够。今天晚上似乎总算是梳理出了一点眉目,在先前自己写过的代码工程中做一下测试。先前实现过一个归并排序算法,算法函数的一个传入参数是指向一个比较功能函数的指针。当时进行代码实现的时

数据治理-数据存储和操作-数据架构类型

数据库可以分为集中式数据库和分布式数据。集中式系统管理单一数据库,而分布式系统管理多个系统上的多个数据库。分布式系统组件可以根据组件系统的自治性分为两类:联邦的和非联邦的。集中式数据库集中式数据库将所有数据存放在一个地方的一套系统中,所有用户连接到这套系统进行数据访问。对某些访问受限的数据来说,集中化可能是理想的,但对

如何在Spring Boot中配置双数据源?

如何在SpringBoot中配置双数据源?背景双数据源优点技术用法添加依赖配置数据源创建实体类和存储库配置数据源和实体管理器配置事务管理器实现双数据源背景在许多应用程序中,可能会遇到需要连接多个数据库的情况。这些数据库可以是不同的类型,例如关系型数据库和NoSQL数据库,或者它们可以是相同类型但包含不同的数据。为了处理

Spring Boot 如何配置 Hikari 数据库连接池

目录一、SpringBoot介绍二、什么是数据库连接池三、Hikari介绍四、配置Hikari一、SpringBoot介绍SpringBoot是一个开源的Java框架,它简化了基于Spring的应用程序的开发和部署。它提供了一种快速、方便的方式来创建独立的、可扩展的、生产级别的Spring应用程序。SpringBoot

Matlab--微积分问题的计算机求解

目录1.单变量函数的极限问题1.1.公式例子1.2.对应例题12.多变量函数的极限问题3.函数导数的解析解4.多元函数的偏导数5.Jacobian函数6.Hessian矩阵7.隐函数的偏导8.不定积分问题的求解9.定积分的求解问题10.多重积分的问题求解1.单变量函数的极限问题1.1.公式例子%%%3.1.1.单变量函

Springboot 实践(21)服务熔断机制

在微服务架构中,服务众多,通常会涉及到多个服务层的调用,一旦基础服务发生故障,很可能会导致级联故障,继而造成整个系统不可用,这种现象被称为服务雪崩效应。服务熔断引入熔断器概念,熔断器如果在一段时间内侦测到许多类似错误,就会强迫其以后的多个调用快速失败,不在访问远程服务器,从而防止应用程序不断地尝试执行可能会失败的操作,

热文推荐