attention
优劣
attention机制是对卷积局部性、位移不变性这两种特性的补充,通过全局的attention计算来实现较大的感受野。attention机制的优势是能够以较大的感受野来实现任务性能的提升,可以解决卷积的一些暂时不好解决的问题,但是它的劣势也比较明显,一来是计算量的问题,二来是不一定所有任务都需要attention,它也不是万能的东西。
内容

from torch import nn
class Attention(nn.Module):
def __init__(
self,
dim, #输入token的dim
num_heads=8, #多头注意力中head的个数
qkv_bias=False, #在生成qkv时是否使用偏置,默认否
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads #计算每一个head需要传入的dim
self.scale = head_dim**-0.5 #head_dim的-0.5次方,即1/根号d_k,即理论公式里的分母根号d_k
self.qkv = nn.Linear(dim, dim * 3,
bias=qkv_bias) #qkv是通过1个全连接层参数为dim和3dim进行初始化的,也可以使用3个全连接层参数为dim和dim进行初始化,二者没有区别,
self.attn_drop = nn.Dropout(attn_drop) #定义dp层 比率attn_drop
self.proj = nn.Linear(dim, dim) #再定义一个全连接层,是 将每一个head的结果进行拼接的时候乘的那个矩阵W^O
self.proj_drop = nn.Dropout(proj_drop) #定义dp层 比率proj_drop
def forward(self, x): #正向传播过程
#输入是[batch_size,
# num_patches+1, (base16模型的这个数是14*14)
# total_embed_dim(base16模型的这个数是768)]
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
#qkv->[batchsize, num_patches+1, 3*total_embed_dim]
#reshape->[batchsize, num_patches+1, 3, num_heads, embed_dim_per_head]
#permute->[3, batchsize, num_heads, num_patches+1, embed_dim_per_head]
q, k, v = qkv[0], qkv[1], qkv[2]
# make torchscript happy (cannot use tensor as tuple)
#q、k、v大小均[batchsize, num_heads, num_patches+1, embed_dim_per_head]
attn = (q @ k.transpose(-2, -1)) * self.scale
#现在的操作都是对每个head进行操作
#transpose是转置最后2个维度,@就是矩阵乘法的意思
#q [batchsize, num_heads, num_patches+1, embed_dim_per_head]
#k^T[batchsize, num_heads, embed_dim_per_head, num_patches+1]
#q*k^T=[batchsize, num_heads, num_patches+1, num_patches+1]
#self.scale=head_dim的-0.5次方
#至此完成了(Q*K^T)/根号d_k的操作
attn = attn.softmax(dim=-1)
#dim=-1表示在得到的结果的每一行上进行softmax处理,-1就是最后1个维度
#至此完成了softmax[(Q*K^T)/根号d_k]的操作
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
#@->[batchsize, num_heads, num_patches+1, embed_dim_per_head]
#这一步矩阵乘积就是加权求和
#transpose->[batchsize, num_patches+1, num_heads, embed_dim_per_head]
#reshape->[batchsize, num_patches+1, num_heads*embed_dim_per_head]即[batchsize, num_patches+1, total_embed_dim]
#reshape实际上就实现了concat拼接
x = self.proj(x)
#将上一步concat的结果通过1个线性映射,通常叫做W,此处用全连接层实现
x = self.proj_drop(x)
#dropout
#至此完成了softmax[(Q*K^T)/根号d_k]*V的操作
#一个head的attention的全部操作就实现了
return x