Set Transformer:3步掌握置换不变注意力机制的代码实现
1. 为什么我们需要处理集合数据?
在机器学习领域,我们经常遇到需要处理集合数据的场景。想象一下,你面前有一堆散落的乐高积木——这些积木没有固定的排列顺序,但它们的组合方式决定了最终能搭建出什么模型。这就是集合数据的典型特征:元素之间没有顺序关系,但整体具有特定含义。
传统神经网络在处理这类数据时面临两个核心挑战:
- 置换不变性:无论积木的排列顺序如何改变,只要组合相同,最终搭建的模型应该相同
- 可变集合大小:积木数量可以任意增减,网络需要适应不同大小的输入
常见应用场景包括:
- 点云处理(自动驾驶中的物体识别)
- 多实例学习(医疗影像分析)
- 分子性质预测(化学结构分析)
- 推荐系统(用户行为集合建模)
传统RNN虽然能处理变长输入,但对顺序敏感;CNN需要固定尺寸输入。Set Transformer通过注意力机制完美解决了这两个问题。
2. Set Transformer的核心创新
2.1 置换不变注意力机制
Set Transformer的核心是Set Attention Block (SAB),它通过自注意力机制让集合中的每个元素都能与其他元素交互:
class SetAttentionBlock(nn.Module): def __init__(self, dim, heads=8): super().__init__() self.attention = nn.MultiheadAttention(dim, heads) self.norm1 = nn.LayerNorm(dim) self.ffn = nn.Sequential( nn.Linear(dim, dim*4), nn.ReLU(), nn.Linear(dim*4, dim) ) self.norm2 = nn.LayerNorm(dim) def forward(self, x): # x: [set_size, batch_size, dim] attn_out = self.attention(x, x, x)[0] x = self.norm1(x + attn_out) ffn_out = self.ffn(x) x = self.norm2(x + ffn_out) return x关键特性:
- 无论输入顺序如何变化,输出保持不变(置换不变性)
- 可以处理任意大小的输入集合
- 通过注意力权重显式建模元素间关系
2.2 诱导注意力降低计算复杂度
原始自注意力复杂度为O(n²),对于大集合不实用。Set Transformer提出Induced Set Attention Block (ISAB),引入m个诱导点(通常m≪n):
class InducedSetAttentionBlock(nn.Module): def __init__(self, dim, num_inds, heads=8): super().__init__() self.induced_points = nn.Parameter(torch.randn(num_inds, dim)) self.mab1 = MAB(dim, dim, dim, heads) # MAB是基础注意力模块 self.mab2 = MAB(dim, dim, dim, heads) def forward(self, x): # x: [set_size, batch_size, dim] h = self.mab1(self.induced_points, x) # 诱导点与输入交互 return self.mab2(x, h) # 输入与处理后的诱导点交互复杂度从O(n²)降到O(nm),其中m是诱导点数量,通常远小于n。
2.3 完整架构设计
典型的Set Transformer包含编码器和解码器:
编码器架构对比:
| 组件 | 传统Pooling方法 | Set Transformer |
|---|---|---|
| 元素处理 | 独立MLP | 通过SAB/ISAB交互 |
| 聚合方式 | 简单平均/最大池化 | 注意力池化 |
| 复杂度 | O(n) | O(nm) |
| 关系建模 | 无显式建模 | 显式注意力权重 |
解码器使用Pooling by Multihead Attention (PMA),比普通池化更能保留集合的关键信息:
class PMA(nn.Module): def __init__(self, dim, num_seeds, heads=8): super().__init__() self.seeds = nn.Parameter(torch.randn(num_seeds, dim)) self.mab = MAB(dim, dim, dim, heads) def forward(self, x): return self.mab(self.seeds, x)3. 实战:点云分类任务
让我们用PyTorch实现一个完整的点云分类模型。假设输入是n个3D点坐标的集合,输出是类别标签。
3.1 数据预处理
from torch_geometric.datasets import ModelNet from torch_geometric.loader import DataLoader # 加载ModelNet10数据集 train_dataset = ModelNet(root='data/ModelNet10', name='10', train=True) test_dataset = ModelNet(root='data/ModelNet10', name='10', train=False) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)3.2 模型实现
import torch import torch.nn as nn import torch.nn.functional as F class SetTransformer(nn.Module): def __init__(self, input_dim=3, hidden_dim=128, output_dim=10, num_heads=4, num_inds=32, num_blocks=4): super().__init__() # 输入嵌入层 self.embed = nn.Linear(input_dim, hidden_dim) # 编码器:堆叠ISAB块 self.encoder = nn.Sequential(*[ InducedSetAttentionBlock(hidden_dim, num_inds, num_heads) for _ in range(num_blocks) ]) # 解码器:PMA + 线性层 self.decoder = nn.Sequential( PMA(hidden_dim, num_seeds=1, heads=num_heads), nn.Linear(hidden_dim, output_dim) ) def forward(self, x): # x: [batch_size, set_size, input_dim] x = x.transpose(0, 1) # [set_size, batch_size, input_dim] x = self.embed(x) x = self.encoder(x) x = self.decoder(x) return x.squeeze(0) # [batch_size, output_dim]3.3 训练与可视化
# 初始化模型和优化器 model = SetTransformer() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() # 训练循环 for epoch in range(100): model.train() for data in train_loader: optimizer.zero_grad() out = model(data.pos.reshape(data.batch_size, -1, 3)) loss = criterion(out, data.y) loss.backward() optimizer.step() # 验证 model.eval() correct = 0 for data in test_loader: pred = model(data.pos.reshape(data.batch_size, -1, 3)).argmax(dim=1) correct += (pred == data.y).sum().item() acc = correct / len(test_dataset) print(f'Epoch {epoch}, Test Acc: {acc:.4f}')注意力可视化技巧:
import matplotlib.pyplot as plt def visualize_attention(model, sample): # 注册hook获取注意力权重 attention_maps = [] def hook(module, input, output): attention_maps.append(output[1].detach()) # 输出是(output, attention_weights) # 为每个注意力层注册hook handles = [] for block in model.encoder: handles.append(block.mab1.attention.register_forward_hook(hook)) handles.append(block.mab2.attention.register_forward_hook(hook)) # 前向传播 with torch.no_grad(): model(sample) # 移除hook for handle in handles: handle.remove() # 可视化第一个注意力头的权重 plt.figure(figsize=(12, 8)) for i, attn in enumerate(attention_maps[:4]): # 只看前4个注意力图 plt.subplot(2, 2, i+1) plt.imshow(attn[0, 0].cpu().numpy()) # 第一个样本,第一个注意力头 plt.colorbar() plt.show()4. 进阶技巧与优化建议
在实际项目中应用Set Transformer时,以下几点经验值得注意:
诱导点数量的选择:
- 小集合(n<100):可以直接使用SAB
- 中等集合(100<n<1000):ISAB,m=32-64
- 大集合(n>1000):考虑分层注意力或采样策略
处理高维特征:
# 当输入特征维度较高时 self.embed = nn.Sequential( nn.Linear(input_dim, hidden_dim*2), nn.ReLU(), nn.Linear(hidden_dim*2, hidden_dim) )正则化策略:
- 注意力dropout(防止过拟合)
- 层归一化的位置(Pre-LN vs Post-LN)
- 标签平滑(Label Smoothing)
与其他架构的结合:
- 对于局部结构明显的集合(如分子图),可以结合图卷积
- 对于时序集合数据,可以加入轻量级LSTM层
部署优化:
# 使用TorchScript提高推理速度 scripted_model = torch.jit.script(model) scripted_model.save('set_transformer.pt')
Set Transformer在多个基准测试中表现出色,例如在点云分类任务上,使用相同参数量的情况下,相比传统PointNet方法可以获得2-3%的准确率提升。更重要的是,注意力权重提供了可解释性——我们可以直观地看到哪些元素对最终决策贡献更大。