einsum

einsum,这是我开发深度学习模型时最喜欢的函数。这篇文章将从PyTorch角度出发,如何使用torch.einsum进行相关的矩阵计算操作。

einsum符号

如果你很难记住PyTorch\TensorFlow中所有中用于计算点积、外积、转置和矩阵向量或矩阵矩阵乘法的不同函数的名称或者使用方式。那么你就急需要einsum,einsum表示法是以一种优雅的方式来表达所有这些矩阵、向量或者张量的复杂计算操作。基本上使用的是一种特定于域的语言。这样做的好处不仅仅是不需要记忆或定期查找特定的库或者函数。一旦您理解并使用einsum,您将能够更快地编写更简洁、更高效的代码。目前主流深度学习框架中都包含了einsum的实现,比如PyTorch中的torch.einsum和TensorFlow中的tf.einsum。

假设我们要计算两个矩阵${\color{red}\mathbf{A}} \in \mathbb{R}^{I \times K}$和${\color{blue}\mathbf{B}}\in\mathbb{R}^{K \times J}$的相乘并对列求和,即得到一个向量${\color{green}\mathbf{c}}\in\mathbb{R}^{J}$为的每一列的和。用einsum符号,我们可以把它写成:

上述公式表明了$c$中的所有单个元素$c_i$是通过将$A$中列向量$A_i$和$B$中行向量$B_j$中的值相乘并求和计算出来的。注意,对于einsum符号,当我们隐式地对重复指标求和(本例中为$k$)和输出中未提及的指标求和(本例中为$i$)时,求和符号可以去掉。到目前为止还不错,但我们也可以用einsum来表示更多的基本运算。例如,计算两个向量${\color{red}\mathbf{a}},{\color{blue}\mathbf{b}}\in\mathbb{R}^I$的点积可以写成:

在深度学习实现过程中经常遇到的一个问题是对高阶张量中的向量进行变换。例如,假设有一个张量包含($N$,$T$,$S$),即长度为$T$,维度为$K$,样本个数为$N$。现在我们想对改张量的维度进行转换$Q$。令${\color{red}\mathcal{T}}\in\mathbb{R}^{N \times T \times K}$,${\color{blue}\mathbf{W}}\in\mathbb{R}^{K \times Q}$。则该计算用einsum符号可以表示为:

假设张量为4维,即${\color{red}\mathcal{T}}\in\mathbb{R}^{N \times T \times K \times M}$,除了上述将第3维维度转换为$Q$,还需要对第2维进行求和,并将第一维和最后一维进行转置,那么上述全部操作使用einsum符号可以表示为:

注意,通过交换$n$和$m$来实现张量收缩结果的转置 (${\color{green}C_{mqn}}$而不是${\color{green}C_{nqm}}$)。

实现

einsum在numpy、PyTorch和TensorFlow三个模块中都有对应的实现,在numpy中为np.einsum,在PyTorch中为torch.einsum,在TensorFlow中为tf.einsum。三个einsum函数使用方式都是相同的,即einsum(equation,operands),其中equation表示einsum计算的字符串表示,operands是张量序列(计算主体)。上面的例子都可以用方程串来表示。例如,我们的第一个例子${\color{green}c_j} = \sum_i\sum_k {\color{red}A_{ik}}{\color{blue}B_{kj}}$可以写成方程字符串“${\color{red}ik},{\color{blue}kj}$ -> $j$”。

不仅在numpy中,而且在PyTorch和TensorFlow中,einsum的伟大之处在于,它可以用于任意的神经网络结构的计算图中,我们可以通过它进行反向传播。对einsum的典型调用具有以下形式

其中$\square$是一个标识张量维度的字符的占位符。从这个方程串中我们可以推断出${\color{red}\text{arg1}}$和${\color{blue}\text{arg3}}$是2维矩阵,${\color{purple}\text{arg2}}$是一个3阶张量,这个einsum运算的结果是一个矩阵。注意,einsum使用的输入数量是可变的。在上面的例子中,einsum指定了对三个参数的操作,但是它也可以用于涉及一个、两个或三个以上参数的操作。Einsum最好通过学习示例来学习,因此让我们通过PyTorch中的一些Einsum示例,它们对应于许多深度学习模型中使用的库函数。

矩阵转置

1
2
3
4
5
6
7
8
9
10
11
import torch
a = torch.arange(6).reshape(2, 3)
print(a)
torch.einsum('ij->ji', [a])

#output:
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[0, 3],
[1, 4],
[2, 5]])

求和

1
2
3
4
5
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])

#output:
tensor(15.)

列求和

1
2
3
4
5
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])

#output:
tensor([ 3., 5., 7.])

行求和

1
2
3
4
5
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->i', [a])

#output:
tensor([ 3., 12.])

矩阵-向量乘法

1
2
3
4
5
6
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])

#output:
tensor([ 5., 14.])

矩阵-矩阵乘法

1
2
3
4
5
6
7
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
torch.einsum('ik,kj->ij', [a, b])

#output:
tensor([[ 25., 28., 31., 34., 37.],
[ 70., 82., 94., 106., 118.]])

点乘

向量

1
2
3
4
5
a = torch.arange(3)
b = torch.arange(3,6) # -- a vector of length 3 containing [3, 4, 5]
torch.einsum('i,i->', [a, b])
#output:
tensor(14.)

矩阵

1
2
3
4
5
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
#output:
tensor(145.)

矩阵乘法之Hadamard

对应的元素想乘

1
2
3
4
5
6
7
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])

#output:
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])

外积(OUTER PRODUCT)

1
2
3
4
5
6
7
8
a = torch.arange(3)
b = torch.arange(3,7) # -- a vector of length 4 containing [3, 4, 5, 6]
torch.einsum('i,j->ij', [a, b])

#output:
tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])

batch矩阵相乘

1
2
3
4
5
6
7
8
9
10
11
12
13
a = torch.randn(3,2,5)
b = torch.randn(3,5,3)
torch.einsum('ijk,ikl->ijl', [a, b])

#output:
tensor([[[ 1.0886, 0.0214, 1.0690],
[ 2.0626, 3.2655, -0.1465]],

[[-6.9294, 0.7499, 1.2976],
[ 4.2226, -4.5774, -4.8947]],

[[-2.4289, -0.7804, 5.1385],
[ 0.8003, 2.9425, 1.7338]]])

张量压缩

batch矩阵乘法是张量压缩的一个特例。假设我们有两个张量,一个n阶张量${\color{red}\mathcal{A}}\in\mathbb{R}^{I_1\,\times\,\cdots\,\times\,I_n}$,一个m阶张量${\color{blue}\mathcal{B}}\in\mathbb{R}^{J_1\,\times\,\cdots\,\times\,I_m}$。以n=4, m=5为例,假设$I_2 = J_3$, $I_3=J_5$。我们可以将这两个张量在这两个维度中相乘(A是2和3,B是3和5),得到新的张量${\color{green}\mathcal{C}}\in\mathbb{R}^{I_1\,\times\,I_4\,\times\,J_1\,\times\,J_2\,\times\,J_4}$,如下:

1
2
3
4
5
6
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape

#output:
torch.Size([2, 7, 11, 13, 17])

双线性变换

如前所述,einsum可以作用于两个以上的张量。使用这个的一个例子是双线性变换。

1
2
3
4
5
6
7
8
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
torch.einsum('ik,jkl,il->ij', [a, b, c])

#output:
tensor([[ 3.8471, 4.7059, -3.0674, -3.2075, -5.2435],
[-3.5961, -5.2622, -4.1195, 5.5899, 0.4632]])

复杂的案例分析

attention

我们使用einsum实现attention机制,计算公式如下:

具体的实现如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Parameters
# -- [hidden_dimension]
bM, br, w = random_tensors([7], num=3, requires_grad=True)
# -- [hidden_dimension x hidden_dimension]
WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)

# Single application of attention mechanism
def attention(Y, ht, rt1):
# -- [batch_size x hidden_dimension]
tmp = torch.einsum("ik,kl->il", [ht, Wh]) + torch.einsum("ik,kl->il", [rt1, Wr])
Mt = F.tanh(torch.einsum("ijk,kl->ijl", [Y, WY]) + tmp.unsqueeze(1).expand_as(Y) + bM)
# -- [batch_size x sequence_length]
at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w]))
# -- [batch_size x hidden_dimension]
rt = torch.einsum("ijk,ij->ik", [Y, at]) + F.tanh(torch.einsum("ij,jk->ik", [rt1, Wt]) + br)
# -- [batch_size x hidden_dimension], [batch_size x sequence_dimension]
return rt, at

# Sampled dummy inputs
# -- [batch_size x sequence_length x hidden_dimension]
Y = random_tensors([3, 5, 7])
# -- [batch_size x hidden_dimension]
ht, rt1 = random_tensors([3, 7], num=2)

rt, at = attention(Y, ht, rt1)
at # -- print attention weights

输出结果为;

1
2
3
tensor([[ 0.1150,  0.0971,  0.5670,  0.1149,  0.1060],
[ 0.0496, 0.0470, 0.3465, 0.1513, 0.4057],
[ 0.0483, 0.5700, 0.0524, 0.2481, 0.0813]])

参考

  1. https://rockt.github.io/2018/04/30/einsum
-------------本文结束感谢您的阅读-------------
;