news 2026/3/16 4:59:16

如何解释 GPT2-Small

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
如何解释 GPT2-Small

原文:towardsdatascience.com/how-to-interpret-gpt2-small-76e0536a588a?source=collection_archive---------8-----------------------#2024-03-22

预测重复令牌的机制可解释性

https://medium.com/@vanillaxiangshuyang?source=post_page---byline--76e0536a588a--------------------------------https://towardsdatascience.com/?source=post_page---byline--76e0536a588a-------------------------------- Shuyang Xiang

·发表于Towards Data Science ·阅读时长 7 分钟·2024 年 3 月 22 日

大规模语言模型的发展,特别是 ChatGPT,令那些曾经试验过它的人,包括我自己,感到惊讶于它卓越的语言能力和执行各种任务的能力。然而,许多研究人员,包括我自己,虽然对它的能力感到惊叹,但也发现自己感到困惑。尽管我们了解模型的架构以及权重的具体数值,但我们仍然难以理解为何某一特定输入序列会导致特定的输出序列。

在这篇博客文章中,我将尝试通过机制可解释性来揭开 GPT2-small 的神秘面纱,研究一个简单的案例:重复令牌的预测。

机制可解释性

传统的数学工具在解释机器学习模型时,并不完全适用于语言模型。

以 SHAP 为例,它是一个有助于解释机器学习模型的工具。它擅长确定哪些特征显著影响了优质葡萄酒的预测。然而,重要的是要记住,语言模型是在令牌级别进行预测的,而 SHAP 值通常是在特征级别计算的,这使得它们可能不适用于令牌。

此外,语言模型(LLM)有大量的参数和输入,形成了一个高维空间。即便在低维空间中,计算 SHAP 值也非常昂贵,而在 LLM 的高维空间中,这一成本则更为庞大。

尽管容忍了高昂的计算成本,SHAP 提供的解释可能显得肤浅。例如,知道“potter”这个词由于之前提到“Harry”而最影响输出预测,但这并没有提供太多的洞见。这让我们无法确定模型的哪一部分或具体机制对这种预测负责。

机制可解释性提供了一种不同的方法。它不仅仅识别模型预测的重要特征或输入。相反,它揭示了模型的底层机制或推理过程,帮助我们理解模型是如何做出预测或决策的。

GPT2-Small 对重复标记的预测

我们将使用 GPT2-small 进行一个简单的任务:预测一系列重复的标记。我们将使用的库是TransformerLens,该库旨在提供 GPT-2 风格语言模型的机制可解释性。

gpt2_small:HookedTransformer=HookedTransformer.from_pretrained("gpt2-small")

我们使用上面的代码加载 GPT2-Small 模型,并对由特定函数生成的序列进行标记预测。该序列包含两个相同的标记序列,后面跟着 bos_token。例如,当 seq_len 为 3 时,序列为“ABCDABCD” + bos_token。为清晰起见,我们将从序列开始到 seq_len 的部分称为前半部分,将剩余部分(不包括 bos_token)称为后半部分。

defgenerate_repeated_tokens(model:HookedTransformer,seq_len:int,batch:int=1)->Int[Tensor,"batch full_seq_len"]:''' Generates a sequence of repeated random tokens Outputs are: rep_tokens: [batch, 1+2*seq_len] '''bos_token=(t.ones(batch,1)*model.tokenizer.bos_token_id).long()# generate bos token for each batchrep_tokens_half=t.randint(0,model.cfg.d_vocab,(batch,seq_len),dtype=t.int64)rep_tokens=t.cat([bos_token,rep_tokens_half,rep_tokens_half],dim=-1).to(device)returnrep_tokens

当我们允许模型在生成的标记上运行时,发现一个有趣的现象:模型在序列的后半部分表现明显优于前半部分。这是通过正确标记的对数概率来衡量的。具体来说,前半部分的性能为-13.898,而后半部分的性能为-0.644。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/0aecbb799f225fbde4fee8e95c671884.png

作者图片:正确标记的对数概率

我们还可以计算预测准确率,定义为正确预测的标记(与生成的标记相同的标记)与总标记数的比率。前半部分序列的准确率为 0.0,这不足为奇,因为我们使用的是没有实际意义的随机标记。与此同时,后半部分的准确率为 0.93,明显优于前半部分。

归纳电路

寻找归纳头

上述观察可能是由归纳电路的存在所解释的。归纳电路扫描序列中的当前标记的前序实例,识别出它之前跟随的标记,并预测相同的序列将会重复。例如,如果它遇到一个‘A’,它会扫描嵌入空间中与‘A’相似的前一个‘A’或标记,识别出随后的标记‘B’,然后预测‘A’之后的标记将是‘B’或与‘B’非常相似的标记。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ee98d87987179fb90146a1571297455d.png

作者图片:归纳电路

这个预测过程可以分为两个步骤:

  1. 识别之前相同(或相似)的 token。序列后半部分的每个 token 应该“关注”前面seq_len位置的 token。例如,如果seq_len为 3,那么位置为 4 的 ‘A’ 应该关注位置为 1 的 ‘A’。我们可以将执行此任务的注意力头称为“引导头”。

  2. 识别下一个 token ‘B’。这是从前一个 token(例如,‘A’)复制信息到下一个 token(例如,‘B’)的过程。当 ‘A’ 再次出现时,这些信息将用于“重现”‘B’。我们可以将执行此任务的注意力头称为“前一个 token 头”。

这两个头构成了一个完整的引导回路。请注意,有时“引导头”一词也用来描述整个“引导回路”。关于引导回路的更多介绍,我强烈推荐阅读这篇文章 In-context learning and induction head,这是一篇杰作!

现在,让我们在 GPT2-small 中识别注意力头和前一个头。

以下代码用于查找引导头。首先,我们使用 30 个批次运行模型。然后,我们计算带有seq_len偏移量的注意力模式矩阵对角线的平均值。此方法使我们能够衡量当前 token 对seq_len之前出现的 token 的关注程度。

definduction_score_hook(pattern:Float[Tensor,"batch head_index dest_pos source_pos"],hook:HookPoint,):''' Calculates the induction score, and stores it in the [layer, head] position of the `induction_score_store` tensor. '''induction_stripe=pattern.diagonal(dim1=-2,dim2=-1,offset=1-seq_len)# src_pos, des_pos, one position right from seq_leninduction_score=einops.reduce(induction_stripe,"batch head_index position -> head_index","mean")induction_score_store[hook.layer(),:]=induction_score seq_len=50batch=30rep_tokens_30=generate_repeated_tokens(gpt2_small,seq_len,batch)induction_score_store=t.zeros((gpt2_small.cfg.n_layers,gpt2_small.cfg.n_heads),device=gpt2_small.cfg.device)rep_tokens_30,return_type=None,pattern_hook_names_filter,induction_score_hook)])

现在,让我们检查引导得分。我们会注意到一些头,如第 5 层第 5 头,具有高达 0.91 的引导得分。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/8a8ae188ca0dd0f8b9fd56e99b55491c.png

作者提供的图片:引导头得分

我们还可以显示该头的注意力模式。你会注意到有一条清晰的对角线,直到seq_len的偏移量。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/cdf95b0714b8e8577791fd751fc7c31a.png

作者提供的图片:第 5 层,第 5 头的注意力模式

类似地,我们可以识别前一个 token 头。例如,第 4 层第 11 头展示了对前一个 token 的强烈模式。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/59d5f94ef7520b604a4296d27da7c44e.png

作者提供的图片:前一个 token 头得分

MLP 层如何归因?

让我们考虑这个问题:MLP 层有影响吗?我们知道 GPT2-Small 包含了注意力和 MLP 层。为了解决这个问题,我提出使用消融技术。

消融,顾名思义,是系统地去除模型中的某些组件,并观察性能变化的过程。

我们将用序列前半部分的输出替换序列后半部分的 MLP 层输出,并观察这如何影响最终的损失函数。我们将使用以下代码计算替换 MLP 层输出后的损失与原始后半序列损失之间的差异。

defpatch_residual_component(residual_component,hook,pos,cache,):residual_component[0,pos,:]=cache[hook.name][pos-seq_len,:]returnresidual_component ablation_scores=t.zeros((gpt2_small.cfg.n_layers,seq_len),device=gpt2_small.cfg.device)gpt2_small.reset_hooks()logits=gpt2_small(rep_tokens,return_type="logits")loss_no_ablation=cross_entropy_loss(logits[:,seq_len:max_len],rep_tokens[:,seq_len:max_len])forlayerintqdm(range(gpt2_small.cfg.n_layers)):forpositioninrange(seq_len,max_len):hook_fn=functools.partial(patch_residual_component,pos=position,cache=rep_cache)ablated_logits=gpt2_small.run_with_hooks(rep_tokens,fwd_hooks=[(utils.get_act_name("mlp_out",layer),hook_fn)])loss=cross_entropy_loss(ablated_logits[:,seq_len:max_len],rep_tokens[:,seq_len:max_len])ablation_scores[layer,position-seq_len]=loss-loss_no_ablation

我们得出了一个令人惊讶的结果:除了第一个标记外,消融实验并没有产生显著的 logit 差异。这表明,在重复标记的情况下,MLP 层可能没有显著的贡献。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/29f50de7de46432cb97d79853870184d.png

作者提供的图片:消融前后 MLP 层的损失差异

一个归纳电路

鉴于 MLP 层对最终预测没有显著贡献,我们可以手动构建一个归纳电路,使用第 5 层的头部 5 和第 4 层的头部 11。回想一下,这些是归纳头和前一个标记头。我们通过以下代码来实现:

defK_comp_full_circuit(model:HookedTransformer,prev_token_layer_index:int,ind_layer_index:int,prev_token_head_index:int,ind_head_index:int)->FactoredMatrix:''' Returns a (vocab, vocab)-size FactoredMatrix, with the first dimension being the query side and the second dimension being the key side (going via the previous token head) '''W_E=gpt2_small.W_E W_Q=gpt2_small.W_Q[ind_layer_index,ind_head_index]W_K=model.W_K[ind_layer_index,ind_head_index]W_O=model.W_O[prev_token_layer_index,prev_token_head_index]W_V=model.W_V[prev_token_layer_index,prev_token_head_index]Q=W_E @ W_Q K=W_E @ W_V @ W_O @ W_KreturnFactoredMatrix(Q,K.T)

计算该电路的 top 1 准确率得到了 0.2283 的值。这对于一个仅由两个头部构建的电路来说相当不错!

有关详细的实现,请查看我的notebook。非常感谢 Neel Nanda,他开发了这个精彩的TransformerLen,这是一个用于大语言模型机制可解释性的伟大工具!

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/14 4:31:51

NCM转换终极指南:快速解密网易云音乐音频文件

NCM转换终极指南:快速解密网易云音乐音频文件 【免费下载链接】NCMconverter NCMconverter将ncm文件转换为mp3或者flac文件 项目地址: https://gitcode.com/gh_mirrors/nc/NCMconverter 你是否曾经为网易云音乐下载的ncm格式文件而烦恼?明明是自己…

作者头像 李华
网站建设 2026/3/13 14:31:02

RePKG:Wallpaper Engine资源处理工具完全指南

RePKG:Wallpaper Engine资源处理工具完全指南 【免费下载链接】repkg Wallpaper engine PKG extractor/TEX to image converter 项目地址: https://gitcode.com/gh_mirrors/re/repkg 基础认知:解决资源处理的核心问题 为什么需要专业的资源处理工…

作者头像 李华
网站建设 2026/3/15 16:13:44

电话号码定位查询终极指南:快速实现地理位置精准定位

电话号码定位查询终极指南:快速实现地理位置精准定位 【免费下载链接】location-to-phone-number This a project to search a location of a specified phone number, and locate the map to the phone number location. 项目地址: https://gitcode.com/gh_mirro…

作者头像 李华
网站建设 2026/3/13 21:12:43

实时数据处理中的背压机制与应对策略

实时数据处理中的背压机制与应对策略 关键词:实时数据处理、背压机制、数据流瓶颈、系统稳定性、流量控制、反压策略、流处理框架 摘要:在实时数据处理场景中(如物联网传感器数据流、电商大促订单洪流、直播弹幕互动),…

作者头像 李华
网站建设 2026/3/14 7:44:33

Unity游戏自动翻译插件:XUnity.AutoTranslator完全指南

Unity游戏自动翻译插件:XUnity.AutoTranslator完全指南 【免费下载链接】XUnity.AutoTranslator 项目地址: https://gitcode.com/gh_mirrors/xu/XUnity.AutoTranslator Unity游戏开发者们,你是否曾为游戏的多语言本地化而头疼?想要让…

作者头像 李华