⚡ 【FlashAttention】原理深度解析:IO-Aware 算法为什么能快 3 倍、省 10 倍显存
文章目录
- ⚡ 【FlashAttention】原理深度解析:IO-Aware 算法为什么能快 3 倍、省 10 倍显存
- 📖 第一章:问题的根源——GPU 的存储层次
- GPU 内存的两层结构
- 标准 Attention 是 Memory-Bound 的
- 🧩 第二章:FlashAttention 的核心思路
- 分块计算(Tiling)
- 难点:分块 Softmax 怎么做
- 🔢 第三章:FlashAttention 完整算法
- 前向传播
- 用 Python 展示核心逻辑
- 💾 第四章:IO 复杂度分析
- 标准 Attention vs FlashAttention
- 为什么省显存
- 🔄 第五章:反向传播——重计算(Recomputation)
- 📊 第六章:实际性能数字
- 用 PyTorch 调用 FlashAttention
- 🚀 第七章:FlashAttention-2 和 FlashAttention-3
- FlashAttention-2 的改进(2023)
- FlashAttention-3(2024,H100 专用)
- 🎯 第八章:面试高频问题
- Q1:FlashAttention 为什么快?不是因为减少了计算量
- Q2:Online Softmax 是什么?为什么分块 Attention 需要它?
- Q3:FlashAttention 显存复杂度为什么是 O(N)?
- Q4:FlashAttention 对训练和推理的影响有什么不同?
- 🎁 总结速查
- 📣 最后
写在前面:训练 GPT-3 级别的模型,标准 Attention 的显存会直接爆掉——不是因为参数太多,而是因为注意力矩阵太大。一个 2048 长度的序列,单头的注意力矩阵就是 2048×2048 的 float16,占 16MB;32 个头的一个 batch,就是 512MB——光这一层的中间结果就把显存打满了。FlashAttention 解决的不是计算量问题,而是内存访问(IO)问题。它通过分块(Tiling)让整个 Attention 计算都在 SRAM 里完成,大幅减少对 HBM 的读写,速度快 2-3 倍,显存减少 5-20 倍。这篇从硬件原理讲起,把 FlashAttention 的每一步都推导清楚。
📖 第一章:问题的根源——GPU 的存储层次
GPU 内存的两层结构
理解 FlashAttention,必须先理解 GPU 的存储体系。GPU 有两种主要存储:
HBM(High Bandwidth Memory,高带宽内存) ├── 容量:大(A100 = 40GB 或 80GB) ├── 带宽:高(A100 = 2TB/s) ├── 延迟:高(读写慢) └── 位置:在 GPU 芯片外 SRAM(Static RAM,静态随机访问内存) ├── 容量:极小(A100 每个 SM = 192KB,共 108 个 SM,总计约 20MB) ├── 带宽:极高(~19TB/s,比 HBM 快约 10 倍) ├── 延迟:极低(读写极快) └── 位置:在 GPU 芯片上(即 L1/L2 Cache + Shared Memory)关键洞察:GPU 的算力(FLOPS)远比带宽增长快得多。
A100 GPU:
算力:312 TFLOPS(float16)
HBM 带宽:2 TB/s
算术强度(Arithmetic Intensity)= FLOPS / Bytes
如果一个运算的算术强度太低(需要大量读写,但计算少)
→ 受内存带宽限制(Memory-Bound)
→ GPU 的算力大量浪费在等待数据上
标准 Attention 是 Memory-Bound 的
标准 Scaled Dot-Product Attention:
输入:Q, K, V ∈ R^{N×d}(N = 序列长度,d = 每头维度)
Step 1:S = QK^T / √d 计算注意力分数 形状 (N, N)
Step 2:P = softmax(S) softmax 形状 (N, N)
Step 3:O = PV 输出 形状 (N, d)
HBM 访问量分析:
读:Q, K, V → O(Nd)
写:S(N×N)→ HBM ← 这是瓶颈!
读:S 再读回来做 softmax
写:P(N×N)→ HBM
读:P 再读回来乘 V
总 HBM 访问量 ≈ O(N²)
N = 2048,d = 64,float16:
S 矩阵大小 = 2048 × 2048 × 2 bytes = 16 MB(单头)
32 头 = 512 MB
反向传播还要保存 S、P 用于梯度计算
→ 一个中等长度序列,显存直接打满
计算量和 IO 量的矛盾:
importnumpyasnp N=2048# 序列长度d=64# 头维度dtype_bytes=2# float16# 计算量(FLOPs)# QK^T: (N,d) × (d,N) → 2*N*N*d FLOPs# softmax: ~5*N*N FLOPs# PV: (N,N) × (N,d) → 2*N*N*d FLOPscompute=4*N*N*d# ≈ 1.07 GFLOPs# HBM 访问量(Bytes)# 写 S:N×N# 读 S 做 softmax:N×N# 写 P:N×N# 读 P 做 PV:N×Nio=4*N*N*dtype_bytes# ≈ 32 MBarithmetic_intensity=compute/io# FLOPs / Byteprint(f"算术强度:{arithmetic_intensity:.1f}FLOPs/Byte")# 约 33.5 FLOPs/Byte# A100 的计算/带宽上限# 312 TFLOPS / 2 TB/s = 156 FLOPs/Byte# 算术强度 33.5 << 156 → Memory-Bound!# GPU 算力只用了 33.5/156 ≈ 21%,其余时间在等内存🧩 第二章:FlashAttention 的核心思路
分块计算(Tiling)
FlashAttention 的核心思路一句话:把 Q、K、V 切成小块,每次只把一小块加载到 SRAM,在 SRAM 里完成所有计算,不把注意力矩阵 S 和 P 写回 HBM。
标准 Attention: Q, K, V → HBM → 计算 S (N×N) → 写回 HBM → 读 S → 计算 P (N×N) → 写回 HBM → 读 P → 计算 O → 写回 HBM FlashAttention: Q, K, V 切块 → SRAM → 分块计算注意力 → 直接在 SRAM 里得到 O 的块 → 写回 HBM N×N 的 S、P 矩阵从未出现在 HBM 中!分块大小的计算:
# SRAM 大小约束SRAM_size=20*1024*1024# 20 MB(A100 粗略估计)dtype_bytes=2# float16# 分块大小 B# 每个 SRAM 需要存:Q_block(B×d) + K_block(B×d) + V_block(B×d) + O_block(B×d)# 4 × B × d × 2 bytes ≤ SRAM_sized=64B=SRAM_size//(4*d*dtype_bytes)print(f"最大块大小 B ≈{B}")# ≈ 40960,远大于 N=2048 时无需分块# 对于超长序列(N=100K+),确实需要分块# B 的选择是关键超参数,影响 SRAM 利用率难点:分块 Softmax 怎么做
分块计算 QK^T 很简单,但 Softmax 需要全局信息(分母需要对所有 key 求和),这是分块计算最大的难点。
Online Softmax(在线 Softmax)是解决方案。
# 问题:standard softmax 需要两遍扫描defsoftmax_standard(x):# 第一遍:找最大值(数值稳定)m=max(x)# 第二遍:计算 exp 和归一化exp_x=[exp(xi-m)forxiinx]Z=sum(exp_x)return[ei/Zforeiinexp_x]# 如果分块处理,每次只看一块,不知道全局最大值!# 解决:Online Softmax(Milakov & Gimelshein, 2018)# 维护两个统计量,随着看到新的块不断更新:# m_i:目前见过的最大值# l_i:目前的归一化分母(已修正)defonline_softmax_update(m_prev,l_prev,s_new_block):""" 处理新的一块 scores,更新统计量 m_prev:之前块的最大值 l_prev:之前块的归一化分母(基于 m_prev) s_new_block:新块的原始分数 """# 新块的最大值m_new=max(m_prev,max(s_new_block))# 更新归一化分母# 旧的 exp(x_i - m_prev) 需要修正为 exp(x_i - m_new)# exp(x_i - m_new) = exp(x_i - m_prev) * exp(m_prev - m_new)l_new=(l_prev*exp(m_prev-m_new)+sum(exp(xi-m_new)forxiins_new_block))returnm_new,l_new🔢 第三章:FlashAttention 完整算法
前向传播
输入:Q, K, V ∈ R^{N×d},存在 HBM
输出:O ∈ R^{N×d}
算法流程: ① 设定块大小 B_c, B_r(由 SRAM 大小决定) B_c = ⌈M/(4d)⌉,B_r = min(⌈M/(4d)⌉, d) (M = SRAM 大小) ② 将 Q 分成 T_r 块:Q_1, ..., Q_{T_r},每块大小 B_r × d 将 K, V 分成 T_c 块:K_1, ..., K_{T_c},每块大小 B_c × d ③ 初始化 O = 0(N×d),l = 0(N),m = -∞(N) 以上存在 HBM ④ 外循环(对 K, V 的块): for j = 1 to T_c: 从 HBM 加载 K_j, V_j 到 SRAM 内循环(对 Q 的块): for i = 1 to T_r: 从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM 在 SRAM 中计算: S_ij = Q_i K_j^T / √d ← (B_r × B_c) 的小矩阵 m̃_ij = max(S_ij) ← 当前块的最大值(行最大) P̃_ij = exp(S_ij - m̃_ij) ← 当前块的 exp l̃_ij = rowsum(P̃_ij) ← 当前块的 sum 更新统计量: m_i_new = max(m_i, m̃_ij) l_i_new = exp(m_i - m_i_new) * l_i + exp(m̃_ij - m_i_new) * l̃_ij 更新输出(关键!修正之前的 O_i): O_i = (l_i * exp(m_i - m_i_new) * O_i + exp(m̃_ij - m_i_new) * P̃_ij V_j) / l_i_new 将 O_i, l_i_new, m_i_new 写回 HBM ⑤ 返回 O用 Python 展示核心逻辑
importtorchimportmathdefflash_attention_forward(Q,K,V,block_size=64):""" FlashAttention 前向传播(教学版本,非真实 CUDA 实现) Q, K, V: (N, d) 输出: O: (N, d) """N,d=Q.shape scale=1.0/math.sqrt(d)# 初始化输出和统计量(存在 HBM)O=torch.zeros(N,d,dtype=Q.dtype)l=torch.zeros(N)# 归一化分母m=torch.full((N,),-float('inf'))# 行最大值# K, V 分块forjinrange(0,N,block_size):K_j=K[j:j+block_size]# (B_c, d),加载到 SRAMV_j=V[j:j+block_size]# (B_c, d),加载到 SRAMB_c=K_j.shape[0]# Q 分块foriinrange(0,N,block_size):Q_i=Q[i:i+block_size]# (B_r, d),加载到 SRAMO_i=O[i:i+block_size]# 从 HBM 加载l_i=l[i:i+block_size]# 当前块的归一化分母m_i=m[i:i+block_size]# 当前块的行最大值B_r=Q_i.shape[0]# 在 SRAM 中计算(不写回 HBM)S_ij=Q_i @ K_j.T*scale# (B_r, B_c)# 当前块的 max 和 expm_tilde=S_ij.max(dim=-1).values# (B_r,)P_tilde=torch.exp(S_ij-m_tilde.unsqueeze(-1))# (B_r, B_c)l_tilde=P_tilde.sum(dim=-1)# (B_r,)# 更新全局统计量(Online Softmax)m_new=torch.maximum(m_i,m_tilde)# (B_r,)l_new=(torch.exp(m_i-m_new)*l_i+torch.exp(m_tilde-m_new)*l_tilde)# (B_r,)# 更新输出(修正之前的 O_i 并加上新贡献)alpha=torch.exp(m_i-m_new)# 旧贡献的修正系数beta=torch.exp(m_tilde-m_new)# 新贡献的系数O_i_new=(alpha.unsqueeze(-1)*l_i.unsqueeze(-1)*O_i+beta.unsqueeze(-1)*P_tilde @ V_j)/l_new.unsqueeze(-1)# 写回 HBM(只有 O, l, m,没有 S 和 P!)O[i:i+block_size]=O_i_new l[i:i+block_size]=l_new m[i:i+block_size]=m_newreturnO# 验证正确性defstandard_attention(Q,K,V):scale=1.0/math.sqrt(Q.shape[-1])S=Q @ K.T*scale P=torch.softmax(S,dim=-1)returnP @ V# 测试torch.manual_seed(42)N,d=128,64Q=torch.randn(N,d)K=torch.randn(N,d)V=torch.randn(N,d)O_flash=flash_attention_forward(Q,K,V,block_size=32)O_standard=standard_attention(Q,K,V)print(f"最大误差:{(O_flash-O_standard).abs().max().item():.2e}")# 最大误差: ~1e-5(数值误差,在 float32 精度范围内)print(f"结果一致:{torch.allclose(O_flash,O_standard,atol=1e-4)}")# True💾 第四章:IO 复杂度分析
标准 Attention vs FlashAttention
标准 Attention 的 HBM 访问量: 操作 读(bytes) 写(bytes) Q, K, V 初始读 3 × N × d × 2 - 写 S - N² × 2 读 S 做 softmax N² × 2 - 写 P - N² × 2 读 P N² × 2 - 写 O - N × d × 2 总计 ≈ Θ(Nd + N²),序列长时 N² 项主导 N=2048, d=64: 约 32 MB(单头,仅 QK^T 部分) ──────────────────────────────────────────────── FlashAttention 的 HBM 访问量: 操作 读(bytes) 写(bytes) K, V 每块加载 N × d × 2 -(每外循环轮) Q, O, l, m 每块加载 N × d × 2 -(每内循环轮) 写回 O, l, m - N × d × 2 总计 ≈ Θ(N²d/M) (M = SRAM 大小,N²d/M 通常远小于 N²,因为 M >> d) 对比: 标准 Attention:O(Nd + N²) FlashAttention:O(N²d/M) M ≈ 20MB,d = 64,2 bytes → M/d ≈ 163840 N²/(N²d/M)= M/d ≈ 10 倍 → FlashAttention 的 HBM 访问量约是标准 Attention 的 1/10!为什么省显存
# 显存占用对比N=4096# 序列长度d=64# 头维度h=32# 头数B=1# batch sizedtype_bytes=2# float16defmemory_usage_standard(N,d,h,B):"""标准 Attention 显存占用(MB)"""Q_K_V=3*B*h*N*d*dtype_bytes# 输入S_mat=B*h*N*N*dtype_bytes# 注意力矩阵(训练时保存用于反向传播)P_mat=B*h*N*N*dtype_bytes# softmax 后的矩阵O_mat=B*h*N*d*dtype_bytes# 输出total=Q_K_V+S_mat+P_mat+O_matreturntotal/1024**2defmemory_usage_flash(N,d,h,B):"""FlashAttention 显存占用(MB)"""Q_K_V=3*B*h*N*d*dtype_bytes# 输入O_mat=B*h*N*d*dtype_bytes# 输出L_mat=B*h*N*dtype_bytes# logsumexp 统计量(很小)# S 和 P 不需要保存!反向传播时重新计算total=Q_K_V+O_mat+L_matreturntotal/1024**2print(f"N={N}, h={h}:")print(f" 标准 Attention 显存:{memory_usage_standard(N,d,h,B):.1f}MB")print(f" FlashAttention 显存:{memory_usage_flash(N,d,B,B):.1f}MB")# N=4096, h=32:# 标准 Attention 显存: ~4096 MB(仅 S、P 就 4GB!)# FlashAttention 显存: ~384 MB(减少 10x+)🔄 第五章:反向传播——重计算(Recomputation)
FlashAttention 训练时,反向传播不保存 S 和 P,而是在需要时重新计算(Recomputation)。
标准 Attention 反向传播: 前向:计算并保存 S、P(用于反向传播求梯度) 反向:读取保存的 S、P → 计算 dQ、dK、dV 显存:O(N²)(需要保存 S、P) FlashAttention 反向传播: 前向:只保存 O 和 logsumexp(L = m + log(l)) 反向:从 Q、K、V 和 L 重新计算 S、P(在 SRAM 中完成) 然后立即计算梯度,不需要把 S、P 写到 HBM 显存:O(N)(只需要 L 向量,O(N) 大小) 代价:额外的计算量(大约 1.33x FLOPs) 权衡:多 33% 计算,换来 10x 显存节省 + 更快的速度(IO 减少更多)# 保存的统计量:logsumexp L# L_i = m_i + log(l_i)# 反向时从 L 重构 softmax:P_ij = exp(S_ij - L_i)defrecompute_attention(Q_i,K_j,V_j,L_i,scale):"""反向传播时在 SRAM 中重新计算注意力"""S_ij=Q_i @ K_j.T*scale# 重新计算分数P_ij=torch.exp(S_ij-L_i.unsqueeze(-1))# 从 L 重构 softmax 输出# 现在用 P_ij 计算梯度...returnP_ij📊 第六章:实际性能数字
FlashAttention 在 A100 上的实测数据(来自论文): 序列长度 vs 速度(前向 + 反向,Batch=8,头数=12,d=64) N=512: 标准 Attention:~150 GFLOPS/s(利用率 ~0.5%) FlashAttention:~350 GFLOPS/s 加速比:~2.3x N=1024: 标准 Attention:~200 GFLOPS/s FlashAttention:~600 GFLOPS/s 加速比:~3.0x N=2048: 标准 Attention:~180 GFLOPS/s FlashAttention:~650 GFLOPS/s 加速比:~3.6x 规律:序列越长,加速比越大 原因:N 越大,IO 瓶颈越严重,FlashAttention 节省的 IO 越多 显存对比(N=2048,Batch=8,12 头): 标准 Attention:~3.1 GB FlashAttention:~0.3 GB 节省:~10x用 PyTorch 调用 FlashAttention
importtorchimporttorch.nn.functionalasF# PyTorch 2.0+ 内置了 FlashAttention# scaled_dot_product_attention 会自动选择最优实现Q=torch.randn(2,8,1024,64,device='cuda',dtype=torch.float16)K=torch.randn(2,8,1024,64,device='cuda',dtype=torch.float16)V=torch.randn(2,8,1024,64,device='cuda',dtype=torch.float16)# 方式1:PyTorch 内置(自动 Flash)withtorch.backends.cuda.sdp_kernel(enable_flash=True,# 启用 FlashAttentionenable_math=False,# 关闭标准实现enable_mem_efficient=False,):O=F.scaled_dot_product_attention(Q,K,V)# 方式2:直接安装 flash-attn 库# pip install flash-attn --no-build-isolationfromflash_attnimportflash_attn_qkvpacked_func,flash_attn_func# QKV packed 格式(省一次内存操作)qkv=torch.randn(2,1024,3,8,64,device='cuda',dtype=torch.float16)O_packed=flash_attn_qkvpacked_func(qkv,dropout_p=0.0,causal=True)# 分别传 Q K VO_sep=flash_attn_func(Q.transpose(1,2),# (batch, seq, heads, dim)K.transpose(1,2),V.transpose(1,2),dropout_p=0.0,causal=True,# 因果掩码(GPT 类模型))# 检查是否用了 FlashAttentionprint(torch.backends.cuda.flash_sdp_enabled())# True🚀 第七章:FlashAttention-2 和 FlashAttention-3
FlashAttention-2 的改进(2023)
FlashAttention-1 的问题: GPU 核心有两种并行单元: warp(线程束):32 个线程组成一个 warp 多个 warp 共享 SRAM FlashAttention-1 的工作分配:不够均匀,部分 warp 空闲 FlashAttention-2 的改进: ① 减少非矩阵乘法操作(rescaling 等)的频率 → 对 GPU 更友好(矩阵乘法有专用硬件 Tensor Core) ② 更好的工作分配: - 外循环改为 Q(之前是 K/V) - 每个 warp 负责一段 Q,减少 warp 间通信 - 不同序列块可以在不同 SM 上并行 ③ 多查询注意力(MQA)和分组查询注意力(GQA)原生支持 性能提升: 相比 FlashAttention-1:~2x 加速 相比标准 Attention:训练 ~6x,推理 ~3xFlashAttention-3(2024,H100 专用)
H100 新特性: ① Tensor Core 第 4 代:支持 FP8 ② WGMMA(Warpgroup Matrix Multiply-Accumulate) ③ TMA(Tensor Memory Accelerator):异步内存传输 FlashAttention-3 的改进: ① 生产者-消费者异步流水线: 一部分 warp 异步加载数据(使用 TMA)另一部分 warp 同时在计算(不等待加载完成) → 内存加载和计算重叠,充分利用带宽
我们可以通过手动执行一些调度来改进这一点。例如,如果我们有两个 warpgroup(分别标记为 1 和 2——每个 warpgroup 由 4 个 warp 组成),我们可以使用同步屏障(bar.sync),使 warpgroup 1 首先执行其 GEMM 操作(例如,一次迭代的 GEMM1 和下一次迭代的 GEMM0),然后 warpgroup 2 执行其 GEMM 操作,同时 warpgroup 1 执行其 softmax 操作,依此类推。
② FP8 支持: 更低精度,进一步提升吞吐量 ③ 归一化/softmax 低精度近似 性能: A100 上 FA-2 达到 ~72% 的 MFU(模型 FLOPs 利用率) H100 上 FA-3 达到 ~75% MFU(fp16),fp8 更高🎯 第八章:面试高频问题
Q1:FlashAttention 为什么快?不是因为减少了计算量
这是最常见的误解! FlashAttention 的 FLOPs(计算量)和标准 Attention 基本相同 甚至反向传播时多了约 33% 的 FLOPs(用于重计算 S、P) 它快的原因:减少了 HBM 访问(IO-Aware) 标准 Attention: HBM 读写量 ≈ O(N²) 算术强度 ≈ 33 FLOPs/Byte → Memory-Bound FlashAttention: HBM 读写量 ≈ O(N²d/M),减少约 10x → 算术强度提高,更接近 Compute-Bound → GPU 算力利用率从 21% 提升到 70%+Q2:Online Softmax 是什么?为什么分块 Attention 需要它?
问题:Softmax 需要全局信息 softmax(x_i) = exp(x_i) / Σ_j exp(x_j) 计算 x_i 的 softmax 需要知道所有 j 的 x_j 分块时:只看到部分 x,全局 sum 未知 Online Softmax 解决方案: 维护两个统计量: m:目前见过的行最大值 l:目前的修正后 exp 的和(相对于 m 的) 每看到新块(新的 x_{j'}),更新: m_new = max(m, max(x_{j'})) l_new = l × exp(m - m_new) + Σ exp(x_{j'} - m_new) 最终:softmax(x_i) = exp(x_i - m_final) / l_final 关键:O 也随之更新(rescaling),保证最终输出等价于标准 softmaxQ3:FlashAttention 显存复杂度为什么是 O(N)?
标准 Attention 显存:O(N²) → 需要保存 N×N 的 S 和 P 矩阵用于反向传播 FlashAttention 显存:O(N) → 不保存 S 和 P → 只保存 logsumexp L = m + log(l),O(N) 大小 → 反向传播时从 Q、K、V 重新计算 S、P(在 SRAM 中) 代价:多 33% FLOPs(重计算 S、P) 收益:10x 显存减少 + 3x 速度提升(因为 IO 减少) 权衡:显然值得Q4:FlashAttention 对训练和推理的影响有什么不同?
训练(前向+反向): 加速:~3x(主要来自减少 HBM 读写) 显存:减少 ~10x(不存 S、P,只存 logsumexp) 代价:额外计算(重计算 S、P,+33% FLOPs) 结论:训练是 IO-Bound 的,额外计算换来的 IO 节省值得 推理(只有前向): 加速:~1.5-2x(反向传播的重计算不存在) 显存:减少 ~3-5x(不存 S、P,但推理时 S、P 本身就不需要保存) 推理的主要瓶颈:KV Cache 的读写(FlashAttention 配合 PagedAttention 效果更好) 长序列(N > 8K): 加速比进一步增大 FlashAttention 使超长上下文(100K+)成为可能 → Claude 的 200K 上下文、Gemini 的 1M 上下文都依赖 FlashAttention🎁 总结速查
FlashAttention 核心原理: 1. 问题: 标准 Attention 的 N×N 注意力矩阵反复读写 HBM → IO-Bound,GPU 算力浪费 2. 解决:Tiling(分块)+ Online Softmax 把 Q、K、V 切块加载到 SRAM(~10x 带宽) 在 SRAM 里完成计算(不写 S、P 到 HBM) 用 Online Softmax 处理分块带来的全局信息问题 3. 效果: 速度:训练 ~3x,推理 ~1.5-2x 显存:减少 ~10x(训练),~3x(推理) IO:减少 ~10x(从 O(N²) 到 O(N²d/M)) 4. 代价: 额外计算:+33% FLOPs(反向传播重计算) 复杂性:实现更复杂(CUDA kernel 级别) 5. 使用: PyTorch 2.0+ 内置: F.scaled_dot_product_attention(Q, K, V) flash-attn 库: pip install flash-attn from flash_attn import flash_attn_func 6. 演进: FA-1(2022):基础分块 + Online Softmax FA-2(2023):更好的并行化,~2x 于 FA-1 FA-3(2024):H100 异步流水线 + FP8📣 最后
如果这篇让你从"知道 FlashAttention 很快"升级到"理解为什么快":
- 👍点赞让更多做大模型训练的同学看到
- ⭐收藏面试被问 FlashAttention 原理时翻出来
- 🔔关注持续更新 AI 算法原理,一个正在学 AI 的大学生 👨🎓
下期预告:《Speculative Decoding 投机解码:让推理快 2-3 倍的正确姿势》
📚相关阅读:
- 《Attention 注意力机制演化全图:从 MHA 到 GQA、MLA、SWA》
- 《KV Cache 原理与优化:推理成本和序列长度的关系》
- 《大模型面试题精选 50 问》
- 《最小成本体验集群大模型训练实战》
📖参考资料:
- Dao et al., 2022:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness(arXiv:2205.14135)
- Dao et al., 2023:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Shah et al., 2024:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
- Milakov & Gimelshein, 2018:Online normalizer calculation for softmax