神经网络 05(损失函数)

2023-09-12 16:51:41

一、损失函数

在深度学习中, 损失函数是用来衡量模型参数的质量的函数, 衡量的方式是比较网络输出和真实输出的差异,损失函数在不同的文献中名称是不一样的,主要有以下几种命名方式:

损失函数 (loss function)
代价函数(cost function)
目标函数(objective function)
误差函数(error function) 

二、分类任务

在深度学习的分类任务中使用最多的是 交叉熵损失函数

2.1 多分类任务

在多分类任务通常使用softmax将logits转换为概率的形式,所以多分类的交叉熵损失也叫做softmax损失,它的计算方法是:

其中,y 是样本 x 属于某一个类别的真实概率(onehot编码,0或者1),而 f(x) 是样本属于某一类别的预测分数,S 是 softmax 函数,L 用来衡量 p,q 之间差异性的损失结果。

举例:

从概率角度理解,我们的目的是最小化正确类别所对应的预测概率的对数的负值,如下图所示:

在tf.keras中使用CategoricalCrossentropy实现,如下所示:

# 导入相应的包
import tensorflow as tf
# 设置真实值和预测值
y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]# 两个交叉熵求平均值
# 实例化交叉熵损失
cce = tf.keras.losses.CategoricalCrossentropy()
# 计算损失结果
print(cce(y_true, y_pred).numpy()) # 1.176939

2.2 二分类任务

在处理二分类任务时,我们不再使用 softmax 激活函数,而是使用 sigmoid 激活函数,那损失函数也相应的进行调整,使用二分类的交叉熵损失函数

跟逻辑回归的损失函数是一样的

其中,y是样本x属于某一个类别的真实概率,而y^是样本属于某一类别的预测概率,L用来衡量真实值与预测值之间差异性的损失结果。

在 tf.keras 中实现时使用 BinaryCrossentropy(),如下所示:

# 导入相应的包
import tensorflow as tf
# 设置真实值和预测值
y_true = [[0], [1]]
y_pred = [[0.4], [0.6]]
# 实例化二分类交叉熵损失
bce = tf.keras.losses.BinaryCrossentropy() # 两个交叉熵求平均值

# 计算损失结果
print(bce(y_true, y_pred).numpy()) # 0.5108254

三、回归任务

3.1 MAE损失

Mean absolute loss(MAE)也被称为 L1 Loss,是以绝对误差作为距离:

特点是:

由于 L1 loss 具有稀疏性,为了惩罚较大的值,因此常常 将其作为正则项添加到其他 loss中作为约束

L1 loss 的最大问题是梯度在零点不平滑(不可导),导致会跳过极小值。

在 tf.keras 中使用MeanAbsoluteError 实现,如下所示:

# 导入相应的包
import tensorflow as tf
# 设置真实值和预测值
y_true = [[0.], [0.]]
y_pred = [[1.], [1.]]
# 实例化MAE损失
mae = tf.keras.losses.MeanAbsoluteError()
# 计算损失结果
print(mae(y_true, y_pred).numpy()) # 1.0

3.2 MSE损失

 Mean Squared Loss/ Quadratic Loss(MSE loss) 也被称为 L2 loss,或欧氏距离,它以误差的平方和作为距离

特点是:L2 loss 也常常作为正则项。当预测值与目标值相差很大时, 梯度容易爆炸

在 tf.keras 中通过 MeanSquaredError 实现:

# 导入相应的包
import tensorflow as tf
# 设置真实值和预测值
y_true = [[0.], [1.]]
y_pred = [[1.], [1.]]
# 实例化MSE损失
mse = tf.keras.losses.MeanSquaredError()
# 计算损失结果
print(mse(y_true, y_pred).numpy()) # 0.5

3.3 smooth L1 损失

其中:𝑥=f(x)−y 为真实值和预测值的差值

从上图中可以看出,该函数实际上就是一个分段函数,在[-1,1]之间实际上就是L2损失,这样解决了L1的不光滑问题,在[-1,1]区间外,实际上就是L1损失,这样就解决了离群点梯度爆炸的问题。通常在目标检测中使用该损失函数。

在 tf.keras 中使用 Huber 计算该损失,如下所示:

# 导入相应的包
import tensorflow as tf
# 设置真实值和预测值
y_true = [[0], [1]]
y_pred = [[0.6], [0.4]]
# 实例化smooth L1损失
h = tf.keras.losses.Huber()
# 计算损失结果
h(y_true, y_pred).numpy() # 0.18

更多推荐

【算法】相向双指针

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。推荐:kuan的首页,持续学习,不断总结,共同进步,活到老学到老导航檀越剑指大厂系列:全面总结java核心技术点,如集合,jvm,并发编程redis,kaf

NSSCTF web 刷题记录2

文章目录前言题目[广东强网杯2021团队组]love_Pokemon[NCTF2018]Easy_Audit[安洵杯2019]easy_web[NCTF2018]全球最大交友网站prize_p2[羊城杯2020]easyser[FBCTF2019]rceservice方法一方法二[WUSTCTF2020]颜值成绩查询前

SpringCLoud——服务的拆分和远程调用

服务拆分服务拆分注意事项一般是根据功能的不同,将不同的服务按照功能的不同而分开。微服务拆分注意事项不同微服务,不要重复开发相同业务微服务数据独立,不要访问其他微服务的数据库微服务可以将自己的业务暴露为接口,供其他微服务调用远程调用对于远程调用,之前我们说过,微服务之所以不能像单个服务那样互相的调用各自服务的信息,是因为

REC 系列 Visual Grounding with Transformers 论文阅读笔记

REC系列VisualGroundingwithTransformers论文阅读笔记一、Abstract二、引言三、相关工作3.1视觉定位3.2视觉Transformer四、方法4.1基础的视觉和文本编码器4.2定位编码器自注意力的文本分支文本引导自注意力的视觉分支4.3定位解码器定位query自注意力编码器-解码器自

告别加班烦恼:前端表格构建公式树的绝妙应用!

还在为满屏的公式而“内牛满面”吗?还在为长串的公式解析而发愁吗?还在定位错误的公式而苦恼吗?上班要写代码,加班还要分析这又长又臭的公式。你的发际线还好吗?本葡萄来拯救你的发际线啦!带来的不是洗发水,而是公式追踪!这一章,让本葡萄带你用前端电子表格的公式追踪构建公式树,快(解)速(救)分(你)析(的)公(脱)式(发)问题

【物联网】ROM、RAM和FLASH的区别

引言在计算机领域,我们经常听到ROM、FLASH和RAM这些术语,它们是计算机中不同类型的存储器。虽然它们都用于存储数据,但它们之间有着明显的区别。本文将详细介绍ROM、FLASH和RAM的区别,并给出具体的例子和解释。文章目录引言ROM(只读存储器)RAM(随机存储器)FLASH(闪存)总结ROM(只读存储器)ROM

【web开发】12、Django知识点回顾

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档文章目录总结、Django知识点回顾提示:以下是本篇文章正文内容,下面案例可供参考总结、Django知识点回顾安装Djangopipinstalldjango创建Django项目django-adminstartprojectmysite注意:Pychar

stm32---定时器输入捕获

一、输入捕获介绍在定时器中断实验章节中我们介绍了通用定时器具有多种功能,输入捕获就是其中一种。STM32F1除了基本定时器TIM6和TIM7,其他定时器都具有输入捕获功能。输入捕获可以对输入的信号的上升沿,下降沿或者双边沿进行捕获,通常用于测量输入信号的脉宽、测量PWM输入信号的频率及占空比。输入捕获的工作原理比较简单

新概念英语(第二册)复习——Lesson 6 - Lesson10

前言在学习6-10之前,确保1-5已经可以脱口而出,否则不需要学习6-10文章目录前言Lesson6-PercyButtons原文译文单词Lesson7-Toolate原文译文单词Lesson8-Thebestandtheworst原文译文单词Lesson9-Acoldwelcome译文单词Lesson10-NotFo

数组初学者向导:使用Python从零开始制作经典战舰游戏

引言战舰游戏,一个广受欢迎的经典游戏,为玩家提供了策略与猜测的完美结合。这个游戏的核心思想是通过猜测敌方船只的位置并尝试击沉它们来赢得比赛。在这篇文章中,我们将使用Python语言和数组来构建这款游戏,让你更加了解数组的操作和实用性。1.游戏概述在战舰游戏中,每位玩家都有一个10x10的网格,可以放置5艘船只。这些船只

C语言实现三子棋游戏(详解)

目录引言:1.游戏规则:2.实现步骤:2.1实现菜单:2.2创建棋盘并初始化:2.3绘制棋盘:2.4玩家落子:2.5电脑落子:2.6判断胜负:3.源码:结语:引言:《三子棋》是一款古老的民间传统游戏,又被称为OOXX棋、黑白棋,井字棋,九宫棋等。游戏分为双方对战,双方依次在九宫格棋盘上摆放棋子,率先将自己的三个棋子连成

热文推荐