【深度学习实验】线性模型(二):使用NumPy实现线性模型:梯度下降法

2023-09-17 16:39:45

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 初始化参数

2. 线性模型 linear_model

3. 损失函数loss_function

4. 梯度计算函数compute_gradients

5. 梯度下降函数gradient_descent

6. 调用函数


一、实验介绍

        使用NumPy实现线性模型:梯度下降法

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

         线性模型梯度下降法是一种常用的优化算法,用于求解线性回归模型中的参数。它通过迭代的方式不断更新模型参数,使得模型在训练数据上的损失函数逐渐减小,从而达到优化模型的目的。

        梯度下降法的基本思想是沿着损失函数梯度的反方向更新模型参数。在每次迭代中,根据当前的参数值计算损失函数的梯度,然后乘以一个学习率的因子,得到参数的更新量。学习率决定了参数更新的步长,过大的学习率可能导致错过最优解,而过小的学习率则会导致收敛速度过慢。

具体而言,对于线性回归模型,梯度下降法的步骤如下:

  1. 初始化模型参数,可以随机初始化或者使用一些启发式的方法。

  2. 循环迭代以下步骤,直到满足停止条件(如达到最大迭代次数或损失函数变化小于某个阈值):

    a. 根据当前的参数值计算模型的预测值。

    b. 计算损失函数关于参数的梯度,即对每个参数求偏导数。

    c. 根据梯度和学习率更新参数值。

    d. 计算新的损失函数值,并检查是否满足停止条件。

  3. 返回优化后的模型参数。

       本实验中,gradient_descent函数实现了梯度下降法的具体过程。它通过调用initialize_parameters函数初始化模型参数,然后在每次迭代中计算模型预测值、梯度以及更新参数值。

0. 导入库

import numpy as np

1. 初始化参数

        在梯度下降算法中,需要初始化待优化的参数,即权重 w 和偏置 b。可以使用随机初始化的方式。

def initialize_parameters():
    w = np.random.randn(5)
    b = np.random.randn(5)
    return w, b

2. 线性模型 linear_model

def linear_model(x, w, b):
    output = np.dot(x, w) + b
    return output

3. 损失函数loss_function

         该函数接受目标值y和模型预测值prediction,计算均方误差损失。

def loss_function(y, prediction):
    loss = (prediction - y) * (prediction - y)
    return loss

4. 梯度计算函数compute_gradients

        为了使用梯度下降算法,需要计算损失函数关于参数 w 和 b 的梯度。可以使用数值计算的方法来近似计算梯度。

def compute_gradients(x, y, w, b):
    h = 1e-6  # 微小的数值,用于近似计算梯度
    grad_w = (loss_function(y, linear_model(x, w + h, b)) - loss_function(y, linear_model(x, w - h, b))) / (2 * h)
    grad_b = (loss_function(y, linear_model(x, w, b + h)) - loss_function(y, linear_model(x, w, b - h))) / (2 * h)
    return grad_w, grad_b

5. 梯度下降函数gradient_descent

        根据梯度计算的结果更新参数 w 和 b,从而最小化损失函数。

def gradient_descent(x, y, learning_rate, num_iterations):
    w, b = initialize_parameters()
    for i in range(num_iterations):
        prediction = linear_model(x, w, b)
        grad_w, grad_b = compute_gradients(x, y, w, b)
        w -= learning_rate * grad_w
        b -= learning_rate * grad_b
        loss = loss_function(y, prediction)
        print("Iteration", i, "Loss:", loss)
    return w, b

6. 调用函数

        执行梯度下降优化:调用 gradient_descent 函数并传入数据 x 和 y,设置学习率和迭代次数进行优化。

x = np.random.rand(5)
y = np.array([1, -1, 1, -1, 1]).astype('float')
learning_rate = 0.1
num_iterations = 100
w_optimized, b_optimized = gradient_descent(x, y, learning_rate, num_iterations)

        在上述代码中,每一次迭代都会打印出当前迭代次数和对应的损失值。通过不断更新参数 w 和 b,使得损失函数逐渐减小,达到最小化损失函数的目的。

希望这个详细解析能够帮助你优化代码并使用梯度下降算法最小化损失函数。如果还有其他问题,请随时提问!

更多推荐

HelpLook全新升级!定制AI问答机器人,企业内容中心焕新

一直以来,企业都在努力解决内外部“企业知识管理”问题:从纸质手册发放,转线上电子文档传阅(pdf/ppt/word等),再到整理客户常见问题(FAQ)和内部知识库(wiki),但始终没有找到一套完整方案将“企业知识”很好地集中管理及分享查阅。持续困扰大家的⬇️❌要么是软件系统更新太困难、或搭建费用太昂贵❌要么是没人知道

MybatisX快速生成代码(mybatis plus模板)

文章目录1、概述2、基本使用2.1、插件安装2.2、集成数据库1、概述MybatisX是一款基于IDEA的快速开发插件,为效率而生。在开发过程中,相信大家都遇到过一个数据库内有着十几张或比之更多的数据表的情况。而面对这众多的数据表,实体类、服务类、服务实现类、Mapper接口及其对应的XML文件更是头大,这无疑是成倍增

IPV6真的神

ipv6地址短缺的现实,万物互联的未来<全局可达性>1、路由表更小。地址分配遵循聚类原则,路由表用Entry的路由表示一片子网。2、更强的组播以及流控制。为媒体服务质量QoS。控制提供了良好的网络平台。3、DHCPv6,自动配置地址。使得网(尤其是局域网)的管理更加方便和快捷。4、自带IPSec,端对端安全。在网络层的

通讯网关软件001——利用CommGate X2Access-U实现OPC UA数据转储Access

本文介绍利用CommGateX2ACCESS-U实现从OPCUAServer读取数据并同步转储至ACCESS数据库。CommGateX2ACCESS-U是宁波科安网信开发的网关软件,软件可以登录到网信智汇(http://wangxinzhihui.com)下载。【案例】如下图所示,实现从OPCUAServer实时读取数

linux vim操作汇总

汇总起来,备忘查看~目录1、复制复制一行包括换行符复制光标开始到行末的文本复制光标开始到行首的文本复制当前单词复制单行或多行到指定行后2、粘贴、剪贴3、移动4、删除删除整行删除光标所在行删除光标所在行开始的3行删除一行带复制(当前光标所在行)删除当前行开始的几行(包括当前行)删除到本行行首/行尾删除字符删除单词/符号5

探索Go语言在机器学习领域的应用局限与前景

🌷🍁博主猫头虎带您GotoNewWorld.✨🍁🦄博客首页——猫头虎的博客🎐🐳《面试题大全专栏》文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺🌊《IDEA开发秘籍专栏》学会IDEA常用操作,工作效率翻倍~💐🌊《100天精通Golang(基础入门篇)》学会Golang语言,畅玩云原生,走遍大

Pytest系列-数据驱动@pytest.mark.parametrize(7)

简介unittest和pytest参数化对比:pytest与unittest的一个重要区别就是参数化,unittest框架使用的第三方库ddt来参数化的而pytest框架:前置/后置处理函数fixture,它有个参数params专门与request结合使用来传递参数,也可以用parametrize结合request来传

【linux】paramiko介绍 + 路由器设置tc命令使用

背景:要给网络灵活的设置各种带宽限制,通过对路由器下发tc命令实现。设置python脚本的ssh链接+tc脚本下发+针对某一个id进行配置。Paramiko是一个用于在Python中进行SSH(SecureShell)协议通信的库。它提供了在远程服务器上执行命令、上传和下载文件、建立SSH连接等功能,使得开发者可以轻松

MySQL---优化&日志

目录一、MySQL优化3、mysqlserver上的优化3.1、MySQL查询缓存3.2、索引和数据缓存3.2、线程缓存二、MySQL日志2.1、redolog重做日志2.2、undolog回滚日志2.3、错误日志2.4、查询日志2.5、二进制日志2.5.1、基于binlog数据恢复实践操作六、慢查询日志一、MySQL

渗透测试信息收集方法和工具分享

文章目录一、域名收集1.OneForAll2.子域名挖掘机3.subdomainsBurte4.ssl证书查询二、获取真实ip1.17CE2.站长之家ping检测3.如何寻找真实IP4.纯真ip数据库工具5.c段,旁站查询三、端口扫描1.端口扫描站长工具2.masscan(全端口扫描)+nmap扫描3.scanport

科大讯飞分类算法挑战赛2023的一些经验总结

引言:ResNet是hekaiming大佬的早年神作,当年直接刷榜各大图像分类任务。ResNet是一种残差网络,咱们可以把它理解为一个子网络,这个子网络经过堆叠可以构成一个很深的网络,而ResNext在其基础上,进行了一定修改完善,通过引入Cardinatity后,模型性能得到了大幅度提升。(下图是经典ResNet残差

热文推荐