从零开始写Qwen3目录
概述
经过前文的提速,耗时已经从官方的214%降低到112%,本文将从汇编角度猜测一下差距的原因
概述
使用上一节的输入参数,设置为BM=BN=64,和torch相同,分析汇编指令
torch的指令统计如下
triton实现的指令统计如下
HMMA 是 Half Matrix Multiply Accumulation的意思,这是FlashAttn的核心指令,使用张量核进行矩阵乘法加速,对比两个统计,发现不管是指令数量还是实际执行次数都是一样的,差别可能在共享内存加载部分
指令执行次数分析
单条汇编执行次数常见有这么几种数字:
- 2048
- 15360
2048是无循环的执行次数,15360是执行循环的次数
2048 = 2 ⏟ B × 16 ⏟ H × 1024 64 ⏟ Q × 4 ⏟ n w a r p s 2048 = \underbrace{2}_{B} \times \underbrace{16}_{H}\times \underbrace{\frac{1024}{64}}_{Q}\times \underbrace{4}_{nwarps}2048=B2×H16×Q641024×nwarps4
可以计算
15360 = 2048 × 16 − 1 2 15360=2048\times \frac{16-1}{2}15360=2048×216−1
所以2048是阶段2和公共部分的执行次数,15360是阶段2的执行次数,阶段2平均循环了7.5次,两个阶段指令数基本一致(除了因果遮罩那里,阶段1没有),所以平均执行次数是8704 87048704
张量核
张量核是CUDA从Volta开始引入的一个指令,专门用于矩阵加速,它用一条指令让一个线程束一起完成一个小块矩阵乘法,不仅简化了矩阵乘法的编写,也加快计算速度,减少指令发射耗时。张量核仅支持F16(最新架构也支持FP8的),不支持F32,这可能是FlashAttention不支持F32的一个重要原因
从PTX汇编来看,张量核的关键指令是
ldmatrix.sync.aligned.m8n8.x4.shared.b16{%r11,%r12,%r13,%r14},[%r80+2048];ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16{%r15,%r16,%r21,%r22},[%r89];mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32{%r7,%r8,%r9,%r10},{%r11,%r12,%r13,%r14},{%r15,%r16},{%r7,%r8,%r9,%r10};对应SASS的汇编就是
LDSM.16.MT88.4 LDSM.16.M88.4 HMMA.16816.F32PTX是中间指令,SASS是实际汇编,PTX的可读性比SASS高多了,并且有官方文档,但它不是最终结果,ncu也看不到torch的ptx
ldmatrix.sync.aligned.m8n8.x4.shared.b16从指令可以看出,这是从共享内存读取数据的,同步对齐读取,m=8,n=8,意思是一次读取8x8的数据,b16表示加载的是16bit的数据,.x4表示一次性读取4个寄存器,也就是4 × 8 × 8 4\times 8\times 84×8×8个数据,也有.x2这种指令
这个指令是整个线程束协同完成的,而且寄存器是32位,一个32位存放两个f16,这样一个线程束的一个寄存器就存放8 × 8 8\times 88×8条数据
8 × 8 32 × 2 = 1 \frac{8\times 8}{32\times 2}=132×28×8=1
顺便一提,f32转f16的汇编是F2FP.PACK_AB R114, R114, R113,明明是单个值转换,却有两个输入,这其实就是把两个f16打包到一个f32上,节省寄存器数量和指令数量
ldmatrix.sync.aligned.m8n8.trans.x4.shared.b16就是转置版的,应该是列优先
实际上mma计算的时候A B ⊤ AB^\topAB⊤的时候反而不需要转置,A B ABAB的时候才需要转置
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32这里m16n8k16名字很明确,是这样一个乘法
D = A B ⊤ + C C , D ∈ R 16 × 8 ; A ∈ R 16 × 16 ; B ∈ R 8 × 16 D = AB^\top+C\quad C,D\in \mathbb{R}^{16\times 8};A\in \mathbb{R}^{16\times 16};B\in \mathbb{R}^{8\times 16}D=AB⊤+CC,D∈R16×8;A∈R16×16;B∈R8×16
它有四个操作数,分别对应D,A,B,C,ACD用4个寄存器,B用两个:
C , D : 16 × 8 32 × 1 = 4 A : 16 × 16 32 × 4 = 4 B : 16 × 8 32 × 2 = 2 C,D: \frac{16\times 8}{32\times 1}=4\\ A: \frac{16\times 16}{32 \times 4}=4\\ B: \frac{16\times 8}{32 \times 2}=2C,D:32×116×8=4A:32×416×16=4B:32×216×8=2
张量核指令数量计算
n_warps=1,K=16的情况
此时没有累加,没有分线程束,所以读取一次只会用于计算一次,B每次只用一半,所以是两次,得到
A加载次数是M/16,B加载次数是N/16,计算次数是
M N 16 × 8 \frac{MN}{16\times 8}16×8MN
n_warps>1,K=16的情况
要么固定A要么固定B,把另一个并行,比如2并行,把B并行就是这样
A [ B 0 B 1 ] A\left[\begin{matrix} B_0\\B_1\end{matrix}\right]A[B0B1]
所以加载次数是
M 16 + N 16 × n \frac{M}{16}+\frac{N}{16\times n}16M+16×nN
计算次数简单除以n
M N 16 × 8 × n \frac{MN}{16\times 8 \times n}16×8×nMN
n_warps=1,K≠16的情况
此时必须有累加
加载次数没有影响
M K 16 × 16 + N K 16 × 16 \frac{MK}{16\times 16}+\frac{NK}{16\times 16}16×16MK+16×16NK
计算增加
M N K 16 × 16 × 8 \frac{MNK}{16\times 16\times 8}16×16×8MNK
n_warps≠1,K≠16的情况
此时并行就需要注意,累加必须在同一个线程束中
所以虽然划分方向多了一个,但不能同时划分A和B,或者同时划分行列,只能还是按照行划分
加载次数和K=16的情况一致,计算次数按照上面计算
差异分析
torch的指令数量分析
64x128和64x128的乘积num_warps=4
加载次数
64 × 128 16 × 16 × 4 + 64 × 128 16 × 16 = 8 + 32 = 40 \frac{64\times 128}{16\times 16\times 4}+\frac{64\times 128}{16\times 16}=8+32=4016×16×464×128+16×1664×128=8+32=40
计算次数
64 × 128 × 64 16 × 8 × 16 × 4 = 64 \frac{64\times 128\times 64}{16\times 8\times 16 \times 4}=6416×8×16×464×128×64=64
观察发现,torch是对Q并行,而简单的triton是对K并行
然后计算attn V,这个过程attn没有加载,直接用寄存器,V则用的是 LDSM.xxx.trans 版本,加载次数简单除以大小和并行
attn V的计算次数和QK^top一致,都是64
所以torch的FlashAttnV2中有40条LDSM,32条LDSM.trans(显然是左侧并行,和QK^top一样),128条HMMA
但由于做了2阶段,所以全部乘2,80条LDSM,64条LDSM.trans,256条HMMA
triton指令数量分析
triton实现把attn存到共享内存,然后又加载出来,加载次数计算就是
64 × 64 16 × 16 = 16 \frac{64\times 64}{16\times 16}=1616×1664×64=16
这样就在基础的40上又增加16条,2倍就是112条,然后triton是V并行,V的LDSM.trans加载
64 × 128 16 × 16 × 4 = 8 \frac{64\times 128}{16\times 16\times 4}=816×16×464×128=8
2倍就是16条
查了一圈,triton好像tl.dot要强制加载共享内存,不能直接由寄存器计算,这里可能有一些代价