目前在自然语言处理领域,Transformer的编码能力超越了RNN,但是对长距离依赖的建模能力仍然不足。在基于LSTM的模型中,为了建模长距离依赖,提出了门控机制和梯度裁剪,目前可以编码的最长距离在200左右。在基于Transformer的模型中,允许词之间直接self-attention,能够更好地捕获长期依赖关系,但是还是有限制,本文将主要介绍Transformer-XL,并基于PyTorch框架从头实现Transformer-XL。

原始Transformer

细想一下,BERT在应用Transformer时,有一个参数sequence length,也就是BERT在训练和预测时,每次接受的输入是固定长度的。那么,怎么输入语料进行训练时最理想的呢?当然是将一个完整的段落一次性输入,进行特征提取了。但是现实是残酷的,这么大的Transformer,内存是消耗不起的。所以现有的做法是,对段落按照segment进行分隔。在训练时:

  • 当输入segment序列比sequence length短时,就做padding。
  • 当输入segment序列比sequence length长时就做切割。

这种做法显然是一种权宜之计,它有这么两个缺点:

  1. 长句子切割必然会造成语义的残破,不利于模型的训练。

  2. segment的切割没有考虑语义,也就是模型在训练当前segment时拿不到前面时刻segment的信息,造成了语义的分隔。

那么,该如何解决上述问题呢?围绕建模长距离依赖,提出Transformer-XL【XL是extra long的意思】。

Transformer-XL

我们先想一下,如果要我们自己来解决Transformer上面的问题,会怎么处理呢?

熟悉NLP的同学,可能会想到RNN。在RNN中,为了获取序列中的历史记忆,采用了Recurrence机制,在计算该时刻的状态时,引入前一时刻的状态作为输入。那对Transformer来说,在计算当前序列的隐藏状态时,引入前一个序列的隐藏状态信息不就可以解决上面的问题了吗?

事情真的有这么简单吗?其实,基本上也就是这么简单,不过Transformer-XL在引入时做了一些巧妙的设计。下面我们看看,Transformer-XL是如何引入这种Recurrence机制来解决上述问题的。

上图是传统的Transformer在训练和评估阶段采用的语料输入策略。在训练时,将整个语料库分割成可管理的大小的更短的片段,在每个片段中训练模型,忽略来自前一段的所有上下文信息;在评估阶段,传统的Transformer模型在每个步骤都消耗与训练期间相同长度的一个segment。然后,在下一步中,这个segment向右移动一个位置,并从头开始处理,只在最后一个位置进行一次预测。

如上图所示,Transformer-XL采用了不同的策略,在训练过程中,对上一个segment计算的隐藏状态序列进行固定和缓存,并在模型处理下一个新的segment时对其进行利用。在评估阶段,可以重用前面部分的表示,而不是像传统模型那样从头开始计算,这样可以提高速度。

总的来说,相比Transformer,改进如下:

  1. 片段级别的循环机制:增加Transformer处理文本的长度,而且解决文本碎片(指的是之前的Transformer最大处理长度为定长sequence length,超过sequence length则会截断,这样导致截断处文本信息断裂,连接不上上下文)的问题。相当于滑窗,窗口大小为sequence length。

  2. 相对位置编码:解决在不同片段中相同token,绝对位置编码可能相同,无法区分的问题。采用相对距离的方式得到相应的位置编码。

Recurrence机制

事实上,问题的关键在于,在计算当前序列当前层的隐藏状态时,如何引入前一个序列上一层的隐藏状态。Transformer-XL的做法很简单,就是按照序列长度的维度将他们concate起来。如下的公式所示:

其中:

  • hτnh_{\tau}^n是一个L×dL \times d的矩阵,表示的是第τ\tau个输入序列的第nn层的隐藏层的状态。LL表示序列长度,dd表示嵌入维度。

  • SG(.)表示的Stop Gradient,这非常重要,避免了RNN会出现的一系列问题。

从上述公式可以看出,Transformer-XL与传统的Transformer的差异主要在于隐藏层输入K和V的差异。Transformer-XL中引入了上一个序列前一个隐藏层的值,将他们concatenate起来,计算新的K和V。

具体以下图进行详细说明;

所谓循环机制(其实就是滑窗),就是要重用之前的状态(上图中,橙色部分为初始的记忆单元memory,绿色部分表示segment)。比如在Layer1中,要计算第1个片段第1个位置seg1_1的结果,需要用到前一层的 “记忆单元(Embedding中的mem)和 seg1_1”。再比如在计算Layer1中seg2_1时,mem变成了Embedding中的seg1。就是说,窗口滑动是以片段为单位的。图中,箭头部分,表示当前需要计算的内容的attend来源。橙色箭头表示来自上一层的mem,绿色箭头表示来自上一层的对应的位置。

在计算attention时(比如Layer1的seg1_1,Q为Embedding的seg1_1,K=V为embedding中的mem+seg1_1),先计算 attn_weight =$ softmax(QK^T)$ ,表示要产生Layer1中seg1_1的attention,每个V要贡献的权重;再计算attn_weight$ * V$ ,表示V加权求和的结果。需要注意的是,在计算反向传播时,mem部分是不进行梯度更新的。此外,这里可以很明显的看出与RNN等循环网络,“循环”的不同之处在于,RNN是在同一层传递的( stept\text{step}_t 用到 stept1\text{step}_{t-1}及之前的记忆),Transformer-XL是在不同层之间传递的( layert\text{layer}_t用到 layert1\text{layer}_{t-1} 及以前的记忆)。

Relative Positional Encodings

在传统的Transformer中,输入序列中的位置信息是怎么表示的?通过POS函数生成,它是位置ii和维度dd的函数,也就是不同输入segment在相同绝对位置中的位置表示是相同的。在传统的Transformer中,每个segment之间的表示是没有关联的,这当然就没有问题。但是在Transformer-XL中,因为引入了前一时刻segment的信息,就需要对不同时刻,同样是第ii个的词进行区分。

Transformer-XL引入了一种Relative Positional Encodings机制,会根据词之间的相对距离而非像传统的Transformer中的绝对位置进行编码。

在传统的Transformer中,计算qiq_i和键kjk_j之间的attention分数的方式为

Ai,jabs=(Wq(Exi+Ui))(Wk(Exj+Uj))\mathbf{A}_{i, j}^{\mathrm{abs}} = (\mathbf{W}_q(\mathbf{E}_{x_i} + \mathbf{U}_i))^{\top} (\mathbf{W}_k(\mathbf{E}_{x_j} + \mathbf{U}_j))

展开就是:

其中:

  • Exi\mathbf{E}_{x_{i}}是词ii的embedding
  • Exj\mathbf{E}_{x_{j}}是词jj的embedding
  • UiU_iUjU_j 是位置向量。

在Transformer-XL中,对上述的attention计算方式进行了变换,转为相对位置的计算,而且不仅仅在第一层这么计算,在每一层都是这样计算。

对比来看,主要有三点变化:

  1. 在b和d这两项中,将所有绝对位置向量UiU_iUjU_j都转为相对位置向量RijR_{i−j},与Transformer一样,这是一个固定的编码向量,不需要学习。

  2. 在c这一项中,将查询的UiTWqTU_i^TW_q^T向量转为一个需要学习的参数向量uu,因为在考虑相对位置的时候,不需要query的绝对位置ii,因此对于任意的ii,都可以采用同样的向量。同理,在d这一项中,也将query的UiTWqTU_i^TW_q^T向量转为另一个需要学习的参数向量v。

  3. 将K的权重变换矩阵WkW_k转为Wk,EW_{k,E}Wk,RW_{k,R},分别作为content-based key vectors和location-based key vectors。

四部分分别对应为:

  • 基于内容的“寻址”,即没有添加原始位置编码的原始分数。
  • 基于内容的位置偏置,即相对于当前内容的位置偏差。
  • 全局的内容偏置,用于衡量key的重要性。
  • 全局的位置偏置,根据query和key之间的距离调整重要性。

总的来说,Relative Positional Encodings就是在计算attention分数时,用相对位置RijR_{i-j}编码来代替原来的绝对位置编码UiU_iUjU_j。并且学习了相对位置v和u用来调整不同距离和不同嵌入的得分。

计算公式

结合上面两个创新点,将Transformer-XL模型的整体计算公式整理如下,这里考虑一个N层的只有一个注意力头的模型:

其中,τ\tau代表第几段,nn代表第几层,hτ0:=Esτh_\tau^0 := E_{s_\tau}定义为第τ\tau的词向量序列。值得一提的是,计算AA矩阵的时候,需要对所有的iji-j计算Wk,RnRijW_{k,R}^n R_{i-j},如果直接按照公式计算的话,计算时间是O(length)2O(length)^2 ,而实际上iji-j的范围只从0 ~ length,因此可以先计算好这length个向量,然后在实际计算AA矩阵时直接取用即可。

具体的,设MMLL分别为memory和当前段序列的长度,则iji−j的范围也就为0 ~ M + L - 1。下面的Q矩阵中的每一行都代表着Wk,RRijW_{k,R}R_{i-j}中一个iji−j的可能性,即Qk=Wk,RRM+L1kQ_k = W_{k, R} R_{M+L-1-k}

Q:=[RM+L1RM+L2R1R0]Wk,R=[[Wk,RRM+L1][Wk,RR1][Wk,RR0]]R(M+L)×d\mathbf{Q} :=\left[ \begin{array}{c}{\mathbf{R}_{M+L-1}^{\top}} \\ {\mathbf{R}_{M+L-2}^{\top}} \\ {\vdots} \\ {\mathbf{R}_{1}^{\top}} \\ {\mathbf{R}_{0}^{\top}}\end{array}\right] \mathbf{W}_{k, R}^{\top}=\left[ \begin{array}{c}{\left[\mathbf{W}_{k, R} \mathbf{R}_{M+L-1}\right]^{\top}} \\ {\vdots} \\ {\vdots} \\ {\left[\mathbf{W}_{k, R} \mathbf{R}_{1}\right]^{\top}} \\ {\left[\mathbf{W}_{k, R} \mathbf{R}_{0}\right]^{\top}}\end{array}\right] \in \mathbb{R}^{(M+L) \times d}

则对于上面公式中的(b)项,即qiTWk,RRijq_i^T W_{k,R}R_{i-j},其构成的所有可能向量的矩阵为BB矩阵,其形状为L * (M + L)。

从上式中,这是我们最终需要的(b)项的attention结果。

我们进一步定义B~\tilde{B}

可见,需要的B矩阵的每一行只是B~\tilde{B}的向左shift而已。因此,可以直接利用矩阵乘法计算B~\tilde{B}即可。设RijR_{i-j}的维度为dRd_Rqiq_i的维度为dqd_qWk,RW_{k,R}矩阵的维度为dq×dRd_q \times d_R,则直接计算矩阵B的时间复杂度为2×dq×dR×L×(M+L)2 \times d_q \times d_R \times L \times (M+L),而计算B~\tilde{B}的时间复杂度为L×dq×(M+L)+dq×dR×(M+L)L \times d_q \times (M + L) + d_q \times d_R \times (M + L),计算量明显不是一个量级(后者要快很多)。

我们以一个二维矩阵进行说明:

1
2
3
4
5
6
7
x = torch.linspace(1, 12, 12).view(3,4)
print(x)

#output:
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
1
2
3
4
5
6
7
8
9
10
zero_pad = torch.zeros((x.size(0), 1))
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.size(1) + 1, x.size(0))
x = x_padded[1:].view_as(x)
print(x)

#output:
tensor([[ 3., 4., 0., 5.],
[ 6., 7., 8., 0.],
[ 9., 10., 11., 12.]])

总的来说Transformer-XL对Transformer进行了一些调整,试图解决一些问题。按照论文的描述,Transformer-XL学习的依赖关系比RNN长80%,比传统Transformer长450%,在短序列和长序列上都获得了更好的性能,并且在评估阶段比传统Transformer快1800+倍。

实验

接下来,我们将使用PyTorch框架从头实现Transformer-XL。真正理解某个模型的最好方法是从头开始构建。

概述

由于Transformer-XL涉及到Transformer,因此让我们来回顾一下最初的Transformer结构。总体而言,Transformer结构是由多个MultiHeadAttention层堆叠在一起,并包含前馈层、残差层和层标准化层。如下图所示:

MultiHeadAttention层由多个attention head组成。每个attention head对其输入应用一个线性变换,并使用keys和querys计算其输入values上的attention。如下图所示:

这种方法无法步抓到位置信息,因此Transformer将表示输入位置的embeddings与词embeddings进行相加。

现在,我们来看看Transformer-XL。为了更全面地了解整个数据流,看一下Transformer-XL的前向传递的简化版本:

1
2
3
4
5
6
7
8
9
10
11
12
def forward(self, input_):
hidden_states = []
pos_embs = self.position_embedding(input_) # 位置embedding
word_embs = self.word_embedding(input_) # 词embedding

layer_out = word_embs
for mem, layer in zip(memory, self.layers):
layer_out = layer(layer_out, pos_embs, mem)
hidden_states.append(layer_out)
logits = self.output_projection(self.drop(layer_out))
new_memory = self.update_memory(memory, hidden_states)
return logits, new_memory

其中:

  • memory: 这是Transformer XL的独特之处。正确地处理memory是使Transformer-XL正确运行的关键点之一。

  • layer: 这是Transformer-XL的核心部分。虽然这与MultiheadAttention层基本相同,但是有几个关键的变化,比如相对位置编码。

接下来,我们将详细的实现每一个部分。

单Attention Head

我们将首先在一个MultiHeadAttention 层中实现一个attention head。以第一层为例,假设该层的收入为一个shape为(seq=7, batch_size=3, embedding_dim=32)的word embeddings。注意,Transformer-XL并不向输入添加位置embedding。

1
2
seq, batch_size, embedding_dim = 7, 3, 32
word_embs = torch.rand(seq, batch_size, embedding_dim)

在Transformer-XL中,我们需要缓存之前的序列的输出。在第一层中,之前的序列输出定义为词embeddings。

另外假设之前的序列长度为prev_seq=6,则:

1
2
prev_seq = 6
memory = torch.rand(prev_seq, batch_size, embedding_dim) # hidden states from the previous sequence

每个attention head以keys、queries和values作为输入。并进行下面的处理过程::

  1. 对每个keys、queries和、values进行不同的线性变换。

  2. 计算每个values的 attention scores。

  3. 对于每个query,计算values的attention-weighted sum。

  4. 进行残差连接和层标准化。

我们从线性变换开始。

1
2
3
4
inner_dim = 17 # this will be the internal dimension
linear_k = nn.Linear(embedding_dim, inner_dim)
linear_v = nn.Linear(embedding_dim, inner_dim)
linear_q = nn.Linear(embedding_dim, inner_dim)

从Transformer-XL计算公式可知,keys和values与正常Transformer中的keys,values是不一样的。根据公式,将memory和输入在序列长度纬度进行拼接,并作为keys/values的输入。需要注意的是,query是不做该变化的,因为每个query表示一个我们想要预测的单词。

1
2
3
4
word_embs_w_memory = torch.cat([memory, word_embs], dim=0)
k_tfmd = linear_k(word_embs_w_memory)
v_tfmd = linear_v(word_embs_w_memory)
q_tfmd = linear_q(word_embs) # No memory for the queries

接下来,我们类似正常的Transformer一样计算scaled dot product attention。scaled dot product attention通过计算query和key向量之间的点积作为attention score。为了防止values随着向量维数的增加而太大,我们将原始attention score除以embedding size的平方根。

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \mathbf{softmax}(\frac{QK^T}{\sqrt{d_k}})V

我们将在这里使用einsum符号进行编写,如果你不熟悉einsum的话,可以点击该教程连接。简而言之,einsum表示输入和输出的形状,使用一个字母表示每个维度。下面,输入的形状是’ (i, b, d) ‘和’ (j, b, d) ‘,输出的形状是’ (i, j, b) ',其中相同的字母表示相同的大小。einsum是通过对具有相同字符的维度进行点积来计算的。

1
content_attn = torch.einsum("ibd,jbd->ijb", q_tfmd, k_tfmd) / (embedding_dim ** 0.5) # scale

注意,我们没有使用softmax激活函数,因为还要计算相对位置编码。

相对位置编码

Transformer-XL中的一个关键点是相对位置编码。Transformer-XL计算一个表示任意两个token之间距离的embeddings,而不是使用每个token的绝对位置embeddings。

向量qiq_i和向量kjk_j计算公式如下:

这里ExE_{x}xx的词embedding,WW是变换矩阵。a项是 content-based attention,我们已经在上面计算过了。b和d是基于相对位置嵌入的,并且依赖于qiq_ikjk_j之间的距离。u和v是表示对特定内容和特定位置的偏差的全局偏差术语。

下面让我们来看看b到d的具体实现。我们首先加入content bias (c项),因为它是最容易计算的。

1
2
u = torch.rand(17).expand_as(q_tfmd)
content_attn = content_attn + torch.einsum("ibd,jbd->ijb", u, k_tfmd) / (embedding_dim ** 0.5)

接下来,计算所需的相对位置嵌入。对于相对位置嵌入,Transformer-XL使用固定的正弦嵌入。

1
2
3
pos_idxs = torch.arange(seq + prev_seq - 1, -1, -1.0, dtype=torch.float)
pos_idxs
#output: tensor([12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1., 0.])
1
2
3
4
inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_dim, 2.0) / embedding_dim))
sinusoid_inp = torch.einsum("i,j->ij", pos_idxs, inv_freq)
plt.plot(sinusoid_inp[0, :].detach().numpy())
plt.plot(sinusoid_inp[6, :].detach().numpy())

1
2
3
4
5
pos_idxs = torch.arange(seq + prev_seq - 1, -1, -1.0, dtype=torch.float)
inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_dim, 2.0) / embedding_dim))
sinusoid_inp = torch.einsum("i,j->ij", pos_idxs, inv_freq)

relative_positional_embeddings = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)[:,None,:]
1
2
relative_positional_embeddings.shape
#output:torch.Size([13, 1, 32])

将上述合在一起为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class PositionalEmbedding(nn.Module):
def __init__(self, d):
super().__init__()
self.d = d
inv_freq = 1 / (10000 ** (torch.arange(0.0, d, 2.0) / d))
# register buffer tells pytorch that this tensor is part of the modle
# this means that it will be saved in the state_dict and moved to the GPU
# along with the model
self.register_buffer("inv_freq", inv_freq)

def forward(self, positions: torch.LongTensor, # (seq, )
):
# outer product
sinusoid_inp = torch.einsum("i,j->ij", positions.float(), self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
return pos_emb[:,None,:]

我们还需要对keys/values的位置嵌入纬度进行变换。

1
2
linear_p = nn.Linear(embedding_dim, inner_dim)
pos_tfmd = linear_p(relative_positional_embeddings)

因此,将位置偏差添加到attention计算过程中。

1
2
3
4
v = torch.rand(17) # positional bias
pos_attn = torch.einsum("ibd,jd->ijb", q_tfmd + v, pos_tfmd[:,0,:]) / (embedding_dim ** 0.5) # scale
pos_attn.shape
#output: torch.Size([7, 13, 3])

由于我们需要为每个key-query对计算相对位置嵌入,所以上述中使用相对位置嵌入来实现注意力的简单实现在计算复杂度方面为O(n^2)。幸运的是,原作者提出了一个技巧,通过计算一个query的attention,然后为不同的query位置转移其嵌入,从而将时间减少到O(n)(具体可以见上述公式)。

1
2
3
4
5
zero_pad = torch.zeros((seq, 1, batch_size), dtype=torch.float)
# this padding + shifting efficiently computes the attention for all
pos_attn = (torch.cat([zero_pad, pos_attn], dim=1)
.view(seq + prev_seq + 1, seq, batch_size)[1:]
.view_as(pos_attn))

因此,总的attention score为:

1
raw_attn = content_attn + pos_attn

当我们进行语言建模时,我们需要阻止模型查看它应该预测的单词。在Transformer中,我们通过将attention score设置为0来实现这一点。这将掩盖了我们不希望模型看到的字。

1
2
3
4
5
mask = torch.triu(
torch.ones((seq, seq + prev_seq)),
diagonal=1 + prev_seq,
).byte()[...,None]
raw_attn = raw_attn.masked_fill(mask, -float('inf'))

接下来计算value的加权和:

1
2
3
4
attn = torch.softmax(raw_attn, dim=1)
attn_weighted_sum = torch.einsum("ijb,jbd->ibd", attn, v_tfmd)
attn_weighted_sum.shape
#output: torch.Size([7, 3, 17])

最后,将attn_weighted_sum的纬度转换回原来纬度,并使用残差连接层和层标准化,即:

1
2
3
4
5
linear_out = nn.Linear(inner_dim, embedding_dim)
layer_norm = nn.LayerNorm(embedding_dim)
output = layer_norm(word_embs + linear_out(attn_weighted_sum))
output.shape
#output: torch.Size([7, 3, 32])

MultiHeadAttention模块

结合上述代码模块,并增加dropout层,我们将得到一个MultiHeadAttention模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import *
class MultiHeadAttention(nn.Module):
def __init__(self, d_input: int, d_inner: int, n_heads: int=4,
dropout: float=0.1, dropouta: float=0.):
super().__init__()
self.d_input = d_input
self.d_inner = d_inner
self.n_heads = n_heads

# this layer applies the linear transformation required
# for the keys and values for all heads at once for efficiency
self.linear_kv = nn.Linear(
d_input,
(d_inner * n_heads * 2), # 2 is for keys and values
bias=False, # we don't apply bias, making this a simple matrix multiplication
)
# for queries (will not be concatenated with memorized states so separate)
self.linear_q = nn.Linear(
d_input, d_inner * n_heads,
bias=False
)
# for positional embeddings
self.linear_p = nn.Linear(
d_input, d_inner * n_heads,
bias=False
)
self.scale = 1 / (d_inner ** 0.5) # for scaled dot product attention
self.dropa = nn.Dropout(dropouta)
# we will use this to project back to the input dimension
self.lout = nn.Linear(self.d_inner * self.n_heads, self.d_input, bias=False)
self.norm = nn.LayerNorm(self.d_input)
self.dropo = nn.Dropout(dropout)

def _rel_shift(self, x):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
return (torch.cat([zero_pad, x], dim=1)
.view(x.size(1) + 1, x.size(0), *x.size()[2:])[1:]
.view_as(x))

def forward(self, input_: torch.FloatTensor, # (cur_seq, b, d_in)
pos_embs: torch.FloatTensor, # (cur_seq + prev_seq, d_in)
memory: torch.FloatTensor, # (prev_seq, b, d_in)
u: torch.FloatTensor, # (H, d)
v: torch.FloatTensor, # (H, d)
mask: Optional[torch.FloatTensor]=None,
):
"""
pos_embs: we pass the positional embeddings in separately
because we need to handle relative positions
input shape: (seq, bs, self.d_input)
pos_embs shape: (seq + prev_seq, bs, self.d_input)
output shape: (seq, bs, self.d_input)
"""
cur_seq = input_.shape[0] # sequence length of current segment
prev_seq = memory.shape[0] # sequence length of previous segment
H, d = self.n_heads, self.d_inner
input_with_memory = torch.cat([memory, input_], dim=0) # concatenate recurrent memory
# across sequence dimension

# we will use the following symbols to represent the shape of the tensors
# cs: current sequence length, b: batch, H: number of heads
# d: inner dimension, ps: previous sequence length
# The key and value are now conditioned on the preceding context
k_tfmd, v_tfmd = \
torch.chunk(self.linear_kv(input_with_memory), 2, dim=-1) # (cs + ps, b, H * d)
q_tfmd = self.linear_q(input_) # (cs, b, H * d)

# apply scaled dot product attention
# look at the following dimensions carefully, since this is the key operation
# in the Transformer/Transformer XL architecture

_, bs, _ = q_tfmd.shape
assert bs == k_tfmd.shape[1]

# content-based attention term ((a) + (c) in the paper)
# this is the standard attention term in the original Transformer, except without positional embeddings
# which are handled separately in the Transformer XL (see below)
# here, i corresponds to the number of queries = number of current inputs/targets (seq-wise)
# j corresponds to the number of key/values = number of vectors that we can use to compute the
# vector for each query
content_attn = torch.einsum("ibhd,jbhd->ijbh", (
(q_tfmd.view(cur_seq, bs, H, d) + # (a)
u), # (c): u represents the global (independent of the query)
# bias towards certain key/values = words
# Note: maybe this could be a per-attention head parameter?
k_tfmd.view(cur_seq + prev_seq, bs, H, d) # There is no positional information to be found here
)) # (cs, cs + ps, b, H)

# position-based attention term ((b) + (d) in the paper)
# this attention is solely based on the position of the key/values
# (i.e. it does not take the content of the key/values into account)
p_tfmd = self.linear_p(pos_embs) # (cs + ps, b, H * d)
position_attn = torch.einsum("ibhd,jhd->ijbh", (
(q_tfmd.view(cur_seq, bs, H, d) + # (b)
v), # (d): v represents the global (independent of the query)
# bias towards certain positions
p_tfmd.view(cur_seq + prev_seq, H, d) # Notice there is not content information
# regarding keys and values here!
)) # (cs, cs + ps, b, H)

# Compute positional attention efficiently
position_attn = self._rel_shift(position_attn)

# the attention is the sum of content-based and position-based attention
attn = content_attn + position_attn

if mask is not None and mask.any().item():
attn = attn.masked_fill(
mask[...,None], -float('inf'))
attn = torch.softmax(attn * self.scale, # rescale to prevent values from exploding
dim=1) # normalize across the value sequence dimension
attn = self.dropa(attn)

attn_weighted_values = (torch.einsum("ijbh,jbhd->ibhd",
(attn, # (cs, cs + ps, b, H)
v_tfmd.view(cur_seq + prev_seq, bs, H, d), # (cs + ps, b, H, d)
)) # (cs, b, H, d)
.contiguous() # we need to change the memory layout to make `view` work
.view(cur_seq, bs, H * d)) # (cs, b, H * d)

# Project back to input dimension and add residual connection
output = input_ + self.dropo(self.lout(attn_weighted_values))
output = self.norm(output)
return output

我们使用一个随机数进行测试是否正确,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
mha = MultiHeadAttention(32, 17, n_heads=4)
inpt = torch.rand(7, 3, 32)
pos = torch.rand(13, 32)
mem = torch.rand(6, 3, 32)
u, v = torch.rand(4, 17), torch.rand(4, 17)
x1 = mha(inpt, pos, mem, u, v)
x1.shape
#output: torch.Size([7, 3, 32])
x1[0]
#output:
tensor([[ 0.6264, 0.3405, -1.9065, 0.1543, 0.0389, -1.6033, 1.4415, 0.4983,
0.7548, 1.0990, -1.1783, -1.3847, -1.7358, 1.4651, 1.0633, 0.2168,
-0.3323, 1.1270, 0.1614, 1.0170, 1.0459, -0.7286, 0.5064, -1.4765,
0.0448, -1.2500, 0.3132, -0.8007, 0.4089, 0.7325, -1.2740, 0.6147],
[ 0.9541, 0.3682, -0.8096, 0.1357, -0.9159, -1.4382, -1.3385, 0.8269,
0.2721, -0.4982, 1.3105, -0.0236, -1.0547, -1.3076, 1.8884, -0.2891,
1.5231, 0.5507, -0.6423, 0.4412, 1.3656, 0.7858, -0.9425, -0.3198,
-0.3162, -0.0086, 1.5257, -1.3216, 1.4492, -0.1750, -0.1669, -1.8291],
[ 1.0132, 0.7205, -0.4221, 0.2952, -1.4117, -0.6182, -1.7520, -1.7426,
-0.4648, -0.2122, 2.0889, -1.3544, -0.1611, -1.0696, 1.3492, -1.0179,
1.2820, 0.8990, 0.7411, 0.8052, -0.5322, 0.6277, -0.2733, -1.0738,
-0.8435, 1.5357, 0.8260, -0.3422, -0.6204, 1.0091, 0.1011, 0.6182]],
grad_fn=<SelectBackward>)

Decoder

在deocder模块中,除了MultiHeadAttention 层外,还需要FFN。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class PositionwiseFF(nn.Module):
def __init__(self, d_input, d_inner, dropout):
super().__init__()
self.d_input = d_input
self.d_inner = d_inner
self.dropout = dropout
self.ff = nn.Sequential(
nn.Linear(d_input, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_input),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_input)

def forward(self, input_: torch.FloatTensor, # (cur_seq, bs, d_input)
) -> torch.FloatTensor: # (cur_seq, bs, d_input)
ff_out = self.ff(input_)
output = self.layer_norm(input_ + ff_out)
return output

则Decoder模块如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class DecoderBlock(nn.Module):
def __init__(self, n_heads, d_input,
d_head_inner, d_ff_inner,
dropout, dropouta=0.):
super().__init__()
self.mha = MultiHeadAttention(d_input, d_head_inner, n_heads=n_heads,
dropout=dropout, dropouta=dropouta)
self.ff = PositionwiseFF(d_input, d_ff_inner, dropout)

def forward(self, input_: torch.FloatTensor, # (cur_seq, bs, d_input)
pos_embs: torch.FloatTensor, # (cur_seq + prev_seq, d_input),
u: torch.FloatTensor, # (H, d_input),
v: torch.FloatTensor, # (H, d_input),
mask=None,
mems=None,
):
return self.ff(self.mha(input_, pos_embs, mems, u, v, mask=mask))

现在有了上述模块,我们就可以构建完整的Transformer-XL模型了。

除了上面提到的,我们还没有涉及到的语言建模的一个常见技巧是将输入嵌入矩阵E和输出投影矩阵P绑定在一起。请记住,语言模型预测序列中的下一个token,因此它的输出维度是RV\mathbb{R}^{|V|},其中V|V|是vocab的大小。如果我们将倒数第二层的输出约束为与嵌入层dd相同的维度,则嵌入矩阵EE的shape为RV×d\mathbb{R}^{|V| \times d},输出投影矩阵PP的形状为Rd×V\mathbb{R}^{d \times |V|}

P=ETP = E^T,可以提高性能,同时大大减少模型的总参数(从而减少内存使用量!)

1
2
3
4
5
6
7
8
9
10
11
12
import torch.nn.functional as F
class StandardWordEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim,
div_val=1, sample_softmax=False):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.scale = embedding_dim ** 0.5

def forward(self, input_: torch.LongTensor):
return self.embedding(input_) * self.scale

因此完整的Transformer-XL代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class TransformerXL(nn.Module):
def __init__(self, num_embeddings, n_layers, n_heads,
d_model, d_head_inner, d_ff_inner,
dropout=0.1, dropouta=0.,
seq_len: int=0, mem_len: int=0):
super().__init__()
self.n_layers,self.n_heads,self.d_model,self.d_head_inner,self.d_ff_inner = \
n_layers,n_heads,d_model,d_head_inner,d_ff_inner
# Embedding layers
self.word_embs = StandardWordEmbedding(num_embeddings, d_model)
self.pos_embs = PositionalEmbedding(d_model)

# Core transformer
self.drop = nn.Dropout(dropout)
self.layers = nn.ModuleList([DecoderBlock(n_heads, d_model, d_head_inner=d_head_inner,
d_ff_inner=d_ff_inner,
dropout=dropout, dropouta=dropouta)
for _ in range(n_layers)])

# tie weights
self.output_projection = nn.Linear(d_model, num_embeddings)
self.output_projection.weight = self.word_embs.embedding.weight
self.loss_fn = nn.CrossEntropyLoss()

self.seq_len, self.mem_len = seq_len, mem_len

# u and v are global parameters: maybe changing these to per-head parameters
# might help performance?
self.u, self.v = (nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)),
nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)))

def init_memory(self, device=torch.device("cpu")) -> torch.FloatTensor:
return [torch.empty(0, dtype=torch.float).to(device) for _ in range(self.n_layers+1)]

def update_memory(self,
previous_memory: List[torch.FloatTensor],
hidden_states: List[torch.FloatTensor],
):
assert len(hidden_states) == len(previous_memory)
mem_len, seq_len = previous_memory[0].size(0), hidden_states[0].size(0)

# For the updated memory, we use the most recent `self.mem_len`
# states, including the previous memory
# In other words, if `seq_len` < `self.mem_len` some of the previous memory
# will carry over to the next memory
with torch.no_grad():
new_memory = []
end_idx = mem_len + seq_len
beg_idx = max(0, end_idx - self.mem_len)
for m, h in zip(previous_memory, hidden_states):
cat = torch.cat([m, h], dim=0) # (mem_len + seq_len, bs, d)
new_memory.append(cat[beg_idx:end_idx].detach()) # (self.mem_len, bs, d)
return new_memory

def reset_length(self, seq_len, ext_len, mem_len):
self.seq_len = seq_len
self.mem_len = mem_len

def forward(self, idxs: torch.LongTensor, # (cs, bs)
target: torch.LongTensor, # (cs, bs)
memory: Optional[List[torch.FloatTensor]]=None,
) -> Dict[str, torch.Tensor]:
if memory is None:
memory: List[torch.FloatTensor] = self.init_memory(idxs.device)
assert len(memory) == len(self.layers) + 1
cur_seq, bs = idxs.size()
prev_seq = memory[0].size(0)

# Construct attention mask
dec_attn_mask = torch.triu(
torch.ones((cur_seq, cur_seq + prev_seq)),
diagonal=1 + prev_seq,
).byte()[...,None].to(idxs.device)

word_embs = self.drop(self.word_embs(idxs))
pos_idxs = torch.arange(cur_seq + prev_seq - 1, -1, -1.0, dtype=torch.float).to(word_embs.device)
pos_embs = self.drop(self.pos_embs(pos_idxs))

# Main part of forward pass
hidden_states = [word_embs]
layer_out = word_embs
for mem, layer in zip(memory, self.layers):
layer_out = layer(layer_out, pos_embs, self.u, self.v,
mask=dec_attn_mask, mems=mem)
hidden_states.append(layer_out)

logits = self.output_projection(self.drop(layer_out))
loss = self.loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))

# Update memory
# Ensure the memory is treated as a constant
# and we do not back propagate through them
new_memory = self.update_memory(memory, hidden_states)
return {"loss": loss, "logits": logits, "memory": new_memory}

同样使用一个随机数进行测试,如下:

1
2
3
4
transformer = TransformerXL(1000, 4, 3, 32, 17, 71, mem_len=5)
idxs = torch.randint(1000, (5, 9))
tgts = torch.randint(1000, (5, 9))
transformer(idxs, tgts)

数据加载

Transformer-XL的数据加载类似于基于rnn的语言模型的数据加载,但与标准的数据加载有很大的不同。

假设我们将输入分成4个单词的序列输入到模型中。请记住Transformer-XL是有状态的,这意味着每个mini-batch的计算将被转移到下一个mini-batch。对于mini-batch为1的情况,处理起来很简单。我们只是把输入分成块,然后像这样输入到模型中:

如果批大小是2会发生什么?我们不能像这样拆分句子,否则,我们将打破片段之间的依赖关系。

处理batch size为2的语料库的正确方法,应为:

在此基础上,我们首先将语料库划分成batch size的长度片段,然后将每个片段逐块输入到模型中。让我们来看一个例子。假设batch size 为4,我们的整个语料库是这样的:

1
pytorch is an amazing deep learning framework that makes nlp really easy

我们想要确保前一batch包含在相同位置上的前一段。换句话说,假设我们一次向模型输入一个单词,我们希望像这样迭代这个句子

1
2
3
Batch 1: pytorch amazing framework nlp
Batch 2: is deep that really
Batch 3: an learning makes easy

注意,这意味着你通过从上到下,从左到右,而不是从左到右,从上到下来重新构造原句子。实际上,每个batch中的单词序列的长度通常为bptt(时间反向传播)长度,因为这是梯度沿序列方向传播的最大长度。例如,当bptt长度为2时,batch的shape为(batch_size, bptt):

1
2
3
Batch 1: pytorch amazing framework nlp
is deep that really
Batch 2: an learning makes easy

我们可以实现这在一个数据加载这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from torch.utils import data
import math

class LMDataLoader(data.DataLoader):
def __init__(self, data: torch.LongTensor, batch_size: int, bptt: int,
device=torch.device("cpu")):
self.batch_size = batch_size
self.bptt = bptt
self.n_steps = data.size(0) // batch_size

# we reshape the data here so that we can index
# efficiently into it while training
self.data = (data[:self.n_steps * batch_size] # trim off any elements that don't fit cleanly
.view(batch_size, self.n_steps) #
.transpose(0, 1) #
.contiguous().to(device) # put on device as contiguous tensor
)

def __iter__(self):
for batch_start_idx in range(0, self.data.size(0) - 1, self.bptt):
batch_end_idx = min(batch_start_idx + self.bptt, self.data.size(0) - 1)
# TODO: What is `self.ext_len` in the original code?
batch_data = self.data[batch_start_idx:batch_end_idx]
target = self.data[batch_start_idx+1:batch_end_idx+1]
# we generate the sequence length as well for loss calculation later
yield batch_data, target, batch_end_idx - batch_start_idx

def __len__(self):
return math.ceil(self.data.size(0) / self.bptt)

测试结果如下:

1
2
3
4
5
6
7
8
9
10

test_corpus = torch.arange(1000)
BS = 16
BPTT = 10
test_corpus[:BPTT]
#output: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
loader = LMDataLoader(test_corpus, BS, BPTT)
b1, *_ = next(iter(loader))
b1.shape
#output: torch.Size([10, 16])

完整代码

下载数据集,新建一个名为download_data.sh脚本文件,并写入以下内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#!/bin/bash
echo "- Downloading Penn Treebank (PTB)"
wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
tar -xzf simple-examples.tgz
mkdir -p penn
cd penn
mv ../simple-examples/data/ptb.train.txt train.txt
mv ../simple-examples/data/ptb.test.txt test.txt
mv ../simple-examples/data/ptb.valid.txt valid.txt
cd ..
echo "- Downloading Penn Treebank (Character)"
mkdir -p pennchar
cd pennchar
mv ../simple-examples/data/ptb.char.train.txt train.txt
mv ../simple-examples/data/ptb.char.test.txt test.txt
mv ../simple-examples/data/ptb.char.valid.txt valid.txt
cd ..
rm -rf simple-examples/

运行sh download_data.sh命令进行自动下载数据集。

新建一个名为vocabulary.py文件,并写入以下内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from collections import Counter, OrderedDict
import os
import torch
class Vocab(object):
"""Borrowed from the Transformer XL repository"""
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
delimiter=None, vocab_file=None):
self.counter = Counter()
self.special = special
self.min_freq = min_freq
self.max_size = max_size
self.lower_case = lower_case
self.delimiter = delimiter
self.vocab_file = vocab_file

def tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip()
# convert to lower case
if self.lower_case:
line = line.lower()

# empty delimiter '' will evaluate False
if self.delimiter == '':
symbols = line
else:
symbols = line.split(self.delimiter)

if add_double_eos: # lm1b
return ['<S>'] + symbols + ['<S>']
elif add_eos:
return symbols + ['<eos>']
else:
return symbols
def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path))
assert os.path.exists(path)
sents = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos)
self.counter.update(symbols)
sents.append(symbols)

return sents
def count_sents(self, sents, verbose=False):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if verbose: print('counting {} sents ...'.format(len(sents)))
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
self.counter.update(symbols)
def _build_from_file(self, vocab_file):
self.idx2sym = []
self.sym2idx = OrderedDict()

with open(vocab_file, 'r', encoding='utf-8') as f:
for line in f:
symb = line.strip().split()[0]
self.add_symbol(symb)
self.unk_idx = self.sym2idx['<UNK>']
def build_vocab(self):
if self.vocab_file:
print('building vocab from {}'.format(self.vocab_file))
self._build_from_file(self.vocab_file)
print('final vocab size {}'.format(len(self)))
else:
print('building vocab with min_freq={}, max_size={}'.format(
self.min_freq, self.max_size))
self.idx2sym = []
self.sym2idx = OrderedDict()
for sym in self.special:
self.add_special(sym)
for sym, cnt in self.counter.most_common(self.max_size):
if cnt < self.min_freq: break
self.add_symbol(sym)
print('final vocab size {} from {} unique tokens'.format(
len(self), len(self.counter)))
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
add_double_eos=False):
if verbose: print('encoding file {} ...'.format(path))
assert os.path.exists(path)
encoded = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos,
add_double_eos=add_double_eos)
encoded.append(self.convert_to_tensor(symbols))
if ordered:
encoded = torch.cat(encoded)
return encoded
def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: print('encoding {} sents ...'.format(len(sents)))
encoded = []
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
encoded.append(self.convert_to_tensor(symbols))
if ordered:
encoded = torch.cat(encoded)
return encoded
def add_special(self, sym):
if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
def add_symbol(self, sym):
if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx):
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
return self.idx2sym[idx]
def get_idx(self, sym):
if sym in self.sym2idx:
return self.sym2idx[sym]
else:
# print('encounter unk {}'.format(sym))
assert '<eos>' not in sym
assert hasattr(self, 'unk_idx')
return self.sym2idx.get(sym, self.unk_idx)
def get_symbols(self, indices):
return [self.get_sym(idx) for idx in indices]
def get_indices(self, symbols):
return [self.get_idx(sym) for sym in symbols]
def convert_to_tensor(self, symbols):
return torch.LongTensor(self.get_indices(symbols))
def convert_to_sent(self, indices, exclude=None):
if exclude is None:
return ' '.join([self.get_sym(idx) for idx in indices])
else:
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])

def __len__(self):
return len(self.idx2sym)

改脚本主要处理语料数据,接下里新建一个名为trainsformer_xl.py文件,并写入以下内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
import sys
import torch
import torch.nn as nn
from typing import *
from torch.utils import data
from pathlib import Path
from vocabulary import Vocab
import torch.optim as optim
import math
import time
import os
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt

class PositionalEmbedding(nn.Module):
def __init__(self, d):
super().__init__()
self.d = d
inv_freq = 1 / (10000 ** (torch.arange(0.0, d, 2.0) / d))
# register buffer tells pytorch that this tensor is part of the modle
# this means that it will be saved in the state_dict and moved to the GPU
# along with the model
self.register_buffer("inv_freq", inv_freq)

def forward(self, positions: torch.LongTensor):
# outer product
sinusoid_inp = torch.einsum("i,j->ij", positions.float(), self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
return pos_emb[:, None, :]

class MultiHeadAttention(nn.Module):
def __init__(self, d_input: int, d_inner: int, n_heads: int = 4,
dropout: float = 0.1, dropouta: float = 0.):
super().__init__()
self.d_input = d_input
self.d_inner = d_inner
self.n_heads = n_heads
# this layer applies the linear transformation required
# for the keys and values for all heads at once for efficiency
self.linear_kv = nn.Linear(
d_input,
(d_inner * n_heads * 2), # 2 is for keys and values
bias=False, # we don't apply bias, making this a simple matrix multiplication
)
# for queries (will not be concatenated with memorized states so separate)
self.linear_q = nn.Linear(
d_input, d_inner * n_heads,
bias=False
)
# for positional embeddings
self.linear_p = nn.Linear(
d_input, d_inner * n_heads,
bias=False
)
self.scale = 1 / (d_inner ** 0.5) # for scaled dot product attention
self.dropa = nn.Dropout(dropouta)
# we will use this to project back to the input dimension
self.lout = nn.Linear(self.d_inner * self.n_heads, self.d_input, bias=False)
self.norm = nn.LayerNorm(self.d_input)
self.dropo = nn.Dropout(dropout)

def _rel_shift(self, x):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
return (torch.cat([zero_pad, x], dim=1)
.view(x.size(1) + 1, x.size(0), *x.size()[2:])[1:]
.view_as(x))

def forward(self, input_: torch.FloatTensor, # (cur_seq, b, d_in)
pos_embs: torch.FloatTensor, # (cur_seq + prev_seq, d_in)
memory: torch.FloatTensor, # (prev_seq, b, d_in)
u: torch.FloatTensor, # (H, d)
v: torch.FloatTensor, # (H, d)
mask: Optional[torch.FloatTensor] = None,
):
"""
pos_embs: we pass the positional embeddings in separately
because we need to handle relative positions
input shape: (seq, bs, self.d_input)
pos_embs shape: (seq + prev_seq, bs, self.d_input)
output shape: (seq, bs, self.d_input)
"""
cur_seq = input_.shape[0] # sequence length of current segment
prev_seq = memory.shape[0] # sequence length of previous segment
H, d = self.n_heads, self.d_inner
input_with_memory = torch.cat([memory, input_], dim=0) # concatenate recurrent memory
# across sequence dimension

# we will use the following symbols to represent the shape of the tensors
# cs: current sequence length, b: batch, H: number of heads
# d: inner dimension, ps: previous sequence length
# The key and value are now conditioned on the preceding context
k_tfmd, v_tfmd = \
torch.chunk(self.linear_kv(input_with_memory), 2, dim=-1) # (cs + ps, b, H * d)
q_tfmd = self.linear_q(input_) # (cs, b, H * d)

# apply scaled dot product attention
# look at the following dimensions carefully, since this is the key operation
# in the Transformer/Transformer XL architecture

_, bs, _ = q_tfmd.shape
assert bs == k_tfmd.shape[1]
# content-based attention term ((a) + (c) in the paper)
# this is the standard attention term in the original Transformer, except without positional embeddings
# which are handled separately in the Transformer XL (see below)
# here, i corresponds to the number of queries = number of current inputs/targets (seq-wise)
# j corresponds to the number of key/values = number of vectors that we can use to compute the
# vector for each query
content_attn = torch.einsum("ibhd,jbhd->ijbh", (
(q_tfmd.view(cur_seq, bs, H, d) + # (a)
u), # (c): u represents the global (independent of the query)
# bias towards certain key/values = words
# Note: maybe this could be a per-attention head parameter?
k_tfmd.view(cur_seq + prev_seq, bs, H, d) # There is no positional information to be found here
)) # (cs, cs + ps, b, H)

# position-based attention term ((b) + (d) in the paper)
# this attention is solely based on the position of the key/values
# (i.e. it does not take the content of the key/values into account)
p_tfmd = self.linear_p(pos_embs) # (cs + ps, b, H * d)
position_attn = torch.einsum("ibhd,jhd->ijbh", (
(q_tfmd.view(cur_seq, bs, H, d) + # (b)
v), # (d): v represents the global (independent of the query)
# bias towards certain positions
p_tfmd.view(cur_seq + prev_seq, H, d) # Notice there is not content information
# regarding keys and values here!
)) # (cs, cs + ps, b, H)

# Compute positional attention efficiently
position_attn = self._rel_shift(position_attn)

# the attention is the sum of content-based and position-based attention
attn = content_attn + position_attn

if mask is not None and mask.any().item():
attn = attn.masked_fill(
mask[..., None], -float('inf'))
attn = torch.softmax(attn * self.scale, # rescale to prevent values from exploding
dim=1) # normalize across the value sequence dimension
attn = self.dropa(attn)

attn_weighted_values = (torch.einsum("ijbh,jbhd->ibhd",
(attn, # (cs, cs + ps, b, H)
v_tfmd.view(cur_seq + prev_seq, bs, H, d), # (cs + ps, b, H, d)
)) # (cs, b, H, d)
.contiguous() # we need to change the memory layout to make `view` work
.view(cur_seq, bs, H * d)) # (cs, b, H * d)

# Project back to input dimension and add residual connection
output = input_ + self.dropo(self.lout(attn_weighted_values))
output = self.norm(output)
return output

class PositionwiseFF(nn.Module):
def __init__(self, d_input, d_inner, dropout):
super().__init__()

self.d_input = d_input
self.d_inner = d_inner
self.dropout = dropout
self.ff = nn.Sequential(
nn.Linear(d_input, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_input),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_input)

def forward(self, input_: torch.FloatTensor, # (cur_seq, bs, d_input)
) -> torch.FloatTensor: # (cur_seq, bs, d_input)
ff_out = self.ff(input_)
output = self.layer_norm(input_ + ff_out)
return output

class DecoderBlock(nn.Module):
def __init__(self, n_heads, d_input,
d_head_inner, d_ff_inner,
dropout, dropouta=0.):
super().__init__()
self.mha = MultiHeadAttention(d_input, d_head_inner, n_heads=n_heads,
dropout=dropout, dropouta=dropouta)
self.ff = PositionwiseFF(d_input, d_ff_inner, dropout)

def forward(self, input_: torch.FloatTensor, # (cur_seq, bs, d_input)
pos_embs: torch.FloatTensor, # (cur_seq + prev_seq, d_input),
u: torch.FloatTensor, # (H, d_input),
v: torch.FloatTensor, # (H, d_input),
mask=None,
mems=None,
):
return self.ff(self.mha(input_, pos_embs, mems, u, v, mask=mask))

class StandardWordEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.scale = embedding_dim ** 0.5

def forward(self, input_: torch.LongTensor):
return self.embedding(input_) * self.scale

class TransformerXL(nn.Module):
def __init__(self, num_embeddings, n_layers, n_heads,
d_model, d_head_inner, d_ff_inner,
dropout=0.1, dropouta=0.,
seq_len: int = 0, mem_len: int = 0):
super().__init__()
self.n_layers, self.n_heads, self.d_model, self.d_head_inner, self.d_ff_inner = \
n_layers, n_heads, d_model, d_head_inner, d_ff_inner
# Embedding layers
self.word_embs = StandardWordEmbedding(num_embeddings, d_model)
self.pos_embs = PositionalEmbedding(d_model)
# Core transformer
self.drop = nn.Dropout(dropout)
self.layers = nn.ModuleList([DecoderBlock(n_heads, d_model, d_head_inner=d_head_inner,
d_ff_inner=d_ff_inner,
dropout=dropout, dropouta=dropouta)
for _ in range(n_layers)])

# tie weights
self.output_projection = nn.Linear(d_model, num_embeddings)
self.output_projection.weight = self.word_embs.embedding.weight
self.loss_fn = nn.CrossEntropyLoss()

self.seq_len, self.mem_len = seq_len, mem_len

# u and v are global parameters: maybe changing these to per-head parameters
# might help performance?
self.u, self.v = (nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)),
nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)))

def init_memory(self, device=torch.device("cpu")) -> torch.FloatTensor:
return [torch.empty(0, dtype=torch.float).to(device) for _ in range(self.n_layers + 1)]

def update_memory(self,
previous_memory: List[torch.FloatTensor],
hidden_states: List[torch.FloatTensor],
):
assert len(hidden_states) == len(previous_memory)
mem_len, seq_len = previous_memory[0].size(0), hidden_states[0].size(0)

# For the updated memory, we use the most recent `self.mem_len`
# states, including the previous memory
# In other words, if `seq_len` < `self.mem_len` some of the previous memory
# will carry over to the next memory
with torch.no_grad():
new_memory = []
end_idx = mem_len + seq_len
beg_idx = max(0, end_idx - self.mem_len)
for m, h in zip(previous_memory, hidden_states):
cat = torch.cat([m, h], dim=0) # (mem_len + seq_len, bs, d)
new_memory.append(cat[beg_idx:end_idx].detach()) # (self.mem_len, bs, d)
return new_memory

def reset_length(self, seq_len, ext_len, mem_len):
self.seq_len = seq_len
self.mem_len = mem_len

def forward(self, idxs: torch.LongTensor, # (cs, bs)
target: torch.LongTensor, # (cs, bs)
memory: Optional[List[torch.FloatTensor]] = None,
) -> Dict[str, torch.Tensor]:
if memory is None:
memory: List[torch.FloatTensor] = self.init_memory(idxs.device)
assert len(memory) == len(self.layers) + 1
cur_seq, bs = idxs.size()
prev_seq = memory[0].size(0)

# Construct attention mask
dec_attn_mask = torch.triu(
torch.ones((cur_seq, cur_seq + prev_seq)),
diagonal=1 + prev_seq,
).byte()[..., None].to(idxs.device)

word_embs = self.drop(self.word_embs(idxs))
pos_idxs = torch.arange(cur_seq + prev_seq - 1, -1, -1.0, dtype=torch.float).to(word_embs.device)
pos_embs = self.drop(self.pos_embs(pos_idxs))

# Main part of forward pass
hidden_states = [word_embs]
layer_out = word_embs
for mem, layer in zip(memory, self.layers):
layer_out = layer(layer_out, pos_embs, self.u, self.v,
mask=dec_attn_mask, mems=mem)
hidden_states.append(layer_out)

logits = self.output_projection(self.drop(layer_out))
loss = self.loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))

# Update memory
# Ensure the memory is treated as a constant
# and we do not back propagate through them
new_memory = self.update_memory(memory, hidden_states)
return {"loss": loss, "logits": logits, "memory": new_memory}

class Config(dict):
def __init__(self, **kwargs):
super().__init__(**kwargs)
for k, v in kwargs.items():
setattr(self, k, v)

def set(self, key, val):
self[key] = val
setattr(self, key, val)

def update(self, dct):
for k, v in dct.items():
self.set(k, v)

class LMDataLoader(data.DataLoader):
def __init__(self, data: torch.LongTensor, batch_size: int, bptt: int,
device=torch.device("cpu")):
self.batch_size = batch_size
self.bptt = bptt
self.n_steps = data.size(0) // batch_size
# we reshape the data here so that we can index
# efficiently into it while training
self.data = (data[:self.n_steps * batch_size] # trim off any elements that don't fit cleanly
.view(batch_size, self.n_steps) #
.transpose(0, 1) #
.contiguous().to(device) # put on device as contiguous tensor
)
def __iter__(self):
for batch_start_idx in range(0, self.data.size(0) - 1, self.bptt):
batch_end_idx = min(batch_start_idx + self.bptt, self.data.size(0) - 1)
# TODO: What is `self.ext_len` in the original code?
batch_data = self.data[batch_start_idx:batch_end_idx]
target = self.data[batch_start_idx +1:batch_end_idx +1]
# we generate the sequence length as well for loss calculation later
yield batch_data, target, batch_end_idx - batch_start_idx

def __len__(self):
return math.ceil(self.data.size(0) / self.bptt)

def init_weight(weight):
nn.init.normal_(weight, 0.0, 0.02)

def init_bias(bias):
nn.init.constant_(bias, 0.0)

# Borrowed from the transformer XL repo
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
init_weight(m.weight)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, 0.02)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
else:
if hasattr(m, 'u'):
init_weight(m.u)
if hasattr(m, 'v'):
init_weight(m.v)

def train_epoch(
epoch: int,
model: nn.Module, train_loader: data.DataLoader,
val_loader: data.DataLoader,
optimizer: optim.Optimizer,
scheduler,
train_step_start=0.,
):
# Turn on training mode which enables dropout.
model.train()
mems = None
train_step = train_step_start
train_loss = 0
log_start_time = time.time()
best_val_loss = float("inf")

pbar = tqdm(train_loader, total=min(config.max_step - train_step_start, len(train_loader)))
for batch_idx, (data, target, seq_len) in enumerate(pbar):
model.zero_grad()
out_dict = model(data, target, memory=mems)
loss, mems = out_dict["loss"], out_dict["memory"]

loss.backward()
train_loss += loss.item()
loss_change.append(loss.item())
torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
optimizer.step()

# step-wise learning rate annealing
train_step += 1
# linear warmup stage
if train_step < config.warmup_step:
curr_lr = config.lr * train_step / config.warmup_step
optimizer.param_groups[0]['lr'] = curr_lr
else:
scheduler.step(train_step)

if train_step % config.log_interval == 0:
cur_loss = train_loss / config.log_interval
elapsed = time.time() - log_start_time
log_str = '| epoch {:3d} step {:>8d} | lr {:.3g} ' \
'| loss {:5.2f}'.format(
epoch, train_step, optimizer.param_groups[0]['lr'], cur_loss)
log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
pbar.set_description(log_str)
train_loss = 0
log_start_time = time.time()

if train_step % config.eval_interval == 0:
val_loss = evaluate(model, val_loader)
val_loss_change.append(val_loss)
eval_start_time = time.time()

if train_step == config.max_step:
return train_step
return train_step

def train(model, train_loader, valid_loader):
optimizer = optim.Adam(model.parameters(), lr=config.lr)
total_steps = min(config.max_step, len(train_loader) * config.epochs)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
total_steps, eta_min=config.min_lr)
train_step_start = 0
for epoch in range(config.epochs):
if train_step_start >= config.max_step:
break
train_step_start = train_epoch(
epoch,
model,
train_iter,
valid_iter,
optimizer,
scheduler,
train_step_start,
)

def evaluate(model: nn.Module, val_loader: data.DataLoader):
# Turn on evaluation mode which disables dropout.
model.eval()
model.reset_length(config.eval_bptt,
0, config.eval_mem_len+config.train_bptt-config.eval_bptt)

# Evaluation
total_len, total_loss = 0, 0.
with torch.no_grad():
mems = None
for i, (data, target, seq_len) in enumerate(val_loader):
out_dict = model(data, target, memory=mems)
loss, mems = out_dict["loss"], out_dict["memory"]
total_loss += seq_len * loss.float().item()
total_len += seq_len

# Switch back to the training mode
model.reset_length(config.train_bptt, 0, config.mem_len)
model.train()
return total_loss / total_len

def evaluate_final(model, val_loader):
model.eval()
total_len, total_loss = 0, 0.
start_time = time.time()

model.reset_length(config.eval_bptt, 0, config.eval_mem_len + config.train_bptt - config.eval_bptt)

with torch.no_grad():
mems = None
for i, (data, target, seq_len) in enumerate(val_loader):
out_dict = model(data, target, memory=mems)
loss, mems = out_dict["loss"], out_dict["memory"]
total_loss += seq_len * loss.item()
total_len += seq_len
total_time = time.time() - start_time

model.reset_length(config.train_bptt, 0, config.mem_len)
loss_val = total_loss / total_len
return {"loss": loss_val, "ppl": math.exp(loss_val)}

if __name__ == "__main__":
TESTING = True
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
# We will use prime numbers to ensure our implementation
# is actually correct
config = Config(
seed=101,
debug=False,
warmup_step=0,
# Check default params
min_lr=0.,
dropouta=0.,
clip=0.25,
log_interval=200,
eval_interval=100,
)

if TESTING:
config.update(dict(
debug=True,
lr=0.00025,
bs=8,
epochs=2,
max_step=10000, # shorten for testing
n_layers=4,
n_heads=3,
d_model=32,
d_head_inner=17,
d_ff_inner=71,
dropout=0.1,
train_bptt=33,
eval_bptt=41,
mem_len=41,
eval_mem_len=63,
))
else:
config.update(dict(
lr=0.0025,
bs=64,
epochs=4,
max_step=400000,
n_layers=12,
n_heads=8,
d_model=512,
d_head_inner=64,
d_ff_inner=2048,
dropout=0.1,
train_bptt=512,
eval_bptt=128,
mem_len=512,
eval_mem_len=2100,
))

DATASET = "penn"
DATA_DIR = Path(".") / DATASET

vocab = Vocab(special=["<eos>"], lower_case=True)
vocab.count_file(DATA_DIR / "train.txt")
vocab.count_file(DATA_DIR / "valid.txt")
vocab.count_file(DATA_DIR / "test.txt")
vocab.build_vocab()
train_dataset = vocab.encode_file(DATA_DIR / "train.txt", ordered=True, add_eos=True)
valid_dataset = vocab.encode_file(DATA_DIR / "valid.txt", ordered=True, add_eos=True)
test_dataset = vocab.encode_file(DATA_DIR / "test.txt", ordered=True, add_eos=True)

train_iter = LMDataLoader(train_dataset, config.bs, config.train_bptt, device=device)
valid_iter = LMDataLoader(valid_dataset, config.bs, config.eval_bptt, device=device)
test_iter = LMDataLoader(test_dataset, config.bs, config.eval_bptt, device=device)

loss_change = []
val_loss_change = []

transformer_xl = TransformerXL(
num_embeddings=len(vocab), n_layers=config.n_layers,
n_heads=config.n_heads, d_model=config.d_model,
d_head_inner=config.d_head_inner,
d_ff_inner=config.d_ff_inner,
dropout=config.dropout,
dropouta=config.dropouta,
seq_len=config.train_bptt,
mem_len=config.mem_len,
)
if torch.cuda.is_available():
transformer_xl.cuda()
transformer_xl.apply(weights_init)
train(
transformer_xl,
train_iter,
valid_iter,
)
eval_result = evaluate_final(transformer_xl, valid_iter)
print(eval_result)
plt.plot(loss_change)
plt.savefig('train_loss.png', bbox_inches='tight')
plt.close()
plt.plot(val_loss_change)
plt.savefig('val_loss.png', bbox_inches='tight')
plt.close()

运行python trainsformer_xl.py命令,将得到以下结果:

1
{'loss': 6.048809673745043, 'ppl': 423.6084975087664}

训练loss变化

验证loss变化

参考

  1. Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context
  2. Transformer-XL:Attentive Language Models Beyond a Fixed-Length Context