核心思想
核心思想
ViT的训练其实就在做一件事情:把图片打成Patch,送入Transformer Encoder,然后拿对
应位置的向量,过一个简单的softmax多分类模型,去预测原始图片中描绘的物体类别即可。
ViT的目的不是让这个softmax分类模型强大,而是让这个分类模型的输入强大。这个输入就是
Transformer Encoder提炼出来的特征。分类模型越简单,对特征的要求就越高。
所以,为什么说Transformer开启了大一统模型的预训练大门呢?主要原因就在于它对特征的
提炼能力——这样我们就可以拿这个特征去做更多有趣的任务了。这也是ViT能成为后续多模态
backbone的主要原因。
网络结构
下图是原论文中给出的关于Vision Transformer(ViT)的模型框架。简单而言,模型由三个模块组
成:
(1)输入层处理模块:将图像转化为序列
主要包括图像块分割与嵌入(Patch Embedding),位置嵌入(Positional Embedding),分类
Token(CLS Token)插入。
(2)Transformer 编码器模块:提取全局特征的核心
主要包括多头自注意力(MHSA)子模块(捕捉全局依赖),多层感知机(MLP)子模块)增
强非线性表达), 残差连接与层归一化的作用。
(3)输出层分类模块:从特征到任务结果,分类头设计,结果输出。
手算Vision Transformer
Patch Embedding
功能:通过Patch Embedding操作,得到一维向量。
Patch图像块
ViT 利用等分窗口图片块的思想,将图像分成块,每个小块称作Patch,每个Patch块看作NLP
Transformer中的一个单词。
例如,假设原始图片尺寸大小为:224*224*3(H*W*C)。
每个Patch的尺寸设为16(P=16),则每个Patch下图片的大小为:16*16*3,Patch共有 (224/16)
x (224/16) = 14 x 14=196个。
Patch Embedding(Patch to Token)
Patch Embedding将每一个Patch的矩阵拉伸成为一个1维向量,从而获得近似词向量堆叠的效果。
如上图所示,每个Patch对应着一个token,将每个Patch展平,则得到输入矩阵X,其大小为(196,
768),其中16*16*3=768,也就是每个token是768维。通过这样的方式,我们成功将图像数据处
理成自然语言的向量表达方式。那么现在问题来了,对于图中每一个16*16*3的小方块,我要怎么
把它拉平成1*768维度的向量呢?比如说,我先把第一个channel拉成一个向量,然后再往后依次
接上第二个channel、第三个channel拉平的向量。但这种办法下,同一个pixel本来是三个channel
的值共同表达的,现在变成竖直的向量之后,这三个值的距离反而远了。基于这个原因,你可能会
想一些别的拉平方式,但归根究底它们都有一个共同的问题:太规则化,太主观。
ViT中最终采用CNN进行特征提取,具体方案如下:
采用768个16*16*3尺寸的卷积核,stride=16,padding=0。这样我们就能得到14*14*768大小的
特征图。如图所示,特征图中每一个1*1*768大小的子特征图,都是由卷积核对第一块patch做处理
而来,因此它就能表示第一块patch的token向量。
Patch Embedding之后,会经过 Class Embedding 和 Position Embedding 两个过程。
Class Embedding
功能:通过Class Embedding操作,得到类别向量。
Class Embedding 主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增
加一个类别值,通常是加在向量的第一位。例如,Patch Embedding得到的196维的向量加上
Class Embedding 后,变成197维。Class Embedding 用于最后的类别输出,可参考BERT 的
class token,整个过程示意如下图:
Class Embedding是可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输
出来决定最后的输出类别。由于输入是 16x16个Patch,所以输出进行分类时是取16x16个Class
Embedding进行分类。
Position Embedding
功能:通过Position Embedding操作,得到位置向量。
图像切分重排后,失去了位置信息,并且Transformer的内部运算是空间信息无关的,所以需要把
位置信息编码重新传进网络。一句话来说是在原来的输入上,加上表示位置信息的向量。
Encoder编码器
Transformer Encoder是两个块的堆叠,然后再整体叠加 L 次。这两个块指的是:
(1)LayerNorm + Multi-Head Attention;
(2)LayerNorm + MLP;
这几个模块在我直接的系列文章中已经详细介绍了,这里的话就不过多的介绍了
MLP Head详解
上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,
输入的是[197, 768]输出的还是[197, 768]。在Transformer Encoder后其实还有一个Layer
Norm没有画出来(目的是:稳定分类头输入)。这里我们只是需要分类的信息,所以我们只需要提
取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着
我们通过MLP Head得到我们最终的分类结果。
正向传播代码实现
假设我们的最后是1000个分类的:
假设原始图像的大小是224*224*3的
输入端:
【1】经过卷积,然后拉平之后变成196×768大小的张量
【2】经过Class Embedding,变成了197×768大小的张量
【3】经过Position Embedding,变成了197×768大小的张量
编码器:
【1】经过编码器之后,大小不发生变化,仍然是197×768大小的张量
【2】提取CLS Token,大小是1×768
输出端:
【1】经过一个线性层,变成了1×1000的张量,每个数字代表着这个类别的概率
前向传播
# ============================== 1. 环境依赖导入 ============================== import torch import torch.nn as nn import torch.nn.functional as F # ============================== 2. ViT核心组件定义(修复版) ============================== class PatchEmbedding(nn.Module): """图像块嵌入:将图像分割为块并映射为Token序列""" def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() self.num_patches = (img_size // patch_size) ** 2 # 14×14=196(Patch Token数量) self.conv = nn.Conv2d( in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size ) def forward(self, x): # 输入x:[batch_size, 3, 224, 224](PyTorch图像格式) print(f"PatchEmbedding输入形状: {x.shape}") # 卷积:[B, 3, 224, 224] → [B, 768, 14, 14](每个像素对应1个16×16图像块) x = self.conv(x) print(f"PatchEmbedding卷积后形状: {x.shape}") # 展平:[B, 768, 14, 14] → [B, 768, 196](196=14×14,按列展平) x = x.flatten(2) print(f"PatchEmbedding展平后形状: {x.shape}") # 转置:[B, 768, 196] → [B, 196, 768](适配Transformer的[B, N, D]输入格式) x = x.transpose(1, 2) print(f"PatchEmbedding输出形状: {x.shape}\n") return x class MultiHeadSelfAttention(nn.Module): """多头自注意力:捕捉Token间全局依赖关系""" def __init__(self, embed_dim=768, num_heads=12, dropout=0.1): super().__init__() self.num_heads = num_heads self.head_dim = embed_dim // num_heads # 768÷12=64(单个头的维度) assert self.head_dim * num_heads == embed_dim, "embed_dim需被num_heads整除" self.qkv = nn.Linear(embed_dim, embed_dim * 3) # 一次性生成Q/K/V self.proj = nn.Linear(embed_dim, embed_dim) # 多头结果融合 self.dropout = nn.Dropout(dropout) def forward(self, x): # 输入x:[B, 197, 768](196个Patch Token + 1个CLS Token) batch_size, num_tokens, embed_dim = x.shape print(f"MHSA输入形状: {x.shape}") # 1. 生成Q/K/V:[B, 197, 768] → [B, 197, 2304] → 拆分为3个[B, 197, 768] qkv = self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim) q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) # 每个形状:[B, 12, 197, 64] print(f"MHSA中Q/K/V形状(单头): {q.shape}") # 2. 计算注意力权重:Q×K^T / √d_k → [B, 12, 197, 197](每行是1个Token对所有Token的注意力) attn_weights = torch.matmul(q, k.transpose(-2, -1)) attn_weights = attn_weights / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)) print(f"MHSA注意力权重形状: {attn_weights.shape}") # 3. Softmax归一化 + Dropout:确保权重和为1,防止过拟合 attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = self.dropout(attn_weights) # 4. 注意力加权V:[B, 12, 197, 197] × [B, 12, 197, 64] → [B, 12, 197, 64] attn_output = torch.matmul(attn_weights, v) print(f"MHSA单头输出形状: {attn_output.shape}") # 5. 多头拼接:[B, 12, 197, 64] → [B, 197, 768](12×64=768,恢复原维度) attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, num_tokens, embed_dim) print(f"MHSA多头拼接后形状: {attn_output.shape}") # 6. 融合输出:[B, 197, 768] → [B, 197, 768](线性层调整特征) x = self.proj(attn_output) x = self.dropout(x) print(f"MHSA最终输出形状: {x.shape}\n") return x class TransformerEncoderLayer(nn.Module): """Transformer编码器层:LN→MHSA→残差 + LN→MLP→残差(Pre-LN结构)""" def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.1): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) # MHSA前的归一化 self.norm2 = nn.LayerNorm(embed_dim) # MLP前的归一化 self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout) self.mlp = nn.Sequential( nn.Linear(embed_dim, int(embed_dim * mlp_ratio)), # 768→3072(维度扩张) nn.GELU(), # 激活函数(比ReLU更平滑) nn.Dropout(dropout), nn.Linear(int(embed_dim * mlp_ratio), embed_dim), # 3072→768(维度收缩) nn.Dropout(dropout) ) def forward(self, x): # 输入x:[B, 197, 768](来自上一层编码器或输入层) print(f"EncoderLayer输入形状: {x.shape}") # 1. MHSA分支:残差连接(x + MHSA(LN(x)))→ 保持维度不变 x = x + self.attn(self.norm1(x)) print(f"EncoderLayer MHSA+残差后形状: {x.shape}") # 2. MLP分支:残差连接(x + MLP(LN(x)))→ 保持维度不变 x = x + self.mlp(self.norm2(x)) print(f"EncoderLayer MLP+残差后形状: {x.shape}\n") return x class VisionTransformer(nn.Module): """完整ViT模型:PatchEmbedding→位置嵌入+CLS Token→编码器堆叠→分类头(修复pos_embed维度)""" def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_heads=12, num_layers=12, num_classes=1000, dropout=0.1): super().__init__() # 1. 图像块嵌入模块 self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) num_patches = self.patch_embed.num_patches # 196(Patch Token数量) # 2. 修复:位置嵌入长度改为197(适配CLS Token + Patch Token) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # [1, 197, 768] # 3. CLS Token(用于聚合全局特征) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # [1, 1, 768] self.dropout = nn.Dropout(dropout) # 4. 12层Transformer编码器堆叠 self.encoders = nn.ModuleList([ TransformerEncoderLayer(embed_dim, num_heads, mlp_ratio=4.0, dropout=dropout) for _ in range(num_layers) ]) # 5. 输出层分类头(LN+线性层) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes) def forward(self, x): batch_size = x.shape[0] # 输入x:[B, 3, 224, 224](原始RGB图像) print(f"ViT整体输入形状: {x.shape}\n") # -------------------------- 步骤1:图像→Patch Token序列 -------------------------- x = self.patch_embed(x) # 输出:[B, 196, 768] # -------------------------- 步骤2:插入CLS Token -------------------------- # 扩展CLS Token到batch维度:[1, 1, 768] → [B, 1, 768] cls_token = self.cls_token.expand(batch_size, -1, -1) # 拼接:[B, 1, 768] + [B, 196, 768] → [B, 197, 768] x = torch.cat((cls_token, x), dim=1) print(f"插入CLS Token后形状: {x.shape}\n") # -------------------------- 步骤3:添加位置嵌入(修复后可正常相加) -------------------------- # 现在x和pos_embed形状均为[B, 197, 768],可直接广播相加 x = x + self.pos_embed x = self.dropout(x) print(f"添加位置嵌入后形状: {x.shape}") print(f"Dropout后形状: {x.shape}\n") # -------------------------- 步骤4:12层编码器特征提取 -------------------------- for i, encoder in enumerate(self.encoders): print(f"=== 第{i+1}层编码器 ===") x = encoder(x) # 每层输入输出均为[B, 197, 768] # -------------------------- 步骤5:分类预测 -------------------------- # 提取CLS Token:[B, 197, 768] → [B, 768](仅用第0个Token的全局特征) cls_token_final = x[:, 0, :] print(f"\n提取CLS Token后形状: {cls_token_final.shape}") # LN归一化:[B, 768] → [B, 768](稳定分类头输入) cls_token_final = self.norm(cls_token_final) print(f"LN归一化后形状: {cls_token_final.shape}") # 分类头:[B, 768] → [B, 1000](1000为ImageNet-1K类别数) logits = self.head(cls_token_final) print(f"分类头输出形状: {logits.shape}") return logits # ============================== 3. 模型测试(可直接运行,无维度报错) ============================== if __name__ == "__main__": # 1. 设置随机种子(确保结果可复现) torch.manual_seed(42) # 2. 构造测试输入(batch_size=1,RGB图像:[B, C, H, W]) dummy_img = torch.randn(1, 3, 224, 224) # 随机生成符合格式的图像 # 3. 实例化ViT-Base模型(参数与论文一致) vit_model = VisionTransformer( img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_heads=12, num_layers=12, num_classes=1000 ) print(f"ViT-Base模型初始化完成(参数总数:{sum(p.numel() for p in vit_model.parameters()):,})\n") # 4. 前向传播测试(禁用梯度计算,仅推理) with torch.no_grad(): output = vit_model(dummy_img) # 5. 输出结果分析 probabilities = F.softmax(output, dim=1) # 转化为类别概率(和为1) pred_class = torch.argmax(probabilities, dim=1).item() # 预测概率最高的类别索引 print(f"\n=== 最终结果 ===") print(f"模型输出logits形状: {output.shape}") # [1, 1000] print(f"类别概率分布形状: {probabilities.shape}") # [1, 1000] print(f"预测概率最高的类别索引: {pred_class}") # 随机值(模型未训练)
ViT-Base模型初始化完成(参数总数:86,567,656)
ViT整体输入形状: torch.Size([1, 3, 224, 224])
PatchEmbedding输入形状: torch.Size([1, 3, 224, 224])
PatchEmbedding卷积后形状: torch.Size([1, 768, 14, 14])
PatchEmbedding展平后形状: torch.Size([1, 768, 196])
PatchEmbedding输出形状: torch.Size([1, 196, 768])
插入CLS Token后形状: torch.Size([1, 197, 768])
添加位置嵌入后形状: torch.Size([1, 197, 768])
Dropout后形状: torch.Size([1, 197, 768])
=== 第1层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第2层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第3层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第4层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第5层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第6层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第7层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第8层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第9层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第10层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第11层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
=== 第12层编码器 ===
EncoderLayer输入形状: torch.Size([1, 197, 768])
MHSA输入形状: torch.Size([1, 197, 768])
MHSA中Q/K/V形状(单头): torch.Size([1, 12, 197, 64])
MHSA注意力权重形状: torch.Size([1, 12, 197, 197])
MHSA单头输出形状: torch.Size([1, 12, 197, 64])
MHSA多头拼接后形状: torch.Size([1, 197, 768])
MHSA最终输出形状: torch.Size([1, 197, 768])
EncoderLayer MHSA+残差后形状: torch.Size([1, 197, 768])
EncoderLayer MLP+残差后形状: torch.Size([1, 197, 768])
提取CLS Token后形状: torch.Size([1, 768])
LN归一化后形状: torch.Size([1, 768])
分类头输出形状: torch.Size([1, 1000])
=== 最终结果 ===
模型输出logits形状: torch.Size([1, 1000])
类别概率分布形状: torch.Size([1, 1000])
预测概率最高的类别索引: 302