1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
| import torch import torch.nn as nn from torch.nn import Parameter from torch.nn import init from torch import Tensor import math
class NaiveLSTM(nn.Module): """Naive LSTM like nn.LSTM""" def __init__(self, input_size, hidden_size): super(NaiveLSTM, self).__init__() self.input_size = input_size self.hidden_size = hidden_size
self.w_ii = Parameter(Tensor(hidden_size, input_size)) self.w_hi = Parameter(Tensor(hidden_size, hidden_size)) self.b_ii = Parameter(Tensor(hidden_size, 1)) self.b_hi = Parameter(Tensor(hidden_size, 1))
self.w_if = Parameter(Tensor(hidden_size, input_size)) self.w_hf = Parameter(Tensor(hidden_size, hidden_size)) self.b_if = Parameter(Tensor(hidden_size, 1)) self.b_hf = Parameter(Tensor(hidden_size, 1))
self.w_io = Parameter(Tensor(hidden_size, input_size)) self.w_ho = Parameter(Tensor(hidden_size, hidden_size)) self.b_io = Parameter(Tensor(hidden_size, 1)) self.b_ho = Parameter(Tensor(hidden_size, 1)) self.w_ig = Parameter(Tensor(hidden_size, input_size)) self.w_hg = Parameter(Tensor(hidden_size, hidden_size)) self.b_ig = Parameter(Tensor(hidden_size, 1)) self.b_hg = Parameter(Tensor(hidden_size, 1)) self.reset_weigths()
def reset_weigths(self): """reset weights """ stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): init.uniform_(weight, -stdv, stdv)
def forward(self, inputs, state): """Forward Args: inputs: [1, 1, input_size] state: ([1, 1, hidden_size], [1, 1, hidden_size]) """ if state is None: h_t = torch.zeros(1, self.hidden_size).t() c_t = torch.zeros(1, self.hidden_size).t() else: (h, c) = state h_t = h.squeeze(0).t() c_t = c.squeeze(0).t() hidden_seq = [] seq_size = 1 for t in range(seq_size): x = inputs[:, t, :].t() i = torch.sigmoid(self.w_ii @ x + self.b_ii + self.w_hi @ h_t +self.b_hi) f = torch.sigmoid(self.w_if @ x + self.b_if + self.w_hf @ h_t +self.b_hf) g = torch.tanh(self.w_ig @ x + self.b_ig + self.w_hg @ h_t + self.b_hg) o = torch.sigmoid(self.w_io @ x + self.b_io + self.w_ho @ h_t +self.b_ho)
c_next = f * c_t + i * g h_next = o * torch.tanh(c_next) c_next_t = c_next.t().unsqueeze(0) h_next_t = h_next.t().unsqueeze(0) hidden_seq.append(h_next_t)
hidden_seq = torch.cat(hidden_seq, dim=0) return hidden_seq, (h_next_t, c_next_t)
def reset_weigths(model): """reset weights """ for weight in model.parameters(): init.constant_(weight, 0.5)
inputs = torch.ones(1, 1, 10) h0 = torch.ones(1, 1, 20) c0 = torch.ones(1, 1, 20) print(h0.shape, h0) print(c0.shape, c0) print(inputs.shape, inputs)
naive_lstm = NaiveLSTM(10, 20) reset_weigths(naive_lstm)
output1, (hn1, cn1) = naive_lstm(inputs, (h0, c0))
print(hn1.shape, cn1.shape, output1.shape) print(hn1) print(cn1) print(output1)
lstm = nn.LSTM(10, 20) reset_weigths(lstm) output2, (hn2, cn2) = lstm(inputs, (h0, c0)) print(hn2.shape, cn2.shape, output2.shape) print(hn2) print(cn2) print(output2)
|