本文将基于PyTorch框架中从头开始构建LSTM,以更好地了解其内部工作原理。

LSTM的基础

LSTM是RNN中一个较为流行的网络模块。主要包括输入,输入门,输出门,遗忘门,激活函数,全连接层和输出。其结构如下:

上图中每一个大绿色块代表一个LSTMcell,可以看到中间的LSTMcell 里面有四个黄色小框,每一个小黄框代表一个前馈网络层,就是经典的神经网络结构,小黄框里面符号代表该层的激活函数,即1、2、4的激活函数时sigmoid,而3的激活函数时tanh。

整个LSTMCell的具体计算公式如下

更加详细的关于LSTM内容见于博文理解LSTM

接下来,我们结合框架实现来分析一些LSTM内部工作原理,首先我们看看Pytorch和Keras关于LSTM的使用的方式。

** Pytorch**

pytoch中的LSTM api使用方式如下:

1
class torch.nn.LSTM(*args, **kwargs)

参数列表

  • input_size:x的特征维度
  • hidden_size:隐藏层的特征维度
  • num_layers:lstm隐层的层数,默认为1
  • bias:False则bih=0和bhh=0. 默认为True
  • batch_first:True则输入输出的数据格式为 (batch, seq, feature)dropout:除最后一层,每一层的输出都进行
  • dropout,默认为: 0
  • bidirectional:True则为双向lstm默认为False输入:input, (h0, c0)输出:output, (hn,cn)

Keras

keras的LSTM api使用方式如下:

1
keras.layers.recurrent.LSTM(units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True,**)

核心参数

  • units:输出维度
  • input_dim:输入维度,当使用该层为模型首层时,应指定该值(或等价的指定input_shape)
  • return_sequences:布尔值,默认False,控制返回类型。若为True则返回整个序列,否则仅返回输出序列的最后一个输出
  • input_length:当输入序列的长度固定时,该参数为输入序列的长度。当需要在该层后连接Flatten层,然后又要连接Dense层时,需要指定该参数,否则全连接的输出无法计算出来。

大部分参数直接看说明就可以明白,而对于pytorch的hidden_size和keras的units可能不太直观理解,两者的大小都代表LSTMcell的输出大小,其实就是LSTMcell的神经元个数,也就是前面我们提到的每一个小黄框的4个前馈网络层的神经元个数,如下图所示:

Pytorch实现

接下来,我们使用Pytoch框架从0构建一个LSTM,具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import math

class NaiveLSTM(nn.Module):
"""Naive LSTM like nn.LSTM"""
def __init__(self, input_size, hidden_size):
super(NaiveLSTM, self).__init__()
self.input_size = input_size # 输入的大小,一般
self.hidden_size = hidden_size # 隐藏层输出大小,也就是单元个数

# input gate
self.w_ii = Parameter(Tensor(hidden_size, input_size))
self.w_hi = Parameter(Tensor(hidden_size, hidden_size))
self.b_ii = Parameter(Tensor(hidden_size, 1))
self.b_hi = Parameter(Tensor(hidden_size, 1))

# forget gate
self.w_if = Parameter(Tensor(hidden_size, input_size))
self.w_hf = Parameter(Tensor(hidden_size, hidden_size))
self.b_if = Parameter(Tensor(hidden_size, 1))
self.b_hf = Parameter(Tensor(hidden_size, 1))

# output gate
self.w_io = Parameter(Tensor(hidden_size, input_size))
self.w_ho = Parameter(Tensor(hidden_size, hidden_size))
self.b_io = Parameter(Tensor(hidden_size, 1))
self.b_ho = Parameter(Tensor(hidden_size, 1))

# cell
self.w_ig = Parameter(Tensor(hidden_size, input_size))
self.w_hg = Parameter(Tensor(hidden_size, hidden_size))
self.b_ig = Parameter(Tensor(hidden_size, 1))
self.b_hg = Parameter(Tensor(hidden_size, 1))
self.reset_weigths()

def reset_weigths(self):
"""reset weights
"""
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
init.uniform_(weight, -stdv, stdv)

def forward(self, inputs, state):
"""Forward
Args:
inputs: [1, 1, input_size]
state: ([1, 1, hidden_size], [1, 1, hidden_size])
"""
# seq_size, batch_size, _ = inputs.size()
if state is None:
h_t = torch.zeros(1, self.hidden_size).t()
c_t = torch.zeros(1, self.hidden_size).t()
else:
(h, c) = state
h_t = h.squeeze(0).t()
c_t = c.squeeze(0).t()
hidden_seq = []
seq_size = 1
for t in range(seq_size):
x = inputs[:, t, :].t()
# input gate
i = torch.sigmoid(self.w_ii @ x + self.b_ii + self.w_hi @ h_t +self.b_hi)
# forget gate
f = torch.sigmoid(self.w_if @ x + self.b_if + self.w_hf @ h_t +self.b_hf)
# cell
g = torch.tanh(self.w_ig @ x + self.b_ig + self.w_hg @ h_t + self.b_hg)
# output gate
o = torch.sigmoid(self.w_io @ x + self.b_io + self.w_ho @ h_t +self.b_ho)

c_next = f * c_t + i * g
h_next = o * torch.tanh(c_next)
c_next_t = c_next.t().unsqueeze(0)
h_next_t = h_next.t().unsqueeze(0)
hidden_seq.append(h_next_t)

hidden_seq = torch.cat(hidden_seq, dim=0)
return hidden_seq, (h_next_t, c_next_t)

def reset_weigths(model):
"""reset weights
"""
for weight in model.parameters():
init.constant_(weight, 0.5)

### test
inputs = torch.ones(1, 1, 10)
h0 = torch.ones(1, 1, 20)
c0 = torch.ones(1, 1, 20)
print(h0.shape, h0)
print(c0.shape, c0)
print(inputs.shape, inputs)

# test naive_lstm with input_size=10, hidden_size=20
naive_lstm = NaiveLSTM(10, 20)
reset_weigths(naive_lstm)

output1, (hn1, cn1) = naive_lstm(inputs, (h0, c0))

print(hn1.shape, cn1.shape, output1.shape)
print(hn1)
print(cn1)
print(output1)

# Use official lstm with input_size=10, hidden_size=20
lstm = nn.LSTM(10, 20)
reset_weigths(lstm)
output2, (hn2, cn2) = lstm(inputs, (h0, c0))
print(hn2.shape, cn2.shape, output2.shape)
print(hn2)
print(cn2)
print(output2)