第九章(1):循环神经网络与pytorch示例(RNN实现股价预测)

2023-07-13 20:13:44

第九章(1):循环神经网络与pytorch示例(RNN实现股价预测)


作者:安静到无声 个人主页

作者简介:人工智能和硬件设计博士生、CSDN与阿里云开发者博客专家,多项比赛获奖者,发表SCI论文多篇。

Thanks♪(・ω・)ノ 如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒+关注哦! o( ̄▽ ̄)d

欢迎大家来到安静到无声的 《基于pytorch的自然语言处理入门与实践》,如果对所写内容感兴趣请看《基于pytorch的自然语言处理入门与实践》系列讲解 - 总目录,同时这也可以作为大家学习的参考。欢迎订阅,请多多支持!

1. 概述

循环神经网络(Recurrent Neural Network,RNN)是一种基于神经网络的机器学习模型,主要用于处理序列数据。与传统的前馈神经网络不同,RNN引入了循环连接,使得模型能够捕捉到输入序列中的上下文信息和时间依赖关系。

假设给定一个序列, x 1 : T = ( x 1 , x 2 , … , x t , … , x T ) x_{1:T}=(x_{1},x_{2},\ldots,x_{t},\ldots,x_{T}) x1:T=(x1,x2,,xt,,xT),RNN神经网络通过下面公式更新带反馈边的隐藏层的活性值 h t h_t ht
h t = f ( h t − 1 , x t ) , \boldsymbol{h}_{t}=f(\boldsymbol{h}_{t-1},\boldsymbol{x}_{t}), ht=f(ht1,xt),

其中 h 0 = 0 {}_{\boldsymbol{h}_{0}=0} h0=0 f ( ⋅ ) f(\cdot) f()为一个非线性的函数,例如前馈神经网络。

下图给出了循环神经网络的示例,其中“延时器”为一个虚拟单元,记录神经元的最近一次或几次活性值。
在这里插入图片描述

从数学上讲,上文公式可以看成一个动力系统。隐藏层的活性值 h t h_t ht,在很多文献上也称为状态( State ) 或隐状态( Hidden State )。

2. 简单的循环神经网络

简单循环网络( Simple Recurrent Network,SRN)是一个非常简单的循环神经网络,只有一个隐藏层的神经网络. 在一个两层的前馈神经网络中,连接存在相邻的层与层之间,隐藏层的节点之间是无连接的。而简单循环网络增加了从隐藏层到隐藏层的反馈连接。

向量 x t ∈ R M {\boldsymbol{x}}_{t}\in\mathbb{R}^{M} xtRM表示在 t t t时刻的一个输入, h t ∈ R D h_t\in\mathbb{R}^D htRD表示一个隐藏层的状态,这时 h t h_t ht与当前时刻的 x t x_t xt有关系,而且也和上一时刻的 h t − 1 h_{t-1} ht1有关系,简单的循环神经网络在 t t t时刻的更新公式如下所示:
z t = U h t − 1 + W x t + b z_{t}=U\boldsymbol{h}_{t-1}+Wx_{t}+\boldsymbol{b} zt=Uht1+Wxt+b h t = f ( z t ) h_t=f(\boldsymbol{z}_t) ht=f(zt)其中 z t z_{t} zt为隐藏层的净输入, U ∈ R D × D U\in\mathbb{R}^{D\times D} URD×D为状态-状态的矩阵, W ∈ R D × M W\in\mathbb{R}^{D\times M} WRD×M为状态-输入的权重矩阵, b ∈ R D b\in\mathbb{R}^{D} bRD为偏置项, f ( ⋅ ) f(\cdot) f()为一个非线性的激活函数。上述公式可以和写为:
h t = f ( U h t − 1 + W x t + b ) . \boldsymbol{h}_{t}=f(\boldsymbol{Uh}_{t-1}+\boldsymbol{W}\boldsymbol{x}_{t}+\boldsymbol{b}). ht=f(Uht1+Wxt+b).

如果我们把每个时刻的状态都看作前馈神经网络的一层,循环神经网络可以看作在时间维度上权值共享的神经网络。下给出了按时间展开的循环神经网络。

在这里插入图片描述

3. RNN实现股价预测

建立RNN基于zgpa train.csv数据,建立RNN模型,预测股价。

  • 完成数据预处理,将序列数据转化为可用子RNN输入的数据
  • 对新数据zgpa_test.csv进行预测,可视化结果
    模型结构:RNN 输出有120个神经元,每次使用前8个数据预测第9个数据。
import pandas as pd
import numpy as np
data = pd.read_csv(r'./zgpa_train.csv')
data.head()

在这里插入图片描述

price = data.loc[:,'close']
price.head()
0    28.78
1    29.23
2    29.26
3    28.50
4    28.67
Name: close, dtype: float64
#归一化处理
price_norm = price/max(price)

%matplotlib inline
from matplotlib import pyplot as plt
fig1 = plt.figure(figsize=(8,5))
plt.plot(price)
plt.title('close price')
plt.xlabel('time')
plt.ylabel('price')
plt.show()

在这里插入图片描述

#define X and y
#define method to extract X and y
def extract_data(data,time_step):
    X = []
    y = []
    #0,1,2,3...9:10个样本;time_step=8;0,1...7;1,2...8;2,3...9三组(两组样本)
    for i in range(len(data)-time_step):
        X.append([a for a in data[i:i+time_step]])
        y.append(data[i+time_step])
    X = np.array(X)
    X = X.reshape(X.shape[0],X.shape[1],1)
    return X, y
time_step = 8
X,y = extract_data(price_norm,time_step)


# 转换为Tensor
import torch
input_data = torch.tensor(X, dtype=torch.float32)
target_data = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)
# 定义RNN模型
import torch.nn as nn
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):  #分别是输入,隐藏和输出的维度
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :])
        return out


input_size = 1
hidden_size = 120
output_size = 1

# 创建模型实例
rnn = RNN(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.01)

# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):
    outputs = rnn(input_data)
    loss = criterion(outputs, target_data)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.10f}")


Epoch [100/1000], Loss: 0.0004473718
Epoch [200/1000], Loss: 0.0006229825
Epoch [300/1000], Loss: 0.0003415935
Epoch [400/1000], Loss: 0.0002980608
Epoch [500/1000], Loss: 0.0002818418
Epoch [600/1000], Loss: 0.0002672504
Epoch [700/1000], Loss: 0.0002543367
Epoch [800/1000], Loss: 0.0002437856
Epoch [900/1000], Loss: 0.0002345658
Epoch [1000/1000], Loss: 0.0002437943
predict = rnn(input_data)
predict_out = []
for i in predict:
    predict_out.append(i)
    
fig2 = plt.figure(figsize=(8,5))
plt.plot(y,label='real price')
plt.plot(predict_out,label='predict price')
plt.title('close price')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-s6UNaWPo-1689238498868)(/imgs/2023-07-13/b7Gtk9DkbhFTkyjr.png)]

--------推荐专栏--------
🔥 手把手实现Image captioning
💯CNN模型压缩
💖模式识别与人工智能(程序与算法)
🔥FPGA—Verilog与Hls学习与实践
💯基于Pytorch的自然语言处理入门与实践

参考

邱锡鹏,神经网络与深度学习,机械工业出版社,https://nndl.github.io/, 2020.

更多推荐

C语言实现 cortex-A7核 点LED灯 (附 汇编实现、使用C语言 循环实现、使用C语言 封装函数实现【重要、常用】)

1汇编实现textglobal_startstart:**************LED1点灯--->PE10**************/**************RCC章节初始化**************/CC_INIT:@1.使能GPIOE组控制器,通过RCC_MP_AHB4ENSETR寄存器设置GPIOE组

openGauss学习笔记-72 openGauss 数据库管理-创建和管理分区表

文章目录openGauss学习笔记-72openGauss数据库管理-创建和管理分区表72.1背景信息72.2操作步骤72.2.1使用默认表空间72.2.1.1创建分区表(假设用户已创建tpcdsschema)72.2.1.2插入数据72.2.1.3修改分区表行迁移属性72.2.1.4删除分区72.2.1.5增加分区7

GSMA SGP.21协议学习

GSMASGP.21协议学习1简介1.1概述本文档提供了一种体系结构方法,作为所有市场中设备的远程SIM配置的建议解决方案。体系结构的主要目标是为设备的远程SIM配置提供必要的凭据以获取移动网络访问权限。该版本专注于消费类市场的设备。请注意,SGP.21V1.0[23]尚未弃用。1.2范围本文档的目的是定义一个通用架构

图像识别在自动驾驶和智能安防中的关键应用

图像识别在自动驾驶和智能安防中的关键应用随着人工智能和深度学习技术的发展,图像识别已经成为了自动驾驶和智能安防领域的关键应用之一。图像识别技术能够通过处理和分析图像数据,帮助自动驾驶车辆和智能安防系统实现更准确、更高效的运行。本文将介绍图像识别在自动驾驶和智能安防中的关键应用及其相关技术。一、图像识别在自动驾驶中的应用

设计模式实战:模版方法

1.模版方法概述在面向对象程序设计过程中,程序员常常会遇到这种情况:设计一个系统时知道了算法所需的关键步骤,而且确定了这些步骤的执行顺序,但某些步骤的具体实现还未知,或者说某些步骤的实现与具体的环境相关。例如,去银行办理业务一般要经过以下4个流程:取号、排队、办理具体业务、对银行工作人员进行评分等,其中取号、排队和对银

并发编程系列-分而治之思想Forkjoin

我们介绍过一些有关并发编程的工具和概念,包括线程池、Future、CompletableFuture和CompletionService。如果仔细观察,你会发现这些工具实际上是帮助我们从任务的角度来解决并发问题的,而不是让我们陷入线程之间如何协作的繁琐细节(比如等待和通知等)。对于简单的并行任务,你可以使用“线程池+F

数据库顶会 VLDB 2023 论文解读 - Krypton: 字节跳动实时服务分析 SQL 引擎设计

“Krypton源于DC宇宙中的氪星,它是超人的故乡,以氪元素命名”。引言近些年,在复杂的分析需求之外,字节内部的业务对于实时数据的在线服务能力也提出了更高的要求。大部分业务不得不采用多套系统来应对不同的Workload,虽然能满足需求,但也带来了不同系统数据一致性的问题,多个系统之间的ETL也浪费了大量的资源,同时对

区块链(1):区块链简介

区快链是通过密码技术保护的分布式数据库这是比特币背后的技术。本文将逐步带您了解区块链。1区块链BLOCKCHAIN的类的定义区块链有一个区块列表。它从一个单独的块开始,称为genesisblock【创世区块】2区块链BLOCK的类的定义第一个区块叫做Genesis[创世]block,每个块存储以下信息:IndexTim

人工智能如何提高转录效率

人工转录已经以某种形式存在了数百年,甚至数千年。近年来,在人工智能(AI)技术推动下,转录取得长足发展。转录文稿本身是音频内容的文本形式;借此,读者无需再听一遍录音便可了解一段时间内所讲述的内容或所发生的情况。转录对于记录保存、知识共享和改善可访问性至关重要。过去几年,随着AI的发展,人们越来越依赖于一种称为自动语音识

详解Nacos和Eureka的区别

文章目录Eureka是什么Nacos是什么Nacos的实现原理Nacos和Eureka的区别CAP理论连接方式服务异常剔除操作实例方式自我保护机制Eureka是什么Eureka是SpringCloud微服务框架默认的也是推荐的服务注册中心,由Netflix公司与2012将其开源出来,Eureka基于REST服务开发,主

设计模式再探——宏观篇

目录一、背景介绍二、思路&方案三、过程1.宏观介绍2.目的与意义3.七大原则的定义与边界4.思路由来四、总结五、升华一、背景介绍最近在做产品技术建模的过程中,一些地方刻意用到了设计模式,而一些地方也用到了但是并不是很明确。于是乎就带着这个疑惑来再探设计模式的宏观;也查阅了自己的博文:1.14年有宏观(第一层看山是山,知

热文推荐