LSTM从入门到精通(形象的图解,详细的代码和注释,完美的数学推导过程)
创始人
2025-05-29 02:26:17
0

先附上这篇文章的一个思维导图


  1. 什么是RNN

按照八股文来说:RNN实际上就是一个带有记忆的时间序列的预测模型

RNN的细胞结构图如下:

softmax激活函数只是我举的一个例子,实际上得到y也可以通过其他的激活函数得到

其中a代表t-1时刻隐藏状态,a代表经过X这一t时刻的输入之后,得到的新的隐藏状态。公式主要是a = tanh(Waa * a + Wax * X + b1) ;大白话解释一下就是,X是今天的吊针,a是昨天的发烧度数39,经过今天这一针之后,a变成38度。这里的记忆体现在今天的38度是在前一天的基础上,通过打吊针来达到第二天的降温状态。


1.1 RNN的应用

由于RNN的记忆性,我们最容易想到的就是RNN在自然语言处理方面的应用,譬如下面这张图,提前预测出下一个字。

除此之外,RNN的应用还包括下面的方向:

  1. 语言模型:RNN被广泛应用于语言模型的建模中,例如自然语言处理、机器翻译、语音识别等领域。

  1. 时间序列预测:RNN可以用于时间序列预测,例如股票价格预测、气象预测、心电图信号预测等。

  1. 生成模型:RNN可以用于生成模型,例如文本生成、音乐生成、艺术创作等。

  1. 强化学习:RNN可以用于强化学习中,例如在游戏、机器人控制和决策制定等领域。


1.2 RNN的缺陷

想必大家一定听说过LSTM,没错,就是由于RNN的尿性,所以才出现LSTM这一更精妙的时间序列预测模型的设计。但是我们知己知彼才能百战百胜,因此我还是决定详细分析一下RNN的缺陷,看过一些资料,但是只是肤浅的提到了梯度消失和梯度爆炸,没有实际的数学推导,这可不是一个求学之人应该有的态度!

主要的缺陷是两个:

  1. 长期依赖问题导致的梯度消失:众所周知RNN模型是一个具有记忆的模型,每一次的预测都和当前输入以及之前的状态有关,但是我们试想,如果我们的句子很长,他在第1000个记忆细胞还能记住并很好的利用第1个细胞的记忆状态吗?答案显然是否定的

  1. 梯度爆炸


1.2.1 梯度消失和梯度爆炸的详细公式推导

敲黑板(手写公式推导,大家最迷糊的地方):

根据下面图示的例子,我手写并反复检查了自己的过程(下图),请各位看官务必认真看看,理解起来并不难,对于别的文章随口一提的梯度消失和梯度爆炸实在是透彻太多啦!!!

我们假设损失函数 ,Y是实际值,O是预测值;首先,我们假设只有三层,然后通过三层我们就能以此类推找出规律。反向传播我们需要对Wo,Wx,Ws,b四个变量都求偏导,在这里我们主要对Wx求偏导,其他三个以此类推,就很简单了。为了表示更清晰,笔者使用紫色的x表示乘法。

根据推导的公式我们得到一个指数函数,我们在高中时候就学到过指数函数的变化系数是极大的,因此在t趋于比较大的时候(也就是我们的句子比较长的时候),如果比1小不少,那么模型的一部分梯度会趋于0,因此优化会几乎停止;同理,如果比1大一些,那么模型的部分梯度会极大,导致模型和的变化无法控制,优化毫无意义。


  1. 什么是LSTM

八股文解释:LSTM(长短时记忆网络)是一种常用于处理序列数据的深度学习模型,与传统的 RNN(循环神经网络)相比,LSTM引入了三个门(输入门、遗忘门、输出门,如下图所示)和一个细胞状态(cell state),这些机制使得LSTM能够更好地处理序列中的长期依赖关系。注意:小蝌蚪形状表示的是sigmoid激活函数
Ct是细胞状态(记忆状态),是输入的信息,是隐藏状态(基于得到的)

用最朴素的语言解释一下三个门,并且用两门考试来形象的解释一下LSTM:

  1. 遗忘门:通过x和ht的操作,并经过sigmoid函数,得到0,1的向量,0对应的就代表之前的记忆某一部分要忘记,1对应的就代表之前的记忆需要留下的部分 ===>代表复习上一门线性代数所包含的记忆,通过遗忘门,忘记掉和下一门高等数学无关的内容(比如矩阵的秩)

  1. 输入门:通过将之前的需要留下的信息和现在需要记住的信息相加,也就是得到了新的记忆状态。===>代表复习下一门科目高等数学的时候输入的一些记忆(比如洛必达法则等等),那么已经线性代数残余且和高数相关的部分(比如数学运算)+高数的知识=新的记忆状态

  1. 输出门:整合,得到一个输出===>代表高数所需要的记忆,但是在实际的考试不一定全都发挥出来考到100分。因此,则代表实际的考试分数

为了便于大家理解,附上几张非常好的图供大家理解完整的数据处理的流程:

遗忘门:

输入门:

细胞状态:

输出门:


2.1 LSTM的模型结构

这里有两张别的博主的很好的图,我在初学的时候也是恍然大悟:

图的出处

解释一下pytorch训练lstm所使用的参数:

  1. 这是利用pytorch调用LSTM所使用的参数

output,(h_n,c_n) = lstm (x, [ht_1, ct_1]),一般直接放入x就好,后面中括号的不用管
  1. 这是作为x(输入)喂给LSTM的参数

x:[seq_length, batch_size, input_size],这里有点反人类,batch_size一般都是放在开始的位置
  1. 这是pytorch简历LSTM是所需参数

lstm = LSTM(input_size,hidden_size,num_layers)

2.2 LSTM相比RNN的优势

LSTM的反向传播的数学推导很繁琐,因为涉及到的变量很多,但是LSTM确实是可以在一定程度上解决梯度消失和梯度爆炸的问题。我简单说一下,RNN的连乘主要是W的连乘,而W是一样的,因此就是一个指数函数(在梯度中出现指数函数并不是一件友好的事情);相反,LSTM的连乘是的偏导的不断累乘,如果前后的记忆差别不大,那偏导的值就是1,那就是多个1相乘。当然,也可能出现某一一些偏导的值很大,但是一定不会很多(换句话说,一句话的前后没有逻辑,那完全没有训练的必要)。


2.3 pytorch实现LSTM对股票的预测(实战)

需要安装一下tushare的金融方面的数据集,代码的注解我已经写的很清楚了

#!/usr/bin/python3
# -*- encoding: utf-8 -*-import matplotlib.pyplot as plt
import numpy as np
import tushare as ts
import pandas as pd
import torch
from torch import nn
import datetime
import timeDAYS_FOR_TRAIN = 10class LSTM_Regression(nn.Module):"""使用LSTM进行回归参数:- input_size: feature size- hidden_size: number of hidden units- output_size: number of output- num_layers: layers of LSTM to stack"""def __init__(self, input_size, hidden_size, output_size=1, num_layers=2):super().__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers)self.fc = nn.Linear(hidden_size, output_size)def forward(self, _x):x, _ = self.lstm(_x)  # _x is input, size (seq_len, batch, input_size)s, b, h = x.shapex = x.view(s * b, h)x = self.fc(x)x = x.view(s, b, -1)  # 把形状改回来return xdef create_dataset(data, days_for_train=5) -> (np.array, np.array):"""根据给定的序列data,生成数据集数据集分为输入和输出,每一个输入的长度为days_for_train,每一个输出的长度为1。也就是说用days_for_train天的数据,对应下一天的数据。若给定序列的长度为d,将输出长度为(d-days_for_train+1)个输入/输出对"""dataset_x, dataset_y = [], []for i in range(len(data) - days_for_train):_x = data[i:(i + days_for_train)]dataset_x.append(_x)dataset_y.append(data[i + days_for_train])return (np.array(dataset_x), np.array(dataset_y))if __name__ == '__main__':t0 = time.time()data_close = ts.get_k_data('000001', start='2019-01-01', index=True)['close']  # 取上证指数的收盘价data_close.to_csv('000001.csv', index=False) #将下载的数据转存为.csv格式保存data_close = pd.read_csv('000001.csv')  # 读取文件df_sh = ts.get_k_data('sh', start='2019-01-01', end=datetime.datetime.now().strftime('%Y-%m-%d'))print(df_sh.shape)data_close = data_close.astype('float32').values  # 转换数据类型plt.plot(data_close)plt.savefig('data.png', format='png', dpi=200)plt.close()# 将价格标准化到0~1max_value = np.max(data_close)min_value = np.min(data_close)data_close = (data_close - min_value) / (max_value - min_value)# dataset_x# 是形状为(样本数, 时间窗口大小)# 的二维数组,用于训练模型的输入# dataset_y# 是形状为(样本数, )# 的一维数组,用于训练模型的输出。dataset_x, dataset_y = create_dataset(data_close, DAYS_FOR_TRAIN)  # 分别是(1007,10,1)(1007,1)# 划分训练集和测试集,70%作为训练集train_size = int(len(dataset_x) * 0.7)train_x = dataset_x[:train_size]train_y = dataset_y[:train_size]# 将数据改变形状,RNN 读入的数据维度是 (seq_size, batch_size, feature_size)train_x = train_x.reshape(-1, 1, DAYS_FOR_TRAIN)train_y = train_y.reshape(-1, 1, 1)# 转为pytorch的tensor对象train_x = torch.from_numpy(train_x)train_y = torch.from_numpy(train_y)model = LSTM_Regression(DAYS_FOR_TRAIN, 8, output_size=1, num_layers=2)  # 导入模型并设置模型的参数输入输出层、隐藏层等model_total = sum([param.nelement() for param in model.parameters()])  # 计算模型参数print("Number of model_total parameter: %.8fM" % (model_total / 1e6))train_loss = []loss_function = nn.MSELoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)for i in range(200):out = model(train_x)loss = loss_function(out, train_y)loss.backward()optimizer.step()optimizer.zero_grad()train_loss.append(loss.item())# 将训练过程的损失值写入文档保存,并在终端打印出来with open('log.txt', 'a+') as f:f.write('{} - {}\n'.format(i + 1, loss.item()))if (i + 1) % 1 == 0:print('Epoch: {}, Loss:{:.5f}'.format(i + 1, loss.item()))# 画loss曲线plt.figure()plt.plot(train_loss, 'b', label='loss')plt.title("Train_Loss_Curve")plt.ylabel('train_loss')plt.xlabel('epoch_num')plt.savefig('loss.png', format='png', dpi=200)plt.close()# torch.save(model.state_dict(), 'model_params.pkl')  # 可以保存模型的参数供未来使用t1 = time.time()T = t1 - t0print('The training time took %.2f' % (T / 60) + ' mins.')tt0 = time.asctime(time.localtime(t0))tt1 = time.asctime(time.localtime(t1))print('The starting time was ', tt0)print('The finishing time was ', tt1)# for testmodel = model.eval()  # 转换成评估模式# model.load_state_dict(torch.load('model_params.pkl'))  # 读取参数# 注意这里用的是全集 模型的输出长度会比原数据少DAYS_FOR_TRAIN 填充使长度相等再作图dataset_x = dataset_x.reshape(-1, 1, DAYS_FOR_TRAIN)  # (seq_size, batch_size, feature_size)dataset_x = torch.from_numpy(dataset_x)pred_test = model(dataset_x)  # 全量训练集# 的模型输出 (seq_size, batch_size, output_size)pred_test = pred_test.view(-1).data.numpy()pred_test = np.concatenate((np.zeros(DAYS_FOR_TRAIN), pred_test))  # 填充0 使长度相同assert len(pred_test) == len(data_close)plt.plot(pred_test, 'r', label='prediction')plt.plot(data_close, 'b', label='real')plt.plot((train_size, train_size), (0, 1), 'g--')  # 分割线 左边是训练数据 右边是测试数据的输出plt.legend(loc='best')plt.savefig('result.png', format='png', dpi=200)plt.close()

2.4 小问题:为什么采用tanh函数,不能都用sigmoid函数吗

先放上两个函数的图形:

  1. Sigmoid函数比Tanh函数收敛饱和速度慢

  1. Sigmoid函数比Tanh函数值域范围更窄

  1. tanh的均值是0,Sigmoid均值在0.5左右,均值在0的数据显然更便于数据处理

  1. tanh的函数变化敏感区间更大

  1. 对两者求导,发现tanh对计算的压力更小,直接是1-原函数的平方,不需要指数操作

使用该问的图请标明出处,创作不易,希望收获你的赞赞

相关内容

热门资讯

原创 蒸... 在探讨蒸馒头时,我们常常会遇到一个选择:是使用碱水还是苏打水来发酵。今天,就让我们一同揭开这个谜团,...
扬州炒饭:颗颗分明的黄金蛋炒饭... 扬州炒饭以其颗颗分明、金黄诱人的独特魅力闻名遐迩,而 “炒散” 技巧则是成就美味的关键。本文深入挖掘...
荔枝“进京” 产销两旺 北京市民一家人现场品尝茂名荔枝。 孟夏时节,广东茂名的荔枝香跨越山海,飘进了京城。5月29日,“荔枝...
今天去吃了一个一直想尝的老挝餐... 今天去吃了一个一直想尝的老挝餐厅 叫Laos in Town,意外的好吃。味道上在云贵菜和泰餐之间,...
原创 这... 标题:这家卖馒头的店,买一次吃两天,都在排队等,每天卖出馒头上万个。 在繁忙的都市中,有一家不起眼...
原创 多... 在中国传统饮食文化中,"药食同源"的理念源远流长。苹果作为日常生活中最常见的水果之一,其营养价值早已...
原创 螺... 标题:螺蛳粉大比拼:谁才是最正宗的螺蛳粉,网友:最爱第三款! 在美食的世界里,每一种风味都承载着它...
煎饼面糊的制作秘诀:口感与美味... 煎饼作为广受欢迎的早餐食品,其面糊的调制是决定口感和品质的关键因素。本文将向你揭示煎饼面糊的调制秘诀...
新华视点|奢华礼粽卖“天价” ... 端午节临近,各式粽子成为消费市场上的热门之选。记者走访发现,其中不乏馅料名贵、包装奢华的粽子礼盒。一...
原创 这... 标题:这些食物都是用脚踩出来的,很多人都喜欢!网友:看了会有阴影! 在美食的世界里,有一种独特的制...
沧州再添殊荣:0317火锅鸡在... 在刚刚落幕的河北省品牌产品博览会上,沧州0317火锅鸡大放异彩,一举斩获金奖。作为沧州美食的杰出代表...
原创 这... 标题:这是生蚝最爽的吃法?据说只有身体倍棒的人才能吃得下! 在探索美食的海洋中,生蚝以其独特的鲜美...
原创 西... 标题:西红柿蛋汤家常最经典,想要营养美味不单调?那你就得这样做! 在繁忙的生活中,我们总是渴望一顿...
原创 西... 标题:西瓜还能做酱料?教你这个秘方,西瓜酱鲜美可口还下饭! 在炎炎夏日,没有什么比一碗清凉的西瓜汤...
原创 以... 芦笋黑椒牛柳是一道色香味俱全的家常菜,不仅营养丰富,而且制作简单,特别适合家庭日常餐桌。芦笋的清脆与...
迎端午 忆乡情 在高潭镇黄洲村,一张旧木桌,上面用三根木条钉个架,就成了当地村民包粽子的工作台。 高潭镇粽子产业走...
原创 那... 标题:那些很便宜但一吃就可以吃一天的东西,吃货们太需要了 在这个世界上,有一种食物,它既廉价又美味...
原创 洋... 最近,洋葱这种再普通不过的食材突然在网络上火了起来。作为一个吃了30年洋葱的老饕,我尝试过无数种洋葱...
原创 自... 标题:在家烹饪的美味秘籍:五种美食,每一口都是幸福的滋味 在这个快节奏的时代,我们总是在寻找那些能...