4-3 nn.functional和nn.Module

2023-09-14 21:09:57

一,nn.functional 和 nn.Module

前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API。利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数,模型层,损失函数)。
其实:Pytorch和神经网络相关的功能组件大多都封装在** torch.nn **模块下。
这些功能组件的绝大部分既有函数形式实现,也有类形式实现。
其中nn.functional(一般引入后改名为F)有各种功能组件的函数实现。例如:
激活函数:
F.relu
F.sigmoid
F.tanh
F.softmax
模型层:
F.linear
F.conv2d
F.max_pool2d
F.dropout2d
F.embedding
损失函数:
F.binary_cross_entropy
F.mse_loss
F.cross_entropy
为了便于对参数进行管理,一般通过继承 nn.Module 转换成为类的实现形式,并直接封装在 nn 模块下。例如:
激活函数:
nn.ReLU
nn.Sigmoid
nn.Tanh
nn.Softmax
模型层:
nn.Linear
nn.Conv2d
nn.MaxPool2d
nn.Dropout2d
nn.Embedding
损失函数:
nn.BCELoss
nn.MSELoss
nn.CrossEntropyLoss
实际上nn.Module除了可以管理其引用的各种参数,还可以管理其引用的子模块,功能十分强大。
简单举例:
image.png

二,使用nn.Module来管理参数(配合nn.Parameter使用)

在Pytorch中,模型的参数是需要被优化器训练的,因此,通常要设置参数为 requires_grad = True 的张量。
同时,在一个模型中,往往有许多的参数,要手动管理这些参数并不是一件容易的事情。
Pytorch一般将参数用nn.Parameter来表示,并且用nn.Module来管理其结构下的所有参数。

requires_grad = True

手动设置:
image.png
nn.Parameter 具有 requires_grad = True 属性:
image.png

nn.ParameterList

列表形式
image.png

nn.ParameterDict

字典形式
image.png

Module管理

image.png
image.png

三、nn.Module构建模块类

实践当中,一般通过继承nn.Module来构建模块类,并将所有含有需要学习的参数的部分放在构造函数中。
以下范例为Pytorch中nn.Linear的源码的简化版本
可以看到它将需要学习的参数放在了__init__构造函数中,并在forward中调用F.linear函数来实现计算逻辑。

class Linear(nn.Module):
    __constants__ = ['in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

四、使用nn.Module来管理子模块

一般情况下,我们都很少直接使用 nn.Parameter来定义参数构建模型,而是通过拼装一些常用的模型层来构造模型。
这些模型层也是继承自nn.Module的对象,本身也包括参数,属于我们要定义的模块的子模块。
nn.Module提供了一些方法可以管理这些子模块。
children() 方法: 返回生成器,包括模块下的所有子模块。
named_children()方法:返回一个生成器,包括模块下的所有子模块,以及它们的名字。
modules()方法:返回一个生成器,包括模块下的所有各个层级的模块,包括模块本身
named_modules()方法:返回一个生成器,包括模块下的所有各个层级的模块以及它们的名字,包括模块本身。
其中chidren()方法和named_children()方法较多使用。
modules()方法和named_modules()方法较少使用,其功能可以通过多个named_children()的嵌套使用实现。

class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        
        self.embedding = nn.Embedding(num_embeddings = 10000,embedding_dim = 3,padding_idx = 1)
        self.conv = nn.Sequential()
        self.conv.add_module("conv_1",nn.Conv1d(in_channels = 3,out_channels = 16,kernel_size = 5))
        self.conv.add_module("pool_1",nn.MaxPool1d(kernel_size = 2))
        self.conv.add_module("relu_1",nn.ReLU())
        self.conv.add_module("conv_2",nn.Conv1d(in_channels = 16,out_channels = 128,kernel_size = 2))
        self.conv.add_module("pool_2",nn.MaxPool1d(kernel_size = 2))
        self.conv.add_module("relu_2",nn.ReLU())
        
        self.dense = nn.Sequential()
        self.dense.add_module("flatten",nn.Flatten())
        self.dense.add_module("linear",nn.Linear(6144,1))
        
    def forward(self,x):
        x = self.embedding(x).transpose(1,2)
        x = self.conv(x)
        y = self.dense(x)
        return y
    
net = Net()

children

image.png

named_children

image.png

modules

image.png
image.png

冻结参数

下面我们通过named_children方法找到embedding层,并将其参数设置为不可训练(相当于冻结embedding层)。
image.png
image.png
image.png

参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

更多推荐

7.2、如何理解Flink中的水位线(Watermark)

目录0、版本说明1、什么是水位线?2、水位线使用场景?3、设计水位线主要为了解决什么问题?4、怎样在flink中生成水位线?4.1、自定义标记Watermark生成器4.2、自定义周期性Watermark生成器4.3、内置Watermark生成器-有序流水位线生成器4.4、内置Watermark生成器-乱序流水位线生成

基于springboot广场舞团系统springboot16

大家好✌!我是CZ淡陌。一名专注以理论为基础实战为主的技术博主,将再这里为大家分享优质的实战项目,本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目,希望你能有所收获,少走一些弯路,向着优秀程序员前行!🍅更多优质项目👇🏻👇🏻可点击下方获取🍅文章底部或评论区获取🍅Java项目精品实

Linux安装mysql数据库并实现主从搭建

一.环境说明【环境说明】:192.168.110.161mysql-master##网络配置到位,防火墙关闭,selinux关闭192.168.110.162mysql-slave##网络配置到位,防火墙关闭,selinux关闭两台主机,操作系统是centos7,提前网络配置好,关闭防火墙,selinux,修改主机名二

esbuild中文文档-路径解析配置项(Path resolution - External、Main fields)

文章目录路径解析配置项Pathresolution外部模块External主字段Mainfields对于包的开发者结语哈喽,大家好!我是「励志前端小黑哥」,我带着最新发布的文章又来了!老规矩,小手动起来~点赞关注不迷路!esbuild简单介绍esbuild为了突破了JavaScript语言的瓶颈,采用了Go语言编写,构

Vue 如何监听 localstorage的变化

一.说明在日常开发中,我们经常使用localStorage来存储一些变量。这些变量会存储在浏览中。对于localStorage来说,即使关闭浏览器,这些变量依然存储着,方便我们开发的时候在别的地方使用。二.localStorage的使用存储值:window.localStorage.setItem(‘键名’,值)读取值

紫禁之巅-Unity游戏开发教程:勇者斗恶龙之魔法石

说明开设了一个unity游戏开发课程,可以帮助对游戏开发有兴趣的小伙伴学习Unity游戏开发的知识和技术,课程地址第一节课的课件是游戏工程,第二节的课件是大纲,和文章内相同,其它章节的课件和第一节课的相同,不需要重复下载课程大纲课程简介开设课程是为了帮助对游戏开发感兴趣的小伙伴掌握游戏开发的思路、方法、技术。为了帮助学

ps打开找不到MSVCP140.dll重新安装方法,安装ps出现msvcp140.dll缺失解决方法

在计算机中,我们可能会遇到许多问题,其中之一就是找不到msvcp140.dll文件。msvcp140.dll是一个动态链接库文件,它是MicrosoftVisualC++2015Redistributable的一部分。当计算机找不到这个文件时,可能会导致程序无法正常运行。本文将为您提供多个解决方法,以及msvcp140

每日一练 | 华为认证真题练习Day114

1、如图所示,交换机GE0/0/1和GE0/0/2两个端口都进行不同的此Hybrid配置,下面说法正确是()。(多选)A.财务部门发出的数据帧在交换机中携带的Tag为VLAN20B.行政部门和财务部门不能互访,因为两部门所属的VLAN不相同C.如果交换机的GE0/0/1和GE0/0/2两个端口都修改为Trunk端口,则

头歌平台 | 逻辑函数及其描述工具logisim使用

文章目录1、根据布尔表达式绘制电路2、根据真值表绘制电路3、根据简化真值表绘制电路4、根据波形图绘制电路5、根据卡诺图绘制电路1、根据布尔表达式绘制电路任务描述本关任务:在Logisim中根据给定的布尔代数表达式(F=AB+BC+CA)绘制逻辑电路。案例场景举例举重比赛裁判电路。在举重比赛中,通常有三位裁判(A、B、C

【SQL数据分析 | 手把手教你做淘宝用户分析!】

SQL也能做分析?当然!常见的数据清洗,预处理,数据分类,数据筛选,分类汇总,以及数据透视等操作,用SQL一样可以实现(除了可视化,需要放到Excel里呈现)。SQL不仅可以从数据库中读取数据,还能通过不同的SQL函数语句直接返回所需要的结果,从而大大提高了自己在客户端应用程序中计算的效率。但是,这个过程需要很熟练掌握

TypeScript入门

目录一:语言特性二:TypeScript安装NPM安装TypeScript三:TypeScript基础语法第一个TypeScript程序四:TypeScript保留关键字空白和换行TypeScript区分大小写TypeScript注释TypeScript支持两种类型的注释五:TypeScript与面向对象六:TypeS

热文推荐