# Vision Transformer
# 关于ViT
Transformer
自2017
年06
月由谷歌团队在论文Attention Is All You Need
中提出后,给自然语言处理领域带去了深远的影响,其并行化处理不定长序列的能力及自注意力机制表现亮眼。根据以往的惯例,一个新的机器学习方法往往先在NLP
领域带来突破,然后逐渐被应用到计算机视觉领域。时间来到2020
年10
月,同样是谷歌团队提出了将Transformer
应用到视觉任务的方法,Vision Transformer(ViT)
。
论文AN IMAGE IS WORTH 16X16 WORDS:
TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE (opens new window)
关于对Transformer
的介绍可以参考Transformer 介绍
(opens new window)。
将Transformer
应用于视觉任务的一种想法是将图像每个像素都flatten
,得到一个表示图像的序列,作为模型的输入。但对使用自注意力模块的transformer
来说,这种方法随着图像分辨率的变大,计算复杂度也变得很高,因为scaled dot self attention
计算时640*640
的图像,序列长度transformer
所能处理的序列长度。
在ViT
中,作者是将输入图像等分成大小为16X16
的patch
,然后通过image embedding
将输入从NCHW
转换成(N, hidden_dim, (n_h * n_w))
, n_h
和n_w
是H//patch_size
和W//patch_size
的大小,flatten
后得到长度为
Image Embedding
后得到的结果shape
为[N, n_h*n_w, hidden_dim]
,作者将ViT
用于分类任务,同BERT
的思路,会在输入序列前插入一个cls_token
用来输出每个图像所属的类别。处理后,输入给encoder
的张量shape
为[N, n_h*n_w+1, hidden_dim]
。
ViT编码器的输出是的shape
为:(N, L, hidden_dim)
, L
是序列的长度,在这里为197=14*14+1
,得到编码器的输出后,只取序列的首元素,shape
为(N, hidden_dim)
作为分类器的输入。从这里会发现,这种方式舍弃了编码器处理得到的大部分信息,只使用了cls_token
部分。
这里理解,cls_token
有点像完形填空中的单词补全,这里做图像图类,待补充的词元组是类别,而这也正是我们关心的部分。至于编码器提取的序列其他信息,因为没有使用就直接舍弃了。假如说,拿编码器输出序列除cls_token
外的部分,再接一个分类器,整个分割效果会不会更好呢?
# 代码分析
输入数据x的shape, NCHW
以(1,3,224,224)为例,
- Image Embedding
处理输入数据,NCHW
变成(N, (n_h * n_w), hidden_dim)的张量,
n_h/n_w
是除以patch_size
后得到的图像的大小。
def _process_input(self, x: torch.Tensor) -> torch.Tensor:
n, c, h, w = x.shape
p = self.patch_size
n_h = h // p
n_w = w // p
x = self.conv_proj(x) # 卷积层,将`NCHW`变成`N,hidden_dim,n_h*n_w`,conv的stride=patch_size
x = x.reshape(n, self.hidden_dim, n_h * n_w)
x = x.permute(0, 2, 1)
return x
- cls_token
再将_process_input
处理后的数据与self.cls_token
合并,得到shape
为N, n_h*n_w+1, hidden_dim
的序列作为编码器的输入。
batch_class_token = self.class_token.expand(n, -1, -1) # (1,1,hidden_dim) ->(N,1,hidden_dim)
x = torch.cat([batch_class_token, x], dim=1) # `N, n_h*n_w+1, hidden_dim`
- position_embedding
编码器是标准的多头注意力Transoformer
,在torchvision
提供的模型中位置嵌入使用的可学习的参数,位置参数直接和输入数据相加
pos_embedding = torch.nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
input = input + pos_embedding
- MSA
多头注意力的实现使用了torch.nn.MultiheadAttention
模块,
CLASStorch.nn.MultiheadAttention(embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None, vdim=None,
batch_first=False,
device=None, dtype=None)
embed_dim
是模型的维度,num_heads
是头数kdim/vdim
是Query
权重和Value
权重的维度,按transformer
论文的介绍Query.weight.shape=(embed_dim, kdim)
,但是在pytorch
目前实现的MultiheadAttention
中必须kdim==vdim==embed_dim
,否则计算将报错batch_first
控制支持的输入数据shape
,为True
是支持N,L,hidden_dim
维度的输入。
一个使用的例子:
embed_dim = 512
num_heads = 16
multihead_attn = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, kdim=512, vdim=512)
x = torch.randn(2,12, 512)
query, key, value = x, x, x
output = multihead_attn(query,key,value)
print(output[0].shape) # (1, 12, 512)
回到ViT
,经过编码器后得到的输出shape
为(N,L+1,hidden_dim)
,然后取输出序列cls_token
对应位置的数据,作为特征送入线性分类器即可得到分类结果。
linear_classifier = nn.Linear(hidden_dim, num_classes)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0] # shape(N, hidden_dim)
linear_classifier(x) # shape(N, num_class)
torchvision 在0.13版本实现了ViT
模型。