news 2026/7/1 20:33:34

sglang 大模型推理框架支持的EAGLE 1,2,3

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
sglang 大模型推理框架支持的EAGLE 1,2,3

文章目录

      • EAGLE 系列模型的演进与核心机制
      • 关键参数与训练逻辑
      • 思考

参考来源:https://docs.sglang.com.cn/backend/speculative_decoding.html
https://github.com/SafeAILab/EAGLE
EAGLE3 https://arxiv.org/pdf/2503.01840

EAGLE 系列模型的演进与核心机制

EAGLE 基础架构
草稿模型通过特征序列和 token 序列预测下一个特征向量,基于原始 LLM 的最后一个隐藏状态生成候选。采样后的 token 与原始序列以树状结构扩展,分支因子由speculative_eagle_topk控制,确保上下文连贯性。扩展后的树结构重新作为输入迭代生成。

EAGLE-2 的优化
引入动态分支评估机制,草稿模型主动评估扩展分支的可能性,提前终止低概率分支的扩展。扩展阶段结束后,通过重排序筛选前speculative_num_draft_tokens个节点作为最终草稿 token,减少冗余计算。

--speculative-token-map参数设置为true以启用高频 token 优化功能。该参数通常在模型推理或训练配置文件中进行设置。

EAGLE-3 的改进
移除特征预测目标,整合低层与中间层特征提升表示能力。采用 on-policy 训练方式,使模型在推理阶段的行为与训练目标更一致,进一步优化生成质量与效率。

关键参数与训练逻辑

  • speculative_eagle_topk:控制每步扩展的分支数量,影响生成多样性与计算开销。
  • speculative_num_draft_tokens:决定保留的候选 token 数量,平衡生成速度与准确性。
  • On-policy 训练:通过对齐训练与推理阶段的策略,减少分布偏移问题。

  • https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py

核心代码部分

def_prepare_decoder_attention_mask(self,attention_mask,input_shape,inputs_embeds,past_key_values_length):# create causal mask# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]combined_attention_mask=Noneifinput_shape[-1]>1:combined_attention_mask=_make_causal_mask(input_shape,inputs_embeds.dtype,device=inputs_embeds.device,past_key_values_length=past_key_values_length,)ifattention_maskisnotNone:# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]expanded_attn_mask=_expand_mask(attention_mask,inputs_embeds.dtype,tgt_len=input_shape[-1]).to(inputs_embeds.device)combined_attention_mask=(expanded_attn_maskifcombined_attention_maskisNoneelseexpanded_attn_mask+combined_attention_mask)returncombined_attention_mask@torch.no_grad()defdataprepare(self,input_ids,attention_mask,loss_mask):device=input_ids.device outs=self.target_model(input_ids=input_ids,attention_mask=attention_mask)hidden_states0=outs.hidden_states[0]hidden_states1=outs.hidden_states[1]hidden_states2=outs.hidden_states[2]hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1)# hidden_states=torch.cat((hidden_states0,hidden_states1),dim=-1)target=outs.logits target=padding(target,left=False)input_ids=padding(input_ids,left=False)iftargetisnotNone:target=target.to(device)loss_mask=loss_mask[...,None]loss_mask=loss_mask.to(device)returnhidden_states,target,loss_mask,input_idsdefforward(self,# hidden_states,input_ids,attention_mask:Optional[torch.Tensor]=None,position_ids:Optional[torch.LongTensor]=None,past_key_values:Optional[List[torch.FloatTensor]]=None,use_cache:Optional[bool]=None,output_attentions:Optional[bool]=None,output_hidden_states:Optional[bool]=None,loss_mask:Optional[torch.Tensor]=None,):hidden_states,target,loss_mask,input_ids=self.dataprepare(input_ids,attention_mask,loss_mask)batch_size,seq_length,_=hidden_states.shape seq_length_with_past=seq_length past_key_values_length=0# with torch.no_grad():# inputs_embeds = self.embed_tokens(input_ids)# inputs_embeds = inputs_embeds.detach()ifself.trainingandself.gradient_checkpointingandnothidden_states.requires_grad:hidden_states.requires_grad=Truehidden_states=self.fc(hidden_states)ifpast_key_valuesisnotNone:past_key_values_length=past_key_values[0][0].shape[2]seq_length_with_past=seq_length_with_past+past_key_values_lengthifposition_idsisNone:device=hidden_states.device position_ids=torch.arange(past_key_values_length,seq_length+past_key_values_length,dtype=torch.long,device=device)position_ids=position_ids.unsqueeze(0).view(-1,seq_length)else:position_ids=position_ids.view(-1,seq_length).long()ifattention_maskisNone:attention_mask=torch.ones((batch_size,seq_length_with_past),dtype=torch.bool,device=hidden_states.device)attention_mask=self._prepare_decoder_attention_mask(attention_mask,(batch_size,seq_length),hidden_states,past_key_values_length)ifself.gradient_checkpointingandself.training:ifuse_cache:use_cache=Falseplosses=[]vlosses=[]acces=[]cache_hidden=[[],[]]foridxinrange(self.length):last=idx==self.length-1inputs_embeds=self.embed_tokens(input_ids)ifself.trainingandself.gradient_checkpointingandnotinputs_embeds.requires_grad:inputs_embeds.requires_grad=Trueinputs_embeds=inputs_embeds.to(hidden_states.dtype)ifself.gradient_checkpointingandself.training:defcreate_custom_forward(module):defcustom_forward(*inputs):# None for past_key_valuereturnmodule(*inputs,None,output_attentions)returncustom_forward layer_outputs,cache_hidden=torch.utils.checkpoint.checkpoint(create_custom_forward(self.midlayer),inputs_embeds,hidden_states,cache_hidden,attention_mask,position_ids,)else:layer_outputs,cache_hidden=self.midlayer(input_emb=inputs_embeds,hidden_states=hidden_states,cache_hidden=cache_hidden,attention_mask=attention_mask,position_ids=position_ids,past_key_value=None,output_attentions=output_attentions,use_cache=True,)hidden_states_out=layer_outputs[0]# cache_hidden.append(layer_outputs[1])# kv_cahce = layer_outputs[-1]withtorch.no_grad():# hidden_states_target = padding(hidden_states, left=False)target_head=target target_max_token=target_head.argmax(-1)# Move d2t to the same device as target_max_tokenself.t2d=self.t2d.to(target_max_token.device)target_mask=self.t2d[target_max_token]target_mask=target_mask[...,None].int()position_mask=target_mask*loss_mask target_head=target_head[...,self.t2d]target_head=target_head.float()target_p=nn.Softmax(dim=2)(target_head)target_p=target_p.detach()hidden_states=hidden_states_out hidden_states_out=self.norm(hidden_states_out)logits=self.lm_head(hidden_states_out)logits=logits.float()out_logp=nn.LogSoftmax(dim=2)(logits)plogp=target_p*out_logp loss=-torch.sum(position_mask*plogp,2).mean()plosses.append(loss)withtorch.no_grad():acces.append(((logits.argmax(-1)==target_p.argmax(-1))*position_mask.squeeze(-1)).sum().item()/(loss_mask.sum().item()+1e-6))ifnotlast:input_ids=padding(input_ids,left=False)target=padding(target,left=False)loss_mask=padding(loss_mask,left=False)returnplosses,vlosses,acces

思考

》 FASTMTP与EAGLE3相比,谁更快一些?

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/30 17:59:34

汇编语言全接触-26.启动画面

上一章我们学习了位图的使用.在这一章我们要用上帝赋予我们的创造力来融会贯通上一章我们学到的知识.那就是研究如何用位图来创建启动画面. 你可以在这里下载示范: the example. 理论首先,我们先要搞清楚什么是启动画面.举个简单的例子:我们启动某些作的专业一点的程序时(比如N…

作者头像 李华
网站建设 2026/6/25 22:37:58

随机抽奖算法实现与对比:聚焦洗牌算法(Fisher-Yates)

期末课程设计中,我和团队成员共同完成了 “随机抽奖算法实现与比较” 的课题。本次设计的核心目标是模拟实际抽奖场景,从指定号码范围(min_num 到 max_num)中抽取 k 个不重复的中奖号码,并通过实现四种不同算法&#x…

作者头像 李华
网站建设 2026/6/29 12:27:03

【Hadoop+Spark+python毕设】物联网网络安全威胁数据分析系统、计算机毕业设计、包括数据爬取、数据分析、数据可视化、Hadoop、实战教学

🎓 作者:计算机毕设小月哥 | 软件开发专家 🖥️ 简介:8年计算机软件程序开发经验。精通Java、Python、微信小程序、安卓、大数据、PHP、.NET|C#、Golang等技术栈。 🛠️ 专业服务 🛠️ 需求定制化开发源码提…

作者头像 李华
网站建设 2026/6/30 19:56:55

Springboot连锁药店进销存业务系统98i85(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。

系统程序文件列表项目功能:员工,供应商,药品信息,药品采购,进货出库,药品销售,退货入库,药品报损,药品销毁开题报告内容基于SpringBoot的连锁药店进销存业务系统开题报告一、选题背景与意义1.1 行业现状与痛点随着医疗行业的快速发展和人们对健康需求的日益增加&…

作者头像 李华
网站建设 2026/7/1 6:51:19

智能测试指标动态权重分配研究

随着人工智能与机器学习技术在软件测试领域的深度渗透,传统静态权重分配模式已难以适应瞬息万变的测试环境。本文基于2025年行业实践数据,提出以动态权重分配为核心的新型测试评估体系,通过构建具备自适应能力的指标权重矩阵,有效…

作者头像 李华
网站建设 2026/6/30 21:19:14

std::promise 重难点

std::promise 重难点全拆解 std::promise 是 C11 异步编程的核心组件,但其难点不在于语法本身,而在于状态管理、生命周期控制、异常传递等“隐性规则”——踩中任何一个都可能导致程序崩溃或逻辑异常。本文用“专业底层逻辑通俗比喻分步实操”的方式&…

作者头像 李华