# Vision Transformer

# 关于ViT

Transformer201706月由谷歌团队在论文Attention Is All You Need中提出后,给自然语言处理领域带去了深远的影响,其并行化处理不定长序列的能力及自注意力机制表现亮眼。根据以往的惯例,一个新的机器学习方法往往先在NLP领域带来突破,然后逐渐被应用到计算机视觉领域。时间来到202010月,同样是谷歌团队提出了将Transformer应用到视觉任务的方法,Vision Transformer(ViT)

论文AN IMAGE IS WORTH 16X16 WORDS:
TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE (opens new window)

关于对Transformer的介绍可以参考Transformer 介绍
(opens new window)

Transformer应用于视觉任务的一种想法是将图像每个像素都flatten,得到一个表示图像的序列,作为模型的输入。但对使用自注意力模块的transformer来说,这种方法随着图像分辨率的变大,计算复杂度也变得很高,因为scaled dot self attention计算时的复杂度是序列长度的平方。譬如对于640*640的图像,序列长度将达到409600,这远远超出当前transformer所能处理的序列长度。

ViT中,作者是将输入图像等分成大小为16X16patch,然后通过image embedding将输入从NCHW转换成(N, hidden_dim, (n_h * n_w)), n_hn_wH//patch_sizeW//patch_size的大小,flatten后得到长度为的输入序列。

Image Embedding后得到的结果shape[N, n_h*n_w, hidden_dim],作者将ViT用于分类任务,同BERT的思路,会在输入序列前插入一个cls_token用来输出每个图像所属的类别。处理后,输入给encoder的张量shape[N, n_h*n_w+1, hidden_dim]

ViT编码器的输出是的shape为:(N, L, hidden_dim), L是序列的长度,在这里为197=14*14+1,得到编码器的输出后,只取序列的首元素,shape(N, hidden_dim)作为分类器的输入。从这里会发现,这种方式舍弃了编码器处理得到的大部分信息,只使用了cls_token部分。

这里理解,cls_token有点像完形填空中的单词补全,这里做图像图类,待补充的词元组是类别,而这也正是我们关心的部分。至于编码器提取的序列其他信息,因为没有使用就直接舍弃了。假如说,拿编码器输出序列除cls_token外的部分,再接一个分类器,整个分割效果会不会更好呢?

# 代码分析

输入数据x的shape, NCHW以(1,3,224,224)为例,

  • Image Embedding

处理输入数据,NCHW变成(N, (n_h * n_w), hidden_dim)的张量,

n_h/n_w是除以patch_size后得到的图像的大小。

def _process_input(self, x: torch.Tensor) -> torch.Tensor:
  n, c, h, w = x.shape
  p = self.patch_size
  n_h = h // p
  n_w = w // p
  x = self.conv_proj(x) # 卷积层,将`NCHW`变成`N,hidden_dim,n_h*n_w`,conv的stride=patch_size
  x = x.reshape(n, self.hidden_dim, n_h * n_w)
  x = x.permute(0, 2, 1)
  return x
  • cls_token
    再将_process_input处理后的数据与self.cls_token合并,得到shapeN, n_h*n_w+1, hidden_dim的序列作为编码器的输入。
batch_class_token = self.class_token.expand(n, -1, -1) # (1,1,hidden_dim) ->(N,1,hidden_dim)
x = torch.cat([batch_class_token, x], dim=1) # `N, n_h*n_w+1, hidden_dim`
  • position_embedding

编码器是标准的多头注意力Transoformer,在torchvision提供的模型中位置嵌入使用的可学习的参数,位置参数直接和输入数据相加

pos_embedding = torch.nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT

input = input + pos_embedding
  • MSA

多头注意力的实现使用了torch.nn.MultiheadAttention模块,

CLASStorch.nn.MultiheadAttention(embed_dim, 
                                 num_heads, 
                                 dropout=0.0, 
                                 bias=True, 
                                 add_bias_kv=False, 
                                 add_zero_attn=False, 
                                 kdim=None, vdim=None, 
                                 batch_first=False, 
                                 device=None, dtype=None)
  • embed_dim是模型的维度,num_heads是头数
  • kdim/vdimQuery权重和Value权重的维度,按transformer论文的介绍Query.weight.shape=(embed_dim, kdim),但是在pytorch目前实现的MultiheadAttention中必须kdim==vdim==embed_dim,否则计算将报错
  • batch_first控制支持的输入数据shape,为True是支持N,L,hidden_dim维度的输入。
MultiHead(Q,K,V)=Concat(head1,...,headh)Woheadi=Attention(QWiQ,KWiK,VWiV)

一个使用的例子:

embed_dim = 512
num_heads = 16
multihead_attn = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, kdim=512, vdim=512)

x = torch.randn(2,12, 512)

query, key, value = x, x, x
output = multihead_attn(query,key,value)
print(output[0].shape) # (1, 12, 512)

回到ViT,经过编码器后得到的输出shape(N,L+1,hidden_dim),然后取输出序列cls_token对应位置的数据,作为特征送入线性分类器即可得到分类结果。

linear_classifier = nn.Linear(hidden_dim, num_classes)

x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0] # shape(N, hidden_dim)
linear_classifier(x) # shape(N, num_class)

torchvision 在0.13版本实现了ViT模型。