einsum
,这是我开发深度学习模型时最喜欢的函数。这篇文章将从PyTorch角度出发,如何使用torch.einsum
进行相关的矩阵计算操作。
einsum符号
如果你很难记住PyTorch\TensorFlow中所有中用于计算点积、外积、转置和矩阵向量或矩阵矩阵乘法的不同函数的名称或者使用方式。那么你就急需要einsum,einsum表示法是以一种优雅的方式来表达所有这些矩阵、向量或者张量的复杂计算操作。基本上使用的是一种特定于域的语言。这样做的好处不仅仅是不需要记忆或定期查找特定的库或者函数。一旦您理解并使用einsum,您将能够更快地编写更简洁、更高效的代码。目前主流深度学习框架中都包含了einsum的实现,比如PyTorch中的torch.einsum和TensorFlow中的tf.einsum。
假设我们要计算两个矩阵A∈RI×K和B∈RK×J的相乘并对列求和,即得到一个向量c∈RJ为的每一列的和。用einsum符号,我们可以把它写成:
cj=i∑k∑AikBkj=AikBkj
上述公式表明了c中的所有单个元素ci是通过将A中列向量Ai和B中行向量Bj中的值相乘并求和计算出来的。注意,对于einsum符号,当我们隐式地对重复指标求和(本例中为k)和输出中未提及的指标求和(本例中为i)时,求和符号可以去掉。到目前为止还不错,但我们也可以用einsum来表示更多的基本运算。例如,计算两个向量a,b∈RI的点积可以写成:
c=i∑aibi=aibi.
在深度学习实现过程中经常遇到的一个问题是对高阶张量中的向量进行变换。例如,假设有一个张量包含(N,T,S),即长度为T,维度为K,样本个数为N。现在我们想对改张量的维度进行转换Q。令T∈RN×T×K,W∈RK×Q。则该计算用einsum符号可以表示为:
Cntq=k∑TntkWkq=TntkWkq.
假设张量为4维,即T∈RN×T×K×M,除了上述将第3维维度转换为Q,还需要对第2维进行求和,并将第一维和最后一维进行转置,那么上述全部操作使用einsum符号可以表示为:
Cmqn=t∑k∑TntkmWkq=TntkmWkq.
注意,通过交换n和m来实现张量收缩结果的转置 (Cmqn而不是Cnqm)。
实现
einsum在numpy、PyTorch和TensorFlow三个模块中都有对应的实现,在numpy中为np.einsum,在PyTorch中为torch.einsum,在TensorFlow中为tf.einsum。三个einsum函数使用方式都是相同的,即einsum(equation,operands),其中equation表示einsum计算的字符串表示,operands是张量序列(计算主体)。上面的例子都可以用方程串来表示。例如,我们的第一个例子cj=∑i∑kAikBkj可以写成方程字符串“ik,kj -> j”。
不仅在numpy中,而且在PyTorch和TensorFlow中,einsum的伟大之处在于,它可以用于任意的神经网络结构的计算图中,我们可以通过它进行反向传播。对einsum的典型调用具有以下形式
result=einsum("□□,□□□,□□->□□",arg1,arg2,arg3)
其中□是一个标识张量维度的字符的占位符。从这个方程串中我们可以推断出arg1和arg3是2维矩阵,arg2是一个3阶张量,这个einsum运算的结果是一个矩阵。注意,einsum使用的输入数量是可变的。在上面的例子中,einsum指定了对三个参数的操作,但是它也可以用于涉及一个、两个或三个以上参数的操作。Einsum最好通过学习示例来学习,因此让我们通过PyTorch中的一些Einsum示例,它们对应于许多深度学习模型中使用的库函数。
矩阵转置
Bji=Aij
1 2 3 4 5 6 7 8 9 10 11
| import torch a = torch.arange(6).reshape(2, 3) print(a) torch.einsum('ij->ji', [a])
tensor([[0, 1, 2], [3, 4, 5]]) tensor([[0, 3], [1, 4], [2, 5]])
|
求和
b=i∑j∑Aij=Aij
1 2 3 4 5
| a = torch.arange(6).reshape(2, 3) torch.einsum('ij->', [a])
tensor(15.)
|
列求和
bj=i∑Aij=Aij
1 2 3 4 5
| a = torch.arange(6).reshape(2, 3) torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.])
|
行求和
bi=j∑Aij=Aij
1 2 3 4 5
| a = torch.arange(6).reshape(2, 3) torch.einsum('ij->i', [a])
tensor([ 3., 12.])
|
矩阵-向量乘法
ci=k∑Aikbk=Aikbk
1 2 3 4 5 6
| a = torch.arange(6).reshape(2, 3) b = torch.arange(3) torch.einsum('ik,k->i', [a, b])
tensor([ 5., 14.])
|
矩阵-矩阵乘法
Cij=k∑AikBkj=AikBkj
1 2 3 4 5 6 7
| a = torch.arange(6).reshape(2, 3) b = torch.arange(15).reshape(3, 5) torch.einsum('ik,kj->ij', [a, b])
tensor([[ 25., 28., 31., 34., 37.], [ 70., 82., 94., 106., 118.]])
|
点乘
向量
c=i∑aibi=aibi
1 2 3 4 5
| a = torch.arange(3) b = torch.arange(3,6) torch.einsum('i,i->', [a, b])
tensor(14.)
|
矩阵
c=i∑j∑AijBij=AijBij
1 2 3 4 5
| a = torch.arange(6).reshape(2, 3) b = torch.arange(6,12).reshape(2, 3) torch.einsum('ij,ij->', [a, b])
tensor(145.)
|
矩阵乘法之Hadamard
对应的元素想乘
Cij=AijBij
1 2 3 4 5 6 7
| a = torch.arange(6).reshape(2, 3) b = torch.arange(6,12).reshape(2, 3) torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.], [ 27., 40., 55.]])
|
外积(OUTER PRODUCT)
Cij=aibj
1 2 3 4 5 6 7 8
| a = torch.arange(3) b = torch.arange(3,7) torch.einsum('i,j->ij', [a, b])
tensor([[ 0., 0., 0., 0.], [ 3., 4., 5., 6.], [ 6., 8., 10., 12.]])
|
batch矩阵相乘
Cijl=k∑AijkBikl=AijkBikl
1 2 3 4 5 6 7 8 9 10 11 12 13
| a = torch.randn(3,2,5) b = torch.randn(3,5,3) torch.einsum('ijk,ikl->ijl', [a, b])
tensor([[[ 1.0886, 0.0214, 1.0690], [ 2.0626, 3.2655, -0.1465]],
[[-6.9294, 0.7499, 1.2976], [ 4.2226, -4.5774, -4.8947]],
[[-2.4289, -0.7804, 5.1385], [ 0.8003, 2.9425, 1.7338]]])
|
张量压缩
batch矩阵乘法是张量压缩的一个特例。假设我们有两个张量,一个n阶张量A∈RI1×⋯×In,一个m阶张量B∈RJ1×⋯×Im。以n=4, m=5为例,假设I2=J3, I3=J5。我们可以将这两个张量在这两个维度中相乘(A是2和3,B是3和5),得到新的张量C∈RI1×I4×J1×J2×J4,如下:
Cpstuv=q∑r∑ApqrsBtuqvr=ApqrsBtuqvr
1 2 3 4 5 6
| a = torch.randn(2,3,5,7) b = torch.randn(11,13,3,17,5) torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
torch.Size([2, 7, 11, 13, 17])
|
双线性变换
如前所述,einsum可以作用于两个以上的张量。使用这个的一个例子是双线性变换。
Dij=k∑l∑AikBjklCil=AikBjklCil
1 2 3 4 5 6 7 8
| a = torch.randn(2,3) b = torch.randn(5,3,7) c = torch.randn(2,7) torch.einsum('ik,jkl,il->ij', [a, b, c])
tensor([[ 3.8471, 4.7059, -3.0674, -3.2075, -5.2435], [-3.5961, -5.2622, -4.1195, 5.5899, 0.4632]])
|
复杂的案例分析
attention
我们使用einsum实现attention机制,计算公式如下:
Mtαtrt=tanh(WyY+(Whht+Wrrt−1)⊗eL)=softmax(wTMt)=YαtT+tanh(Wtrt−1)Mtαtrt∈Rk×L∈RL∈Rk
具体的实现如下所示:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
|
bM, br, w = random_tensors([7], num=3, requires_grad=True)
WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)
def attention(Y, ht, rt1): tmp = torch.einsum("ik,kl->il", [ht, Wh]) + torch.einsum("ik,kl->il", [rt1, Wr]) Mt = F.tanh(torch.einsum("ijk,kl->ijl", [Y, WY]) + tmp.unsqueeze(1).expand_as(Y) + bM) at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w])) rt = torch.einsum("ijk,ij->ik", [Y, at]) + F.tanh(torch.einsum("ij,jk->ik", [rt1, Wt]) + br) return rt, at
Y = random_tensors([3, 5, 7])
ht, rt1 = random_tensors([3, 7], num=2)
rt, at = attention(Y, ht, rt1) at
|
输出结果为;
1 2 3
| tensor([[ 0.1150, 0.0971, 0.5670, 0.1149, 0.1060], [ 0.0496, 0.0470, 0.3465, 0.1513, 0.4057], [ 0.0483, 0.5700, 0.0524, 0.2481, 0.0813]])
|
参考
- https://rockt.github.io/2018/04/30/einsum