news 2026/5/9 3:42:29

从零开始写Qwen3(五-其四)FlashAttention 差异汇编分析

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从零开始写Qwen3(五-其四)FlashAttention 差异汇编分析

从零开始写Qwen3目录

概述

经过前文的提速,耗时已经从官方的214%降低到112%,本文将从汇编角度猜测一下差距的原因

概述

使用上一节的输入参数,设置为BM=BN=64,和torch相同,分析汇编指令
torch的指令统计如下

triton实现的指令统计如下


HMMA 是 Half Matrix Multiply Accumulation的意思,这是FlashAttn的核心指令,使用张量核进行矩阵乘法加速,对比两个统计,发现不管是指令数量还是实际执行次数都是一样的,差别可能在共享内存加载部分

指令执行次数分析

单条汇编执行次数常见有这么几种数字:

  • 2048
  • 15360

2048是无循环的执行次数,15360是执行循环的次数
2048 = 2 ⏟ B × 16 ⏟ H × 1024 64 ⏟ Q × 4 ⏟ n w a r p s 2048 = \underbrace{2}_{B} \times \underbrace{16}_{H}\times \underbrace{\frac{1024}{64}}_{Q}\times \underbrace{4}_{nwarps}2048=B2×H16×Q641024×nwarps4

可以计算

15360 = 2048 × 16 − 1 2 15360=2048\times \frac{16-1}{2}15360=2048×2161

所以2048是阶段2和公共部分的执行次数,15360是阶段2的执行次数,阶段2平均循环了7.5次,两个阶段指令数基本一致(除了因果遮罩那里,阶段1没有),所以平均执行次数是8704 87048704

张量核

张量核是CUDA从Volta开始引入的一个指令,专门用于矩阵加速,它用一条指令让一个线程束一起完成一个小块矩阵乘法,不仅简化了矩阵乘法的编写,也加快计算速度,减少指令发射耗时。张量核仅支持F16(最新架构也支持FP8的),不支持F32,这可能是FlashAttention不支持F32的一个重要原因

从PTX汇编来看,张量核的关键指令是

ldmatrix.sync.aligned.m8n8.x4.shared.b16{%r11,%r12,%r13,%r14},[%r80+2048];ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16{%r15,%r16,%r21,%r22},[%r89];mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32{%r7,%r8,%r9,%r10},{%r11,%r12,%r13,%r14},{%r15,%r16},{%r7,%r8,%r9,%r10};

对应SASS的汇编就是

LDSM.16.MT88.4 LDSM.16.M88.4 HMMA.16816.F32

PTX是中间指令,SASS是实际汇编,PTX的可读性比SASS高多了,并且有官方文档,但它不是最终结果,ncu也看不到torch的ptx

ldmatrix.sync.aligned.m8n8.x4.shared.b16从指令可以看出,这是从共享内存读取数据的,同步对齐读取,m=8,n=8,意思是一次读取8x8的数据,b16表示加载的是16bit的数据,.x4表示一次性读取4个寄存器,也就是4 × 8 × 8 4\times 8\times 84×8×8个数据,也有.x2这种指令

这个指令是整个线程束协同完成的,而且寄存器是32位,一个32位存放两个f16,这样一个线程束的一个寄存器就存放8 × 8 8\times 88×8条数据
8 × 8 32 × 2 = 1 \frac{8\times 8}{32\times 2}=132×28×8=1

顺便一提,f32转f16的汇编是F2FP.PACK_AB R114, R114, R113,明明是单个值转换,却有两个输入,这其实就是把两个f16打包到一个f32上,节省寄存器数量和指令数量

ldmatrix.sync.aligned.m8n8.trans.x4.shared.b16就是转置版的,应该是列优先

实际上mma计算的时候A B ⊤ AB^\topAB的时候反而不需要转置,A B ABAB的时候才需要转置

mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32这里m16n8k16名字很明确,是这样一个乘法
D = A B ⊤ + C C , D ∈ R 16 × 8 ; A ∈ R 16 × 16 ; B ∈ R 8 × 16 D = AB^\top+C\quad C,D\in \mathbb{R}^{16\times 8};A\in \mathbb{R}^{16\times 16};B\in \mathbb{R}^{8\times 16}D=AB+CC,DR16×8;AR16×16;BR8×16
它有四个操作数,分别对应D,A,B,C,ACD用4个寄存器,B用两个:
C , D : 16 × 8 32 × 1 = 4 A : 16 × 16 32 × 4 = 4 B : 16 × 8 32 × 2 = 2 C,D: \frac{16\times 8}{32\times 1}=4\\ A: \frac{16\times 16}{32 \times 4}=4\\ B: \frac{16\times 8}{32 \times 2}=2C,D:32×116×8=4A:32×416×16=4B:32×216×8=2

张量核指令数量计算

n_warps=1,K=16的情况

此时没有累加,没有分线程束,所以读取一次只会用于计算一次,B每次只用一半,所以是两次,得到

A加载次数是M/16,B加载次数是N/16,计算次数是

M N 16 × 8 \frac{MN}{16\times 8}16×8MN

n_warps>1,K=16的情况

要么固定A要么固定B,把另一个并行,比如2并行,把B并行就是这样

A [ B 0 B 1 ] A\left[\begin{matrix} B_0\\B_1\end{matrix}\right]A[B0B1]

所以加载次数是

M 16 + N 16 × n \frac{M}{16}+\frac{N}{16\times n}16M+16×nN

计算次数简单除以n

M N 16 × 8 × n \frac{MN}{16\times 8 \times n}16×8×nMN

n_warps=1,K≠16的情况

此时必须有累加

加载次数没有影响

M K 16 × 16 + N K 16 × 16 \frac{MK}{16\times 16}+\frac{NK}{16\times 16}16×16MK+16×16NK

计算增加

M N K 16 × 16 × 8 \frac{MNK}{16\times 16\times 8}16×16×8MNK

n_warps≠1,K≠16的情况

此时并行就需要注意,累加必须在同一个线程束中

所以虽然划分方向多了一个,但不能同时划分A和B,或者同时划分行列,只能还是按照行划分

加载次数和K=16的情况一致,计算次数按照上面计算

差异分析

torch的指令数量分析

64x128和64x128的乘积num_warps=4

加载次数

64 × 128 16 × 16 × 4 + 64 × 128 16 × 16 = 8 + 32 = 40 \frac{64\times 128}{16\times 16\times 4}+\frac{64\times 128}{16\times 16}=8+32=4016×16×464×128+16×1664×128=8+32=40

计算次数

64 × 128 × 64 16 × 8 × 16 × 4 = 64 \frac{64\times 128\times 64}{16\times 8\times 16 \times 4}=6416×8×16×464×128×64=64

观察发现,torch是对Q并行,而简单的triton是对K并行

然后计算attn V,这个过程attn没有加载,直接用寄存器,V则用的是 LDSM.xxx.trans 版本,加载次数简单除以大小和并行

attn V的计算次数和QK^top一致,都是64

所以torch的FlashAttnV2中有40条LDSM,32条LDSM.trans(显然是左侧并行,和QK^top一样),128条HMMA

但由于做了2阶段,所以全部乘2,80条LDSM,64条LDSM.trans,256条HMMA

triton指令数量分析

triton实现把attn存到共享内存,然后又加载出来,加载次数计算就是
64 × 64 16 × 16 = 16 \frac{64\times 64}{16\times 16}=1616×1664×64=16
这样就在基础的40上又增加16条,2倍就是112条,然后triton是V并行,V的LDSM.trans加载

64 × 128 16 × 16 × 4 = 8 \frac{64\times 128}{16\times 16\times 4}=816×16×464×128=8

2倍就是16条

查了一圈,triton好像tl.dot要强制加载共享内存,不能直接由寄存器计算,这里可能有一些代价

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/9 3:42:29

Cursor插件no-secrets:编码时实时检测API密钥泄露的AI助手

1. 项目概述:为什么我们需要一个“代码守门人”?在代码仓库里意外提交一个API密钥或者数据库连接字符串,这事儿听起来像是新手才会犯的错误,但说实话,我见过太多经验丰富的开发者,包括我自己,都…

作者头像 李华
网站建设 2026/5/9 3:41:31

构建个人技能知识库:模块化设计与实践指南

1. 项目概述:一个技能聚合与管理的开源工具箱 最近在GitHub上闲逛时,发现了一个名为“mega-mind-skills”的项目,作者是k1lgor。这个标题本身就挺有意思的,直译过来是“超级大脑技能”。点进去一看,发现它并非一个单一…

作者头像 李华
网站建设 2026/5/9 3:41:28

破解研发数字化转型中的协同效率瓶颈

在制造业研发数字化转型的浪潮中,产品生命周期管理系统的选型,已成为企业突破研发协同效率瓶颈、迈向协同创新的关键决策。这一选择不仅关乎一套软件工具的引入,更是一场涉及流程再造、数据治理与组织协同的战略规划。本文将探讨如何规划一条…

作者头像 李华
网站建设 2026/5/9 3:33:43

多模态AI框架MMClaw:从编码融合到实战部署全解析

1. 项目概述:一个面向多模态内容理解的“机械爪” 最近在折腾一些多模态项目时,发现一个挺有意思的仓库,叫 leadersboat/MMClaw 。光看名字, MM 大概率指的是 Multimodal(多模态) ,而 Cl…

作者头像 李华
网站建设 2026/5/9 3:24:52

示波器探头核心原理与工程实践:从负载效应到高频测量避坑指南

1. 从一份老测验聊起:为什么你的示波器读数总是不准?前几天在整理资料时,翻到一份2016年EE Times上的“周五小测验”,主题是“示波器探头”。测验本身只有六个选择题,但底下工程师们的讨论却很有意思。一位叫David Ash…

作者头像 李华