用Python代码透视ViT的CLS Token:从理论到可视化实践
在计算机视觉领域,Vision Transformer(ViT)彻底改变了我们处理图像的方式。与传统CNN不同,ViT将图像分割为小块(patch),然后像处理自然语言一样处理这些视觉"词汇"。但其中最令人困惑的元素之一莫过于那个神秘的CLS Token——它既不是图像的一部分,却又最终决定了整个分类结果。今天,我们将用PyTorch代码一步步拆解这个设计精妙的"观察者",看看它究竟如何从图像块中提炼出全局理解。
1. ViT与CLS Token基础:为什么需要这个特殊标记?
ViT模型的核心思想借鉴了自然语言处理中的Transformer架构。当我们把一张224×224的图像分割成16×16的小块时,会得到196个视觉token。但实际输入Transformer的序列长度却是197——多出来的那个就是CLS(classification)Token。
这个设计解决了视觉Transformer的一个关键问题:在自然语言处理中,我们可以用第一个或最后一个token作为整个句子的表示;但在图像处理中,没有一个patch天然适合代表整张图片。CLS Token就像一个"虚拟观察者",通过自注意力机制与所有图像块交互,最终聚合全局信息。
import torch from timm.models.vision_transformer import VisionTransformer # 初始化一个微型ViT模型 model = VisionTransformer( img_size=32, patch_size=8, embed_dim=128, depth=4, num_heads=4, num_classes=10 ) # 模拟一个batch的输入图像 (bs, 3, 32, 32) dummy_input = torch.randn(2, 3, 32, 32) output = model(dummy_input) print(f"输出形状: {output.shape}") # 应该是 (2, 10)2. 解剖ViT前向传播:CLS Token的生命周期
让我们深入模型内部,跟踪CLS Token的完整旅程。在timm库的实现中,关键步骤可以分为四个阶段:
- Patch嵌入:将图像分割并线性投影为patch embeddings
- CLS Token添加:将可学习的CLS Token拼接到patch序列首部
- 位置编码:为所有token添加空间位置信息
- Transformer编码:通过多层自注意力机制进行特征交互
# 手动拆解前向过程 class ManualViT(torch.nn.Module): def __init__(self): super().__init__() self.patch_embed = torch.nn.Linear(8*8*3, 128) # 模拟patch嵌入 self.cls_token = torch.nn.Parameter(torch.randn(1, 1, 128)) # 可学习的CLS Token self.pos_embed = torch.nn.Parameter(torch.randn(1, 16+1, 128)) # 位置编码 def forward(self, x): # 步骤1: 图像分块并嵌入 patches = x.unfold(2, 8, 8).unfold(3, 8, 8) # (bs, 3, 4, 4, 8, 8) patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(2, 16, 8*8*3) patch_embeddings = self.patch_embed(patches) # (2, 16, 128) # 步骤2: 添加CLS Token cls_tokens = self.cls_token.expand(2, -1, -1) # 复制到batch size embeddings = torch.cat((cls_tokens, patch_embeddings), dim=1) # (2, 17, 128) # 步骤3: 添加位置编码 embeddings += self.pos_embed return embeddings3. 可视化注意力:CLS Token看到了什么?
理解CLS Token最直观的方式是观察它的注意力模式。我们可以提取Transformer层中的注意力权重,看看CLS Token更关注图像的哪些区域。
import matplotlib.pyplot as plt import numpy as np def visualize_attention(model, image): # 获取注意力权重 with torch.no_grad(): outputs = model.forward_features(image.unsqueeze(0)) attn = model.blocks[0].attn.get_attention_map() # 第一层的注意力图 # CLS Token对其他patch的注意力 cls_attention = attn[0, :, 0, 1:].mean(0) # 平均所有注意力头 # 调整形状为网格 side = int(np.sqrt(cls_attention.shape[-1])) cls_attention = cls_attention.reshape(side, side) # 可视化 plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.imshow(image.permute(1, 2, 0)) plt.title("原始图像") plt.subplot(1, 2, 2) plt.imshow(cls_attention, cmap='hot') plt.title("CLS Token注意力热图") plt.colorbar() plt.show() # 使用真实图像测试 from PIL import Image img = Image.open("cat.jpg").resize((224, 224)) img_tensor = torch.tensor(np.array(img)).float().permute(2, 0, 1) / 255.0 visualize_attention(model, img_tensor)4. CLS Token的替代方案:均值池化对比实验
ViT论文中提到,除了使用CLS Token,也可以采用全局平均池化(GAP)的方式,即对所有patch token取平均作为图像表示。让我们比较这两种方法的实际效果:
| 特征提取方式 | 参数量 | ImageNet Top-1准确率 | 训练稳定性 |
|---|---|---|---|
| CLS Token | 不变 | 78.8% | 较高 |
| 全局平均池化 | 不变 | 78.3% | 稍低 |
虽然两种方法最终性能相近,但CLS Token有几个潜在优势:
- 更明确的语义角色:专门负责分类任务
- 更稳定的训练:不受特定patch噪声影响
- 可解释性:可以分析注意力模式
# 比较两种分类头实现 class ViTWithGAP(VisionTransformer): def forward_head(self, x): # 使用全局平均池化代替CLS Token return self.head(x[:, 1:].mean(1)) # 排除CLS Token取平均 gap_model = ViTWithGAP(img_size=224, patch_size=16) cls_model = VisionTransformer(img_size=224, patch_size=16) # 测试两种模型 dummy_input = torch.randn(1, 3, 224, 224) print(f"GAP输出: {gap_model(dummy_input).shape}") print(f"CLS输出: {cls_model(dummy_input).shape}")5. 进阶实验:从头训练观察CLS Token演化
最有趣的部分莫过于观察CLS Token在训练过程中如何变化。我们可以设计一个实验,定期保存模型并提取CLS Token的注意力模式:
from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader # 准备数据集 train_data = CIFAR10(root='./data', train=True, download=True) train_loader = DataLoader(train_data, batch_size=32, shuffle=True) # 训练循环 def train_model(model, epochs=10): optimizer = torch.optim.Adam(model.parameters()) criterion = torch.nn.CrossEntropyLoss() for epoch in range(epochs): for images, labels in train_loader: optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 每轮保存注意力可视化 visualize_attention(model, images[0]) print(f"Epoch {epoch+1} 完成") # 初始化并训练微型ViT tiny_vit = VisionTransformer( img_size=32, patch_size=8, embed_dim=128, depth=4, num_heads=4, num_classes=10 ) train_model(tiny_vit)通过这个实验,你会发现CLS Token的注意力模式从最初的随机分布逐渐聚焦到图像的关键区域。这种演化直观展示了模型如何学习"关注"对分类最重要的视觉特征。
6. CLS Token的扩展应用:超越分类任务
虽然CLS Token最初是为分类设计,但研究者们已经探索了它在其他视觉任务中的潜力:
- 目标检测:将CLS Token作为区域提议的基准
- 图像分割:用CLS Token注意力指导像素级预测
- 多模态学习:对齐视觉CLS Token和文本[CLS] Token
# 多任务学习示例 class MultiTaskViT(VisionTransformer): def __init__(self): super().__init__(img_size=224, patch_size=16) self.detection_head = torch.nn.Linear(self.embed_dim, 4) # 检测头 self.segmentation_head = torch.nn.Linear(self.embed_dim, 224*224) # 分割头 def forward(self, x): x = self.forward_features(x) cls_token = x[:, 0] # 分类任务 classification = self.head(cls_token) # 检测任务 (使用CLS Token作为ROI基准) detection = self.detection_head(cls_token) # 分割任务 (将CLS Token信息广播到所有patch) segmentation = self.segmentation_head(x) return classification, detection, segmentation这种设计展示了CLS Token作为"视觉语义枢纽"的潜力,它能够将全局信息有效地分发到不同的任务分支。