深度学习归一化原理及代码实现(BatchNorm2d,LayerNorm,InstanceNorm,GroupNorm)

2023-09-18 19:27:25

概述

本文记录总结pytorch中四种归一化方式的原理以及实现方式。方便后续理解和使用。
本文原理理解参考自

https://zhuanlan.zhihu.com/p/395855181

形式

四种归一化的公式都是相同的,即
在这里插入图片描述
其实就是普通的归一化公式,

((x-均值)/标准差)*γ +β

γ和β是可学习参数,代表着对整体归一化值的缩放(scale) γ和偏移(shift) β。

四种不同形式的归一化归根结底还是归一化维度的不同。

形式原始维度均值/方差的维度
BatchNorm2dNCHW1C11
LayerNormNCHWN111
InstanceNormNCHWNC11
GroupNormNCHWNG11 (G=1,LN,G=C,IN)

原理理解

在这里插入图片描述

  1. BatchNorm2d 从维度上分析,就是在NHW维度上分别进行归一化,保留特征图的通道尺寸大小进行的归一化。
    由上图理解,蓝色位置代表一个归一化的值,BN层的目的就是将每个batch的hw都归一化,而保持通道数不变。抽象的理解就是结合不同batch的通道特征。因此这种方式比较适合用于分类,检测等模型,因为他需要对多个不同的图像有着相同的理解。
  2. LayerNorm 从维度上分析,就是在CHW上对对象的归一化,该归一化的目的可以保留每个batch的自有特征。抽象上来理解,就是通过layernorm让每个batch都有不同的值,有不同的特征,因此适用于图像生成或RNN之类的工作
  3. InstanceNorm从维度上来分析,就是将HW归一化为一个值,保留在通道上C和batch上的特征N。相当于对每个batch每个通道做了归一化。可以保留原始图像的信号而不混杂,因此常用于风格迁移等工作。
  4. GroupNorm从维度上来分析,近似于IN和LN,但是就是在通道上可以分成若干组(G),当G代表权重通道时就变成了LN,当G代表单通道就变成了IN,我也不清楚为什么用这个,但是G通常好像设置为32.

源代码实现

结合以上理解,就可以从原理上实现pytorch中封装的四个归一化函数。如下所示。

1.BatchNorm2d

import torch
import torch.nn as nn

class CustomBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,scale=1,shift=0):
        super(CustomBatchNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        # 可训练参数
        self.scale = scale
        self.shift = shift
        
        # 不可训练的运行时统计信息
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)
    
    def forward(self, x):
        # 计算输入张量的均值和方差
        mean = x.mean(dim=(0, 2, 3), keepdim=True)
        print("mean.shape",mean.shape)
        var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
        print("var.shape",var.shape)
        
        # 更新运行时统计信息 (Batch Normalization在训练和推理模式下的行为不同)
        if self.training:
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.squeeze()
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze()
        
        # 归一化输入张量
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        
        # 应用 scale 和 shift 参数
        scaled_x = self.scale.view(1, -1, 1, 1) * x_normalized + self.shift.view(1, -1, 1, 1)
        
        return scaled_x
if __name__ =="__main__":
    # 创建示例输入张量
    x = torch.randn(16, 3, 32, 32)  # 示例输入数据

    scale = nn.Parameter(torch.randn(x.size(1)))
    shift = nn.Parameter(torch.randn(x.size(1)))

    # 创建自定义批量归一化层
    custom_batchnorm = CustomBatchNorm2d(num_features=3,scale=scale,shift=shift)

    # 调用自定义批量归一化层
    normalized_x_custom = custom_batchnorm(x)

    # 创建官方的批量归一化层
    official_batchnorm = nn.BatchNorm2d(num_features=3)
    official_batchnorm.weight=scale
    official_batchnorm.bias=shift

    # 调用官方批量归一化层
    normalized_x_official = official_batchnorm(x)

    # 检查自定义层和官方层的输出是否一致
    are_equal = torch.allclose(normalized_x_custom, normalized_x_official, atol=1e-5)
    print("自定义批量归一化和官方批量归一化是否一致:", are_equal)


2.LayerNorm

import torch
import torch.nn as nn

class CustomLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5,scale=1,shift=0):
        super(CustomLayerNorm, self).__init__()
        self.normalized_shape = normalized_shape
        self.eps = eps
        
        # 可训练参数
        # self.scale = nn.Parameter(torch.ones(normalized_shape))
        # self.shift = nn.Parameter(torch.zeros(normalized_shape))
        self.scale = scale
        self.shift = shift
    
    def forward(self, x):
        # 计算输入张量 x 的均值和方差
        mean = x.mean(dim=(1,2,3), keepdim=True)
        variance = x.var(dim=(1,2,3), unbiased=False, keepdim=True)
        
        # 归一化输入张量
        x_normalized = (x - mean) / torch.sqrt(variance + self.eps)
        
        # # 应用 scale 和 shift 参数
        
        scaled_x = self.scale * x_normalized + self.shift
        # 应用 scale 和 shift 参数
        #scaled_x = self.scale.view(-1, 1, 1, 1) * x_normalized + self.shift.view(-1, 1, 1, 1)
        
        return scaled_x

# 创建示例输入张量
x = torch.randn(16, 3, 32, 32)  # 示例输入数据

scale = nn.Parameter(torch.randn(3,32,32))
shift = nn.Parameter(torch.randn(3,32,32))

# 创建自定义 Layer Normalization 层
#custom_layernorm = CustomLayerNorm(normalized_shape=16)
custom_layernorm = CustomLayerNorm(normalized_shape=(3,32,32),scale=scale,shift=shift)

# 调用自定义 Layer Normalization 层
normalized_x_custom = custom_layernorm(x)

# 创建官方的 Layer Normalization 层
#official_layernorm = nn.LayerNorm(normalized_shape=3)
official_layernorm = nn.LayerNorm(normalized_shape=(3,32,32))
official_layernorm.weight=scale
official_layernorm.bias=shift
#official_layernorm = nn.LayerNorm(normalized_shape=(0,2,3))

# 调用官方 Layer Normalization 层
normalized_x_official = official_layernorm(x)
#print(normalized_x_official.shape)

# # 检查自定义层和官方层的输出是否一致
are_equal = torch.allclose(normalized_x_custom, normalized_x_official, atol=1e-5)
print("自定义 Layer Normalization 和官方 Layer Normalization 是否一致:", are_equal)

3.InstanceNorm

import torch
import torch.nn as nn

class CustomInstanceNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5,scale=1,shift=0):
        super(CustomInstanceNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        
        # 不可训练参数
        # self.scale = nn.Parameter(torch.ones(num_features))
        # self.shift = nn.Parameter(torch.zeros(num_features))
        self.scale = scale
        self.shift = shift
    
    def forward(self, x):
        # 计算输入张量 x 的均值和方差
        mean = x.mean(dim=(2, 3), keepdim=True)
        variance = x.var(dim=(2, 3), unbiased=False, keepdim=True)
        
        # 归一化输入张量
        x_normalized = (x - mean) / torch.sqrt(variance + self.eps)
        
        # 应用 scale 和 shift 参数
        scaled_x = self.scale.view(1, -1, 1, 1) * x_normalized + self.shift.view(1, -1, 1, 1)
        
        return scaled_x

# 创建示例输入张量
x = torch.randn(16, 3, 32, 32)  # 示例输入数据

# 创建自定义 Instance Normalization 层
scale = nn.Parameter(torch.randn(3))
shift = nn.Parameter(torch.randn(3))
custom_instancenorm = CustomInstanceNorm(num_features=3,scale=scale,shift=shift)

# 调用自定义 Instance Normalization 层
normalized_x_custom = custom_instancenorm(x)

# 创建官方的 Instance Normalization 层
official_instancenorm = nn.InstanceNorm2d(num_features=3)
official_instancenorm.weight=scale
official_instancenorm.bias=shift

# 调用官方 Instance Normalization 层
normalized_x_official = official_instancenorm(x)

# # 检查自定义层和官方层的输出是否一致
are_equal = torch.allclose(normalized_x_custom, normalized_x_official, atol=1e-5)
print("自定义 Layer Normalization 和官方 Layer Normalization 是否一致:", are_equal)



4.GroupNorm

import torch
import torch.nn as nn

class CustomGroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5,scale=1,shift=0):
        super(CustomGroupNorm, self).__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        
        # 不可训练参数
        self.scale = scale
        self.shift = shift
    
    def forward(self, x):
        # 将输入张量 x 分成 num_groups 个组
        # 注意:这里假定 num_channels 可以被 num_groups 整除
        group_size = self.num_channels // self.num_groups
        x = x.view(-1, self.num_groups, group_size, x.size(2), x.size(3))
        
        # 计算每个组的均值和方差
        mean = x.mean(dim=(2, 3, 4), keepdim=True)
        variance = x.var(dim=(2, 3, 4), unbiased=False, keepdim=True)
        
        # 归一化输入张量
        x_normalized = (x - mean) / torch.sqrt(variance + self.eps)
        
        # 将组合并并应用 scale 和 shift 参数
        x_normalized = x_normalized.view(-1, self.num_channels, x.size(3), x.size(4))
        scaled_x = self.scale.view(1, -1, 1, 1) * x_normalized + self.shift.view(1, -1, 1, 1)
        
        return scaled_x

# 创建示例输入张量
x = torch.randn(16, 6, 32, 32)  # 示例输入数据,有6个通道

# 创建自定义 Group Normalization 层
scale = nn.Parameter(torch.randn(6))
shift = nn.Parameter(torch.randn(6))
custom_groupnorm = CustomGroupNorm(num_groups=3, num_channels=6,scale=scale,shift=shift)

# 调用自定义 Group Normalization 层
normalized_x_custom = custom_groupnorm(x)

# 创建官方的 Group Normalization 层
official_groupnorm = nn.GroupNorm(num_groups=3, num_channels=6)
official_groupnorm.weight = scale
official_groupnorm.bias = shift

# 调用官方 Group Normalization 层
normalized_x_official = official_groupnorm(x)

# # 检查自定义层和官方层的输出是否一致
are_equal = torch.allclose(normalized_x_custom, normalized_x_official, atol=1e-5)
print("自定义 Layer Normalization 和官方 Layer Normalization 是否一致:", are_equal)

如果有用帮忙点个赞哦

更多推荐

webpack、vue.config.js

一、webpack学习简述webpack是一个静态资源打包工具,它会以一个或多个文件作为打包的入口,将我们整个项目的文件编译组合成一个或多个文件输出出去。输出的文件就是编译好的文件,可以运行在浏览器中。一般的我们将webpack输出的文件叫做bundle为什么需要打包工具随着现在前端技术的发展,我们会使用各种框架(Vu

化工DCS/SIS/MIS系统时钟同步(NTP服务器)建设

化工DCS/SIS/MIS系统时钟同步(NTP服务器)建设化工DCS/SIS/MIS系统时钟同步(NTP服务器)建设目前计算机网络中各主机和服务器等网络设备的时间基本处于无序的状态。随着计算机网络应用的不断涌现,计算机的时间同步问题成为愈来愈重要的事情。以Unix系统为例,时间的准确性几乎影响到所有的文件操作。如果一台

python经典百题之判断回文数

题目:一个5位数,判断它是不是回文数即12321是回文数,个位与万位相同,十位与千位相同程序分析回文数是指一个数从左向右和从右向左读是一样的,例如:12321。我们需要编写一个程序来判断一个5位数是否是回文数。方法1:转换成字符串defis_palindrome(num):num_str=str(num)returnn

2023年海南省职业院校技能大赛(高职组)信息安全管理与评估赛项规程

2023年海南省职业院校技能大赛(高职组)信息安全管理与评估赛项规程一、赛项名称赛项名称:信息安全管理与评估英文名称:InformationSecurityManagementandEvaluation赛项组别:高等职业教育赛项归属产业:电子与信息大类二、竞赛目标为全面贯彻落实国家网络强国战略,对接新一代信息技术产业,

java中mysql事务嵌套回滚

在Java开发中,MySQL事务嵌套回滚时经常会遇到。本文将介绍如何在Java中处理MySQL事务嵌套回滚的问题。在开始之前,我们需要先了解什么是事务嵌套回滚。当在一个事务中嵌套了其他事务并且其中一个事务回滚时,该事务及其所有嵌套的事务都会被回滚。这可以保持数据的一致性。但是,重要的是,要正确处理异常和回滚。下面是Ja

良好的测试环境应该怎么搭建?对软件产品起到什么作用?

为了确保软件产品的高质量,搭建一个良好的测试环境是至关重要的。在本文中,我们将从多个角度出发,详细描述良好的测试环境的搭建方法、注意事项以及对软件产品的作用。一、软件测试环境的搭建1、从硬件设备的选择与配置开始。对于大型软件产品的测试,建议使用高性能的服务器以及分布式测试平台。在选择服务器时,要考虑产品的特性、测试需求

【结构型】享元模式(Flyweight)

目录享元模式(Flyweight)适用场景享元模式实例代码(Java)享元模式(Flyweight)运用共享技术有效地支持大量细粒度的对象。(业务模型的对象进行细分得到科学合理的更多对象)适用场景一个应用程序使用了大量的对象。完全由于使用大量的对象,造成很大的存储开销。对象的大多数状态都可变为外部状态。如果删除对象的外

【2023集创赛】加速科技杯作品:高光响应的二硫化铼光电探测器

本文为2023年第七届全国大学生集成电路创新创业大赛(“集创赛”)加速科技杯西北赛区二等奖作品分享,参加极术社区的【有奖征集】分享你的2023集创赛作品,秀出作品风采,分享2023集创赛作品扩大影响力,更有丰富电子礼品等你来领!团队介绍参赛单位:西北工业大学队伍名称:噜啦噜啦咧指导老师:李伟参赛队员:程琳,韩笑,尹天乐

苹果手机怎么录屏?1分钟轻松搞定

虽然一直使用苹果手机,但是对它的录屏功能还不是很会使用。苹果手机怎么录屏?录屏可以录制声音吗?麻烦大家教教我!苹果手机为用户提供了十分便捷的内置录屏功能,可以让您随时随地录制手机上的内容。但是很多小伙伴在第一次使用苹果手机时,找不到苹果手机的录屏工具在哪,所以不知道该如何进行录屏。那么,苹果手机怎么录屏呢?下面将给大家

python+nodejs+php+springboot+vue 学生选课程作业提交教学辅助管理系统

二、项目设计目标与原则1、关于课程作业管理系统的基本要求(1)功能要求:可以管理首页、个人中心、公告信息管理、班级管理、学生管理、教师管理、课程类型管理、课程信息管理、学生选课管理、作业布置管理、作业提交管理、作业评分管理、课程评价管理、课程资源管理等功能模块。(2)性能:在不同操作系统上均能无差错实现在不同类型的用户

[答疑]角色和状态的区别

DDD领域驱动设计批评文集“软件方法建模师”不再考查基础题《软件方法》各章合集jeri2023-9-1013:09设备关联角色,设备也有子类(车辆/设备),按书中的解释,设备是一个抽象类,角色类名像是带了状态名的类,如在使用的设备/在维护的设备,设备和这几个角色是关联关系,而且是0.1的关系,潘老师的观点是泛化关系还是

热文推荐