1. 为什么你的模型“记性”这么差?(痛点与背景)
想象一下,你训练了一个神经网络来识别手写数字(MNIST),准确率高达 99%。
接着,你希望能复用这个聪明的脑子,让它继续学习识别时尚单品(Fashion-MNIST)。
你把模型拿来,在“衣服鞋子”的数据集上跑了几轮训练。结果很棒,它现在能完美识别运动鞋和衬衫了。
但是,当你随手扔给它一张数字 “7” 的图片时,它却一脸自信地告诉你:“这是一只靴子!”
这就是灾难性遗忘(Catastrophic Forgetting)。
在传统的反向传播中,为了让模型适应新任务(Task B),优化器会毫不留情地修改模型里的权重参数。它并不在乎这些参数之前对旧任务(Task A)有多重要,只要能降低 Task B 的 Loss,它就会大幅改变权重。结果就是:旧知识的“神经连接”被彻底破坏了。
EWC(Elastic Weight Consolidation)的出现,就是为了解决这个问题。它能让模型在学习新技能的同时,优雅地“锁住”那些对旧技能至关重要的记忆。
2. 概念拆解:给神经元装上“弹簧”
EWC 的论文里充满了费雪信息矩阵(Fisher Information Matrix)和黑森矩阵(Hessian Matrix)等高深术语,但我们先忘掉数学,用**“房间装修”**来打个比方。
🏠 生活化类比:设计师的妥协
把神经网络想象成一个刚刚装修好的房间。
Task A(旧任务):这是一个“家庭影院”模式。为了达到最佳视听效果,沙发(权重 1)、音响(权重 2)、投影仪(权重 3)必须摆在特定的位置。
Task B(新任务):现在你想把这个房间改成“瑜伽室”。
没有 EWC 的做法:
装修队进来,为了腾出瑜伽空间,直接把沙发扔出去,把音响砸了。瑜伽室很完美,但家庭影院彻底毁了。
有 EWC 的做法:
你告诉装修队:“有些东西你们随便动,但有些东西很重要,动起来很费劲。”
不重要的权重(比如墙角的绿植):对家庭影院影响不大,随便移。
重要的权重(比如投影仪):对家庭影院极其重要。如果你非要移动它,就像是在拉一根极其坚硬的弹簧。你可以稍微挪一点点,但挪得越远,弹簧的反作用力(惩罚项)就越大。
EWC 的核心魔法就在于:它能自动计算出哪些家具(权重)是“承重墙”,哪些是“装饰品”。
🧩 原理图解逻辑
训练 Task A:正常训练,找到最优权重
。
计算重要性(Fisher Matrix):分析 Task A 的 Loss 地形。如果某个权重稍微变动一下,Loss 就剧烈飙升,说明这个权重非常重要(地形陡峭);如果权重变了很多 Loss 还没啥反应,说明它不重要(地形平坦)。
训练 Task B:在 Loss 函数后面加上一个 EWC 惩罚项(那个弹簧)。
:新任务的正常 Loss。
:这一项决定了你有多想“守旧”。值越大,越难忘。
:费雪信息量(重要性系数)。越重要,这一项越大,改变参数带来的 Penalty 就越大。
3. 动手实战:PyTorch 实现 EWC
我们将通过一个极简的例子:先让模型拟合一个函数,再拟合另一个函数,看它能不能同时记住两者。
环境准备
你需要安装 PyTorch:
pip install torch matplotlib
核心代码解析
/* by 01130.hk - online tools website : 01130.hk/zh/alldns.html */ import torch import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt import copy # =========================== # 1. 定义一个简单的神经网络 # =========================== class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() # 为了演示,我们用一个小型的网络 self.fc1 = nn.Linear(10, 20) self.fc2 = nn.Linear(20, 20) self.fc3 = nn.Linear(20, 2) # 输出2类 def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) return self.fc3(x) # =========================== # 2. EWC 核心类 (The Magic) # =========================== class EWC: def __init__(self, model, dataset): self.model = model self.dataset = dataset # 存储旧任务的最优参数 (theta_A*) self.params = {n: p.data.clone() for n, p in self.model.named_parameters()} # 存储每个参数的重要性 (Fisher Information) self.fisher = self._calculate_fisher() def _calculate_fisher(self): fisher = {} # 初始化 Fisher 矩阵为 0 for n, p in self.model.named_parameters(): fisher[n] = torch.zeros_like(p.data) self.model.eval() criterion = nn.CrossEntropyLoss() # 遍历数据计算梯度的平方 # 这里的逻辑是:梯度越大,说明参数稍微一动 Loss 变化就大 -> 参数越重要 for input_data, target in self.dataset: self.model.zero_grad() output = self.model(input_data.unsqueeze(0)) # batch_size=1 loss = criterion(output, target.unsqueeze(0)) loss.backward() for n, p in self.model.named_parameters(): if p.grad is not None: # Fisher 近似等于 梯度的平方 fisher[n] += p.grad.data ** 2 # 归一化 for n in fisher: fisher[n] /= len(self.dataset) return fisher def penalty(self, new_model): loss = 0 for n, p in new_model.named_parameters(): # EWC 公式: Sum( F * (new_theta - old_theta)^2 ) _loss = self.fisher[n] * (p - self.params[n]) ** 2 loss += _loss.sum() return loss # =========================== # 3. 模拟训练流程 # =========================== def get_data(task_id): # 模拟数据:Task 1 输入全1,Task 2 输入全0 if task_id == 1: return [(torch.ones(10), torch.tensor(0)) for _ in range(100)] else: return [(torch.zeros(10), torch.tensor(1)) for _ in range(100)] # 实例化模型 model = SimpleNet() optimizer = optim.SGD(model.parameters(), lr=0.1) criterion = nn.CrossEntropyLoss() print(">>> 开始训练 任务 A (识别全1向量)") data_a = get_data(1) for epoch in range(5): for x, y in data_a: optimizer.zero_grad() loss = criterion(model(x.unsqueeze(0)), y.unsqueeze(0)) loss.backward() optimizer.step() print("任务 A 训练完成。保存 EWC 状态...") # --- 关键步骤:计算 Task A 的重要性权重 --- ewc = EWC(model, data_a) print(">>> 开始训练 任务 B (识别全0向量),同时开启 EWC 保护") data_b = get_data(2) ewc_lambda = 1000 # 惩罚力度,越大越照顾旧任务 for epoch in range(5): total_loss = 0 for x, y in data_b: optimizer.zero_grad() # 1. 计算新任务的 Loss loss_b = criterion(model(x.unsqueeze(0)), y.unsqueeze(0)) # 2. 计算 EWC 惩罚项 (旧任务的记忆) loss_ewc = ewc.penalty(model) # 3. 总 Loss loss = loss_b + (ewc_lambda * loss_ewc) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch}: Loss = {total_loss:.4f}") # =========================== # 4. 验证结果 # =========================== model.eval() test_a = model(torch.ones(10).unsqueeze(0)) test_b = model(torch.zeros(10).unsqueeze(0)) print("\n=== 最终测试 ===") print(f"Task A (应为类别0): 预测概率 {torch.softmax(test_a, dim=1).detach().numpy()}") print(f"Task B (应为类别1): 预测概率 {torch.softmax(test_b, dim=1).detach().numpy()}")代码划重点:
_calculate_fisher:这是 EWC 的灵魂。我们在训练完 Task A 后,立刻冻结模型,通过反向传播拿到梯度。注意这里不用optimizer.step(),我们只想要梯度值来计算$F_i$。penalty:在训练 Task B 时,每次迭代都会调用这个函数。它检查当前的参数偏离旧参数有多远,并乘以重要性系数。
4. 进阶深潜:陷阱与最佳实践
⚠️ 常见陷阱
Fisher 计算开销:在上面的代码中,我们是一个样本一个样本算的(为了代码清晰)。在生产环境中,这会非常慢。
优化:使用小批量(Mini-batch)来估算 Fisher 信息,不需要遍历整个数据集,随机采样几千个样本通常就足够了。
Lambda 的平衡:
太小:EWC 失效,照样遗忘。
太大:模型被旧记忆“锁死”了,根本学不进新东西(欠拟合 Task B)。这需要像调学习率一样去调参。
🚀 生产环境贴士
多任务扩展:如果你有 Task A, B, C... 怎么做?
通常的做法是维护一个累积的 Fisher 矩阵。当你学完 Task B 准备学 C 时,你的锚点应该变成 Task B 的参数,而 Fisher 矩阵应该是 A 和 B 的重要性之和。
在线 EWC (Online EWC):这是一种更高效的变体,解决了存储多个 Fisher 矩阵带来的内存爆炸问题。
5. 总结与延伸
核心知识点
EWC 本质上是一种正则化(Regularization)技术。它通过费雪信息矩阵识别出神经网络中的“关键承重墙”,并在学习新知识时强行保护这些区域,从而在可塑性(学习新知识)和稳定性(保持旧知识)之间找到平衡。