news 2026/2/6 5:14:19

【RL】op_compute_log_probs 计算过程

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【RL】op_compute_log_probs 计算过程
defloss_func(self,data:DataProto,output_tensor:torch.Tensor):""" loss func接口定义:data:DataProto,由train_step透传 output_tensor:torch.Tensor,model.forward()的输出Tensor""" response_mask=data.batch["response_mask"][:,1:].long()ref_log_probs=data.batch["ref_log_probs"]old_log_probs=data.batch["old_log_probs"]advantages=data.batch["advantages"]log_probs=self.strategy.op_compute_log_probs(logits=output_tensor,input_ids=data.batch["input_ids"],attention_mask=data.batch["response_mask"])ratio=(log_probs-old_log_probs).exp()pg_clip_low=self.pipeline_config.pg_clip_lowifself.pipeline_config.use_pg_clip_rangeelseself.pipeline_config.pg_clip pg_clip_high=self.pipeline_config.pg_clip_highifself.pipeline_config.use_pg_clip_rangeelseself.pipeline_config.pg_clip surr1=ratio*advantages surr2=ratio.clamp(1-pg_clip_low,1+pg_clip_high)*advantages pg_loss=-torch.min(surr1,surr2)ifself.pipeline_config.dual_clip_loss:dual_clip_loss=-torch.max(-pg_loss,(1+self.pipeline_config.pg_clip*2)*advantages)pg_loss=torch.where(advantages<0,dual_clip_loss,pg_loss)pg_loss=agg_loss(loss_mat=pg_loss,loss_mask=response_mask,loss_agg_mode=self.pipeline_config.loss_agg_mode)kl_loss=compute_approx_kl(log_probs=log_probs,log_probs_base=ref_log_probs,action_mask=response_mask,kl_penalty="k3")kl_loss=agg_loss(loss_mat=kl_loss,loss_mask=response_mask,loss_agg_mode=self.pipeline_config.loss_agg_mode)approxkl=compute_approx_kl(log_probs=log_probs,log_probs_base=old_log_probs,action_mask=response_mask,kl_penalty="mse")policykl=compute_approx_kl(log_probs=log_probs,log_probs_base=old_log_probs,action_mask=response_mask,kl_penalty="kl")clipped_low=(ratio<1-pg_clip_low).float()clipped_high=(ratio>1+pg_clip_high).float()clipped=(clipped_low+clipped_high).float()ifself.pipeline_config.use_kl_loss:total_loss=pg_loss+kl_loss*self.pipeline_config.kl_loss_coefelse:total_loss=pg_lossifself.pipeline_config.entropy_loss_coef>0:entropy=self.strategy.op_compute_entropy(logits=output_tensor,attention_mask=data.batch["response_mask"])entropy_loss=agg_loss(loss_mat=entropy,loss_mask=response_mask,loss_agg_mode=self.pipeline_config.loss_agg_mode,)total_loss=total_loss-entropy_loss*self.pipeline_config.entropy_loss_coef pg_metrics={"actor/ppo_ratio_high_clipfrac":clipped_high.mean().detach().item(),"actor/ppo_ratio_low_clipfrac":clipped_low.mean().detach().item(),"actor/ppo_ratio_clipfrac":clipped.mean().detach().item(),"actor/ratio_mean":masked_mean(ratio,response_mask,dim=-1).mean().detach().item(),"actor/ratio_max":torch.max(ratio*response_mask).detach().item(),"actor/ratio_min":torch.min(ratio*response_mask+(1-response_mask)*1e10).detach().item(),"actor/clipfrac":agg_loss(loss_mat=torch.lt(surr2,surr1).float(),loss_mask=response_mask,loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(),"actor/pg_loss":pg_loss.detach().item(),"actor/kl_loss":kl_loss.detach().item(),"actor/total_loss":total_loss.detach().item(),"actor/approxkl":agg_loss(loss_mat=approxkl,loss_mask=response_mask,loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(),"actor/policykl":agg_loss(loss_mat=policykl,loss_mask=response_mask,loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(),}returntotal_loss,pg_metrics

你提出了一个非常好的问题,这个问题涉及到模型输入、输出和损失计算之间如何对齐的关键细节。

data.batch["response_mask"][:, 1:].long()这么做的原因是为了response_mask与模型的输出logits以及需要计算损失的目标log_probs在序列长度维度上对齐

让我们一步步来分解这个原因。

1. 语言模型的输入和输出

假设我们有一个输入序列input_ids

[BOS, "你好", "吗", "?", EOS]

  • BOS: Begin of Sentence token (起始符)
  • EOS: End of Sentence token (结束符)

当这个序列被送入一个自回归语言模型(比如 GPT)时,模型的工作方式是预测下一个 token

  • 输入[BOS],模型预测"你好"
  • 输入[BOS, "你好"],模型预测"吗"
  • 输入[BOS, "你好", "吗"],模型预测"?"
  • 输入[BOS, "你好", "吗", "?"],模型预测EOS

因此,如果输入序列的长度是L,那么模型输出的logits(预测每个位置的下一个 token 的概率分布)的序列长度通常也是L

但是,logits在位置i的输出,是基于input_ids在位置0i的输入得到的,它是用来预测input_ids在位置i+1的 token 的。

看下图示:

input_ids: [ token_0, token_1, token_2, token_3 ] (长度 L=4) / / / / logits: [ logit_0, logit_1, logit_2, logit_3 ] (长度 L=4) | | | | 预测目标: (token_1) (token_2) (token_3) (token_4)

2. 计算损失时的对齐问题

我们要计算的损失,是模型预测的log_probs和真实 token 之间的交叉熵。

  • logit_0用来预测token_1
  • logit_1用来预测token_2
  • logit_2用来预测token_3

我们注意到,logit的最后一个位置logit_3是用来预测token_4的,但我们的输入序列里没有token_4。因此,logit_3在计算损失时通常是被忽略的。

反过来看,token_0(通常是BOSPAD) 从来不作为被预测的目标,所以与它对应的logit也不需要。

所以,为了计算损失,我们需要:

  • logits: 取[logit_0, logit_1, logit_2],即logits[:, :-1, :]
  • 目标input_ids: 取[token_1, token_2, token_3],即input_ids[:, 1:]

两者现在长度都是L-1,完美对齐。

3.response_mask的对齐

现在回到你的问题response_mask

response_mask的原始形状与input_ids是一样的,长度为L。它标记了哪些 token 属于需要计算损失的 “response” 部分。

例如,对于输入[prompt_1, prompt_2, resp_1, resp_2, resp_3, EOS]
原始response_mask可能是[0, 0, 1, 1, 1, 1]

当我们计算损失时,我们用的是logits[:, :-1]input_ids[:, 1:]。这意味着我们实际上是在处理一个长度为L-1的序列。

为了让response_mask能够正确地屏蔽掉这个长度为L-1的序列中不需要计算损失的部分,它自身也需要被截断,以匹配新的序列长度。

我们截断的方式必须和目标input_ids的截断方式一样,即从第二个元素开始取,去掉第一个元素

  • 原始response_mask(长度 L):[mask_0, mask_1, mask_2, ..., mask_{L-1}]
  • 截断后的response_mask(长度 L-1):[mask_1, mask_2, ..., mask_{L-1}]
    • 这个操作就是response_mask[:, 1:]

这样,截断后的mask_1就会对应logit_0和 目标token_1mask_2对应logit_1和 目标token_2,以此类推,对齐关系就完全正确了。

总结

data.batch["response_mask"][:, 1:].long()这行代码的目的是:

  1. [:, 1:]:为了与用于计算损失的logits[:, :-1]和目标input_ids[:, 1:]在序列长度上保持对齐。它通过移除 mask 的第一个元素来实现这一点。
  2. .long(): 将 mask 的数据类型转换为长整型 (torch.int64)。这通常是为了兼容某些 PyTorch 操作,比如用 mask 进行索引或者作为某些损失函数的权重时,可能要求是整型而不是布尔型或浮点型。

这是一个在实现语言模型训练时非常常见且至关重要的细节,确保了计算的正确性。

好的,我们来详细分析你提供的这个特定版本的op_compute_log_probs函数,并用一个具体的例子来走一遍它的执行过程。

这个版本的实现比我之前给出的通用版本更简洁,它依赖一个名为log_probs_from_logits的辅助函数(HuggingFace TRL 库中常见)。我们假设log_probs_from_logits的功能就是我之前描述的第 2 步(Log-Softmax)和第 3 步(Gather)的结合。

函数代码分析

# 假设 log_probs_from_logits 的实现如下:deflog_probs_from_logits(logits,labels):log_probs=F.log_softmax(logits,dim=-1)returntorch.gather(log_probs,dim=-1,index=labels.unsqueeze(-1)).squeeze(-1)classYourClass:defop_compute_log_probs(self,logits:torch.Tensor,input_ids:torch.Tensor,attention_mask:torch.Tensor):""" logits: llm logits, 形状 [batch_size, seq_len, vocab_size] input_ids [[p, p, r, r, r, 0, 0]], 形状 [batch_size, seq_len] attention_mask(response_mask) [[0, 0, 1, 1, 1, 0, 0]], 形状 [batch_size, seq_len] """# 1. 准备 Labels,并处理无效 Tokenlabels:torch.Tensor=input_ids[:,1:].clone()# 将 mask 为 0 的位置的 label 设置为 0labels[attention_mask[:,1:]==0]=0# 2. 计算 Log Probs# 传入错位对齐的 logits 和处理过的 labelslog_probs=log_probs_from_logits(logits[:,:-1],labels)# 3. 应用 Mask# 将不属于 response 的位置的 log_probs 清零log_probs=log_probs*attention_mask[:,1:]returnlog_probs

核心逻辑点

  1. labels[attention_mask[:, 1:] == 0] = 0: 这是这个实现中最有趣和最关键的一步。它的目的是防止log_probs_from_logits访问到无效的 token ID
    • 在 PPO 训练中,input_ids可能包含PADtoken(其 ID 通常是 0)。
    • 如果一个PADtoken 出现在labels中,torch.gather会尝试去访问词汇表索引为 0 的位置。这本身没问题。
    • 但更重要的是,对于 prompt 部分和 padding 部分,我们根本不关心它们的log_probs,因为它们不会计入最终的损失。将这些位置的label统一设置为 0,可以简化计算。虽然gather仍然会为这些位置计算一个值(即词汇表中 token 0 的对数概率),但这没关系,因为第 3 步会把这些位置的log_probs全部清零。这是一个“先计算再丢弃”的策略。

举例说明执行过程

假设有以下微型配置:

  • batch_size = 1
  • seq_len = 7
  • vocab_size = 50000
  • PAD_TOKEN_ID = 0

输入:

  • input_ids:[[101, 102, 201, 202, 203, 0, 0]]
    • [p, p, r, r, r, pad, pad]
  • attention_mask(response_mask):[[0, 0, 1, 1, 1, 0, 0]]
  • logits: 一个由模型生成的[1, 7, 50000]的张量。

执行步骤:

第 1 步: 准备labels
  1. labels = input_ids[:, 1:].clone()

    • input_ids[:, 1:]得到[[102, 201, 202, 203, 0, 0]]
    • labels的值现在是[[102, 201, 202, 203, 0, 0]],形状[1, 6]
  2. 计算maskforlabels

    • attention_mask[:, 1:]得到[[0, 1, 1, 1, 0, 0]]
  3. labels[attention_mask[:, 1:] == 0] = 0

    • attention_mask[:, 1:] == 0会产生一个布尔掩码[[True, False, False, False, True, True]]
    • 这个掩码会选中labels中需要被修改的位置:
      • labels的第 0 个元素 (对应 prompt 部分)
      • labels的第 4 个元素 (对应第一个 pad)
      • labels的第 5 个元素 (对应第二个 pad)
    • labels被原地修改,修改后的值为:[[0, 201, 202, 203, 0, 0]]
      • 注意:原来的102变成了0

    至此,labels准备完毕,值为[[0, 201, 202, 203, 0, 0]]

第 2 步: 计算log_probs
  1. 准备logits

    • logits[:, :-1]得到一个[1, 6, 50000]的张量。
  2. 调用log_probs_from_logits(logits[:, :-1], labels)

    • log_probs_from_logits内部会:
      a. 对logits[:, :-1]在最后一个维度上做log_softmax
      b. 使用labels[[0, 201, 202, 203, 0, 0]]作为索引,通过torch.gatherlog_softmax的结果中提取值。
    • log_probs的计算结果(形状为[1, 6])会是:
      [[ logP(token=0 | p), // prompt 部分,计算了 pad token 的 log_prob logP(token=201 | p,p), // response 部分,正确 logP(token=202 | p,p,r), // response 部分,正确 logP(token=203 | p,p,r,r), // response 部分,正确 logP(token=0 | p,p,r,r,r), // padding 部分,计算了 pad token 的 log_prob logP(token=0 | p,p,r,r,r,0) // padding 部分,计算了 pad token 的 log_prob ]]
    • 假设计算出的值为[[ -3.2, -1.1, -0.8, -1.5, -4.5, -4.8 ]]
第 3 步: 应用 Mask
  1. log_probs = log_probs * attention_mask[:, 1:]
    • log_probs:[[ -3.2, -1.1, -0.8, -1.5, -4.5, -4.8 ]]
    • attention_mask[:, 1:]:[[0, 1, 1, 1, 0, 0]]
    • 两者进行逐元素相乘:
      [ -3.2, -1.1, -0.8, -1.5, -4.5, -4.8 ] * [ 0, 1, 1, 1, 0, 0 ] ---------------------------------------- = [ 0, -1.1, -0.8, -1.5, 0, 0 ]

最终返回结果

函数返回的log_probs张量为:[[ 0, -1.1, -0.8, -1.5, 0, 0 ]]

这个结果非常完美:

  • 只有 response 部分(r, r, r) 的log_probs被保留了下来。
  • prompt 部分和 padding 部分的log_probs都被清零了
  • 这个张量可以直接用于后续的损失计算(例如计算ratio = (log_probs - old_log_probs).exp()),而不需要担心无效位置的干扰。

这个实现方式非常高效和简洁,巧妙地利用了 PyTorch 的广播和掩码操作。

好的,我们来详细拆解这两行代码,它们是PyTorch中一个非常常见且强大的模式,用于从概率分布中根据标签提取特定的值。

这两行代码的核心是torch.gather函数。

torch.gather的工作原理

torch.gather就像一个高级的索引工具。它的作用是:沿着指定的维度(dim),根据index张量中的值,从输入张量(input)中收集元素。

它的签名是torch.gather(input, dim, index)

为了让它工作,index张量需要满足一些条件,最重要的是:

  • index的维度数量必须和input的维度数量相同。
  • 在所有dim的维度上,index的大小必须和input的大小相同(或者为 1,可以广播)。

结合你的代码进行分解

我们一步步来看:

# 假设我们有以下张量(以 batch_size=1, seq_len=3, vocab_size=5 为例)# log_probs: [1, 3, 5] 的张量,代表了3个位置上,每个词的对数概率log_probs=torch.tensor([[[-1.6,-2.1,-0.9,-3.0,-1.8],# 位置0的 log_probs[-0.5,-1.1,-2.5,-1.3,-4.0],# 位置1的 log_probs[-3.2,-1.9,-1.0,-2.2,-0.8]# 位置2的 log_probs]])# labels: [1, 3] 的张量,代表了3个位置上,正确的 token IDlabels=torch.tensor([[2,0,4]])

第 1 步:labels.unsqueeze(-1)
  • 目的: 增加一个维度,使labels的维度数量与log_probs相同,从而满足torch.gather的要求。
  • 输入labels:
    • 形状:[1, 3]
    • 值:[[2, 0, 4]]
  • 操作:unsqueeze(-1)在最后一个维度(维度索引为-1)上增加一个大小为 1 的新维度。
  • 输出index:
    • 形状:[1, 3, 1]
    • 值:
      [[[2], [0], [4]]]

现在,log_probs(3D) 和index(3D) 的维度数量相同了。


第 2 步:log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
  • input:log_probs(形状[1, 3, 5])
  • dim:-1(或2)。这意味着我们将在最后一个维度——**词汇表维度(vocab_size)**上进行收集。
  • index:labels.unsqueeze(-1)(形状[1, 3, 1])

gather的执行过程 (可以想象成一个 for 循环):

  1. gather会遍历index张量的所有位置。
  2. 对于index中的每个元素(batch_idx, seq_idx, 0),它会取出其中的值v = index[batch_idx, seq_idx, 0]
  3. 然后,它会在log_probs张量的对应位置(batch_idx, seq_idx, ...)上,沿着dim=-1收集索引为v的元素。
  4. 它将收集到的值放在输出张量的(batch_idx, seq_idx, 0)位置。

让我们手动走一遍:

  • 处理index[0, 0, 0]位置:

    • index[0, 0, 0]的值是2
    • gatherlog_probs[0, 0, :]位置,也就是[-1.6, -2.1, -0.9, -3.0, -1.8]
    • 它从这个向量中取出索引为2的元素,即-0.9
    • 输出张量的[0, 0, 0]位置被设置为-0.9
  • 处理index[0, 1, 0]位置:

    • index[0, 1, 0]的值是0
    • gatherlog_probs[0, 1, :]位置,也就是[-0.5, -1.1, -2.5, -1.3, -4.0]
    • 它从这个向量中取出索引为0的元素,即-0.5
    • 输出张量的[0, 1, 0]位置被设置为-0.5
  • 处理index[0, 2, 0]位置:

    • index[0, 2, 0]的值是4
    • gatherlog_probs[0, 2, :]位置,也就是[-3.2, -1.9, -1.0, -2.2, -0.8]
    • 它从这个向量中取出索引为4的元素,即-0.8
    • 输出张量的[0, 2, 0]位置被设置为-0.8

gather的输出log_probs_labels:

  • 形状:[1, 3, 1](与index的形状相同)
  • 值:
    [[[-0.9], [-0.5], [-0.8]]]

直观理解: 对于序列中的每个位置,我们都从完整的词汇表概率分布中,只挑选出了正确标签(label)对应的那个对数概率


第 3 步:.squeeze(-1)
  • 目的: 移除多余的、大小为 1 的维度,让张量更易于处理。
  • 输入log_probs_labels:
    • 形状:[1, 3, 1]
  • 操作:squeeze(-1)移除最后一个维度(因为它的大小是 1)。
  • 输出:
    • 形状:[1, 3]
    • 值:[[-0.9, -0.5, -0.8]]

最终结果

函数最终返回了一个[1, 3]的张量[[-0.9, -0.5, -0.8]]

这个张量的每个元素output[i, j]都代表了在批次i的序列位置j,模型赋予正确label的对数概率。这正是我们计算交叉熵损失或 PPO 损失时所需要的核心数值。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/2/4 10:07:21

从代码编译到服务上线:Open-AutoGLM生产级部署的7个关键步骤

第一章&#xff1a;Open-AutoGLM开源部署教程环境准备 在部署 Open-AutoGLM 之前&#xff0c;需确保本地或服务器环境已安装必要的依赖组件。推荐使用 Linux 系统&#xff08;如 Ubuntu 20.04&#xff09;进行部署。安装 Python 3.9 或更高版本配置虚拟环境以隔离依赖安装 Git …

作者头像 李华
网站建设 2026/2/4 1:04:52

Intel RealSense深度摄像头:Python开发者的5个核心技术突破

Intel RealSense深度摄像头&#xff1a;Python开发者的5个核心技术突破 【免费下载链接】librealsense Intel RealSense™ SDK 项目地址: https://gitcode.com/GitHub_Trending/li/librealsense Intel RealSense™ SDK为Python开发者打开了一扇通往深度感知世界的大门。…

作者头像 李华
网站建设 2026/2/3 22:25:47

如何快速实现高质量语音转换:Mangio-RVC-Fork终极使用指南

如何快速实现高质量语音转换&#xff1a;Mangio-RVC-Fork终极使用指南 【免费下载链接】Mangio-RVC-Fork *CREPEHYBRID TRAINING* A very experimental fork of the Retrieval-based-Voice-Conversion-WebUI repo that incorporates a variety of other f0 methods, along with…

作者头像 李华
网站建设 2026/2/3 9:46:52

免费开源矢量刺绣设计完整指南:InkStitch从入门到精通

免费开源矢量刺绣设计完整指南&#xff1a;InkStitch从入门到精通 【免费下载链接】inkstitch Ink/Stitch: an Inkscape extension for machine embroidery design 项目地址: https://gitcode.com/gh_mirrors/in/inkstitch 厌倦了商业刺绣软件的复杂操作和昂贵费用&…

作者头像 李华
网站建设 2026/2/4 7:54:03

突破传统边界:YYEVA动态MP4动效播放器全解析与实战指南

在内容创作日新月异的今天&#xff0c;你是否曾为静态MP4资源的局限性而困扰&#xff1f;YYEVA动态MP4动效播放器应运而生&#xff0c;彻底打破了传统视频资源的束缚&#xff0c;让MP4文件能够实时插入动态元素&#xff0c;为你的创意提供无限可能。 【免费下载链接】YYEVA YYE…

作者头像 李华
网站建设 2026/2/3 7:48:02

北航矩阵理论期末真题:快速获取与高效复习指南

北航矩阵理论期末真题&#xff1a;快速获取与高效复习指南 【免费下载链接】矩阵理论期末试卷北航资源下载分享 矩阵理论期末试卷&#xff08;北航&#xff09;资源下载 项目地址: https://gitcode.com/Open-source-documentation-tutorial/88e5f 想要顺利通过北航矩阵理…

作者头像 李华