PyTorch深度学习(一)【线性模型、梯度下降、随机梯度下降】

2023-09-15 19:43:16

这个系列是实战(刘二大人讲的pytorch)

建议把代码copy下来放在编译器查看(因为很多备注在注释里面)

线性模型(Linear Model):

import numpy as npimport matplotlib.pyplot as plt  #绘图的包x_data = [1.0, 2.0, 3.0]  #这两行代表数据集,一般x_data,y_data是要把它分开保存的,x表示输入样本y_data = [2.0, 4.0, 6.0]  #相同的索引表示一组样本,就是(1.0,2.0)表示一对样本,(2.0,4.0)表示一对样本。def forward(x):   #定义模型,取名“前馈模型”return x * w  #用x与w相乘返回(Linear Model)def loss(x, y):   #定义损失函数    y_pred = forward(x)return (y_pred - y) ** 2# 穷举法w_list = []       #权重值mse_list = []     #对应权重损失值for w in np.arange(0.0, 4.1, 0.1):  #权重间隔为0.1,从0.0开始取,4.0结束。[0.0,0.1,...,4.0]    print("w=", w)    l_sum = 0for x_val, y_val in zip(x_data, y_data):  #把x_data,y_data这两个列表里边的数据拿出来用zip拼成真实数据的x,y值        y_pred_val = forward(x_val)           #首先计算预测(可以不计算,主要只是打印一下结果看一下);y_pred_val是预测值,loss函数计算会用到        loss_val = loss(x_val, y_val)         #计算损失        l_sum += loss_val                     #损失求和        print('\t', x_val, y_val, y_pred_val, loss_val)    print('MSE=', l_sum / 3)    w_list.append(w)    mse_list.append(l_sum / 3)                #除以3,转成mse(均方误差)plt.plot(w_list, mse_list)        #绘图plt.ylabel('Loss')plt.xlabel('w')plt.show()

使用到的损失函数如下:

"y_pred"就是在求

zip函数:

【番外:可视化常用到的工具---visdom】

练习:实现线性模型(y=wx+b)并输出loss的3D图像。

import numpy as npimport matplotlib.pyplot as pltfrom mpl_toolkits.mplot3d import Axes3D#这里设函数为y=3x+2x_data = [1.0,2.0,3.0]y_data = [5.0,8.0,11.0]def forward(x):return x * w + bdef loss(x,y):y_pred = forward(x)return (y_pred-y)*(y_pred-y)mse_list = []W=np.arange(0.0,4.1,0.1)B=np.arange(0.0,4.1,0.1)[w,b]=np.meshgrid(W,B)l_sum = 0for x_val, y_val in zip(x_data, y_data):y_pred_val = forward(x_val)print(y_pred_val)loss_val = loss(x_val, y_val)l_sum += loss_valfig = plt.figure()ax = Axes3D(fig)fig.add_axes(ax)ax.plot_surface(w, b, l_sum/3)plt.show(block=True)

梯度下降

(鞍点:梯度为0。陷入鞍点没办法迭代)

cost公式:(cost function是对所有的样本)

代码:

import matplotlib.pyplot as plt# 准备训练集数据x_data = [1.0, 2.0, 3.0]  #两个列表,分别表示x和y的值,(1.0,2.0)表示第一条数据样本y_data = [2.0, 4.0, 6.0]  #(2.0,4.0)表示第二条数据样本# initial guess of weightw = 1.0   #初始权重猜测# define the model linear model y = w*xdef forward(x):   #定义前馈计算return x * w    #y^# define the cost function MSEdef cost(xs, ys):   #把所有的数据都拿进来    cost = 0for x, y in zip(xs, ys):        y_pred = forward(x)    #算y^        cost += (y_pred - y) ** 2return cost / len(xs)      #MSE(平均损失的计算)# define the gradient function  gddef gradient(xs, ys):   #求梯度    grad = 0for x, y in zip(xs, ys):        grad += 2 * x * (x * w - y)return grad / len(xs)epoch_list = []cost_list = []print('predict (before training)', 4, forward(4))for epoch in range(100):    #训练过程(100轮)    cost_val = cost(x_data, y_data)  #计算当前损失值,也就是cost    grad_val = gradient(x_data, y_data)   #求梯度    w -= 0.01 * grad_val  # 0.01 learning rate    #学习率*梯度    print('epoch:', epoch, 'w=', w, 'loss=', cost_val)    epoch_list.append(epoch)    cost_list.append(cost_val)print('predict (after training)', 4, forward(4))plt.plot(epoch_list, cost_list)plt.ylabel('cost')plt.xlabel('epoch')plt.show()

【番外:“指数加权均值”方法能够将cost变得更平滑】

在大多数的情况下,得到的loss图像形状趋势都是如上图所示,如果出现右边有又上去了的情况,则说明训练发散了,这次训练失败了。

训练失败的情况有很多,其中最常见的是:学习率取得太大。(可以将学习率调小再看看效果)

随机梯度下降

(只用一个样本,即使陷入了鞍点,也也有可能跨过这个鞍点向前推进找最优点)

公式:(单个样本的损失函数对权重求导,然后进行更新)

代码:

import matplotlib.pyplot as pltx_data = [1.0, 2.0, 3.0]y_data = [2.0, 4.0, 6.0]w = 1.0def forward(x):return x * w# calculate loss functiondef loss(x, y):    y_pred = forward(x)   #y^return (y_pred - y) ** 2    #loss# define the gradient function  sgddef gradient(x, y):return 2 * x * (x * w - y)      #梯度epoch_list = []loss_list = []print('predict (before training)', 4, forward(4))for epoch in range(100):for x, y in zip(x_data, y_data):        grad = gradient(x, y)    #对每一个样本求梯度,loss对w求梯度        w = w - 0.01 * grad  # update weight by every grad of sample of training set    更新        print("\tgrad:", x, y, grad)        l = loss(x, y)    #计算现在的损失    print("progress:", epoch, "w=", w, "loss=", l)    epoch_list.append(epoch)    loss_list.append(l)print('predict (after training)', 4, forward(4))plt.plot(epoch_list, loss_list)plt.ylabel('loss')plt.xlabel('epoch')plt.show()

性能好,但时间复杂度太高,没有并行性。

【番外:Batch。(性能和时间复杂度上取折中)批量的随机梯度下降法。

就是说如果你全都丢到一起,性能不好;全都分开呢,时间复杂度不好。

因此可以若干个分为一组,每次用这一租样本去求相应的梯度,然后进行更新。这个就叫做Batch。】

更多推荐

代码签名证书品牌哪家好?选微软推荐机构

代码签名证书是保护软件代码完整性及来源可信的重要方式,软件程序要在操作系统中运行,就需要使用权威合规的代码签名证书,对软件代码进行数字签名,确保软件来源可信、未被非法篡改,消除操作系统“未知发布者”警告,让软件能够顺畅运行。众多代码签名证书厂商中,哪些厂商提供的代码签名证书才是获得操作系统信任的呢?本文将为大家介绍,如

【JVM】类加载器

类与类加载器类加载器虽然只用于实现类的加载动作,但它在Java程序中起到的作用却远超类加载阶段。对于任意一个类,都必须由加载它的类加载器和这个类本身一起共同确立其在Java虚拟机中的唯一性,每一个类加载器,都拥有一个独立的类名称空间。这句话可以表达得更通俗一些:比较两个类是否“相等”,只有在这两个类是由同一个类加载器加

springboot实战(七)之jackson配置前后端交互下划线转驼峰&对象序列化与反序列化

目录环境:1.驼峰转下划线配置1.1单个字段命名转化使用@JsonProperty注解1.2单个类进行命名转化使用@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)注解3.全局命名策略配置2.序列化以及反序列化2.1序列化2.2反序列化3.自定义序

ref和reactive区别

使用区别reactive定义引用数据类型,ref定义基本类型reactive定义的变量直接使用,ref定义的变量使用时需要.value模板中均可直接使用,vue帮我们判断了是reactive还是ref定义的(通过__v_isRef属性),从而自动添加了.value。//定义letcount=ref(0);letobj=

好用的记笔记app选哪个?

当你在日常生活中突然获得了一个灵感,或者需要记录会议的重要内容,或者是学校课堂上的笔记,你通常会拿出手机,因为它总是在你身边,随时可用。这时候,一款好的记笔记App可以让你事半功倍。敬业签是一款全面的云端备忘记事软件,支持在Windows/Web/Android/iOS/Mac/HarmonyOS等端口同步和编辑记事内

机器学习技术(十)——决策树算法实操,基于运营商过往数据对用户离网情况进行预测

机器学习技术(十)——决策树算法实操文章目录机器学习技术(十)——决策树算法实操一、引言二、数据集介绍三、导入相关依赖库四、读取并查看数据1、读取数据2、查看数据五、数据预处理1、选择数据2、数据转码六、建模与参数优化1、训练模型2、评估模型3、调参优化七、模型可视化八、决策树实操总结一、引言决策树部分主要包含基于py

django_model_一对一映射

settings相关配置#settings.py...DATABASES={'default':{'ENGINE':'django.db.backends.mysql','NAME':'djangos','USER':'root','PASSWORD':'990212','HOST':'localhost','PORT

【TCP】滑动窗口、流量控制 以及拥塞控制

滑动窗口、流量控制以及拥塞控制1.滑动窗口(效率机制)2.流量控制(安全机制)3.拥塞控制(安全机制)1.滑动窗口(效率机制)TCP使用确认应答策略,对每一个发送的数据段,都要给一个ACK确认应答。收到ACK后再发送下一个数据段。这样做有一个比较大的缺点,就是性能较差。尤其是数据往返的时间较长的时候。既然这样一发一收的

大厂超全安全测试--关于安全测试的分类及如何测试

安全测试(总结)1.jsonNP劫持(其实json劫持和jsonNP劫持属于CSRF跨站请求伪造)的攻击范畴,解决方法和CSRF一样定义:构造带有jsonp接口的恶意页面发给用户点击,从而将用户的敏感信息通过jsonp接口传输到攻击者服务器json语法规则:数据在名称/值对中、数据由逗号分隔、花括号保存对象、方括号保存

循环神经网络——中篇【深度学习】【PyTorch】【d2l】

文章目录6、循环神经网络6.4、循环神经网络(`RNN`)6.4.1、理论部分6.4.2、代码实现6.5、长短期记忆网络(`LSTM`)6.5.1、理论部分6.5.2、代码实现6.6、门控循环单元(`GRU`)6.6.1、理论部分6.6.2、代码实现6、循环神经网络6.4、循环神经网络(RNN)6.4.1、理论部分原理

18.3 【Linux】登录文件的轮替(logrotate)

18.3.1logrotate的配置文件logrotate主要是针对登录文件来进行轮替的动作,他必须要记载“在什么状态下才将登录文件进行轮替”的设置。logrotate这个程序的参数配置文件在:/etc/logrotate.conf/etc/logrotate.d/logrotate.conf才是主要的参数文件,至于l

热文推荐