news 2026/4/30 4:09:23

【Flash Attention】原理深度解析:IO-Aware 算法为什么能快 3 倍、省 10 倍显存

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【Flash Attention】原理深度解析:IO-Aware 算法为什么能快 3 倍、省 10 倍显存

⚡ 【FlashAttention】原理深度解析:IO-Aware 算法为什么能快 3 倍、省 10 倍显存

文章目录

  • ⚡ 【FlashAttention】原理深度解析:IO-Aware 算法为什么能快 3 倍、省 10 倍显存
    • 📖 第一章:问题的根源——GPU 的存储层次
      • GPU 内存的两层结构
      • 标准 Attention 是 Memory-Bound 的
    • 🧩 第二章:FlashAttention 的核心思路
      • 分块计算(Tiling)
      • 难点:分块 Softmax 怎么做
    • 🔢 第三章:FlashAttention 完整算法
      • 前向传播
      • 用 Python 展示核心逻辑
    • 💾 第四章:IO 复杂度分析
      • 标准 Attention vs FlashAttention
      • 为什么省显存
    • 🔄 第五章:反向传播——重计算(Recomputation)
    • 📊 第六章:实际性能数字
      • 用 PyTorch 调用 FlashAttention
    • 🚀 第七章:FlashAttention-2 和 FlashAttention-3
      • FlashAttention-2 的改进(2023)
      • FlashAttention-3(2024,H100 专用)
    • 🎯 第八章:面试高频问题
      • Q1:FlashAttention 为什么快?不是因为减少了计算量
      • Q2:Online Softmax 是什么?为什么分块 Attention 需要它?
      • Q3:FlashAttention 显存复杂度为什么是 O(N)?
      • Q4:FlashAttention 对训练和推理的影响有什么不同?
    • 🎁 总结速查
    • 📣 最后

写在前面:训练 GPT-3 级别的模型,标准 Attention 的显存会直接爆掉——不是因为参数太多,而是因为注意力矩阵太大。一个 2048 长度的序列,单头的注意力矩阵就是 2048×2048 的 float16,占 16MB;32 个头的一个 batch,就是 512MB——光这一层的中间结果就把显存打满了。FlashAttention 解决的不是计算量问题,而是内存访问(IO)问题。它通过分块(Tiling)让整个 Attention 计算都在 SRAM 里完成,大幅减少对 HBM 的读写,速度快 2-3 倍,显存减少 5-20 倍。这篇从硬件原理讲起,把 FlashAttention 的每一步都推导清楚。


📖 第一章:问题的根源——GPU 的存储层次

GPU 内存的两层结构

理解 FlashAttention,必须先理解 GPU 的存储体系。GPU 有两种主要存储:

HBM(High Bandwidth Memory,高带宽内存) ├── 容量:大(A100 = 40GB 或 80GB) ├── 带宽:高(A100 = 2TB/s) ├── 延迟:高(读写慢) └── 位置:在 GPU 芯片外 SRAM(Static RAM,静态随机访问内存) ├── 容量:极小(A100 每个 SM = 192KB,共 108 个 SM,总计约 20MB) ├── 带宽:极高(~19TB/s,比 HBM 快约 10 倍) ├── 延迟:极低(读写极快) └── 位置:在 GPU 芯片上(即 L1/L2 Cache + Shared Memory)

关键洞察:GPU 的算力(FLOPS)远比带宽增长快得多。


A100 GPU:
算力:312 TFLOPS(float16)
HBM 带宽:2 TB/s


算术强度(Arithmetic Intensity)= FLOPS / Bytes
如果一个运算的算术强度太低(需要大量读写,但计算少)
→ 受内存带宽限制(Memory-Bound)
→ GPU 的算力大量浪费在等待数据上

标准 Attention 是 Memory-Bound 的

标准 Scaled Dot-Product Attention:

输入:Q, K, V ∈ R^{N×d}(N = 序列长度,d = 每头维度)

Step 1:S = QK^T / √d 计算注意力分数 形状 (N, N)
Step 2:P = softmax(S) softmax 形状 (N, N)
Step 3:O = PV 输出 形状 (N, d)


HBM 访问量分析:
读:Q, K, V → O(Nd)
写:S(N×N)→ HBM ← 这是瓶颈!
读:S 再读回来做 softmax
写:P(N×N)→ HBM
读:P 再读回来乘 V

总 HBM 访问量 ≈ O(N²)


N = 2048,d = 64,float16:
S 矩阵大小 = 2048 × 2048 × 2 bytes = 16 MB(单头)
32 头 = 512 MB
反向传播还要保存 S、P 用于梯度计算
→ 一个中等长度序列,显存直接打满

计算量和 IO 量的矛盾:

importnumpyasnp N=2048# 序列长度d=64# 头维度dtype_bytes=2# float16# 计算量(FLOPs)# QK^T: (N,d) × (d,N) → 2*N*N*d FLOPs# softmax: ~5*N*N FLOPs# PV: (N,N) × (N,d) → 2*N*N*d FLOPscompute=4*N*N*d# ≈ 1.07 GFLOPs# HBM 访问量(Bytes)# 写 S:N×N# 读 S 做 softmax:N×N# 写 P:N×N# 读 P 做 PV:N×Nio=4*N*N*dtype_bytes# ≈ 32 MBarithmetic_intensity=compute/io# FLOPs / Byteprint(f"算术强度:{arithmetic_intensity:.1f}FLOPs/Byte")# 约 33.5 FLOPs/Byte# A100 的计算/带宽上限# 312 TFLOPS / 2 TB/s = 156 FLOPs/Byte# 算术强度 33.5 << 156 → Memory-Bound!# GPU 算力只用了 33.5/156 ≈ 21%,其余时间在等内存

🧩 第二章:FlashAttention 的核心思路

分块计算(Tiling)

FlashAttention 的核心思路一句话:把 Q、K、V 切成小块,每次只把一小块加载到 SRAM,在 SRAM 里完成所有计算,不把注意力矩阵 S 和 P 写回 HBM。

标准 Attention: Q, K, V → HBM → 计算 S (N×N) → 写回 HBM → 读 S → 计算 P (N×N) → 写回 HBM → 读 P → 计算 O → 写回 HBM FlashAttention: Q, K, V 切块 → SRAM → 分块计算注意力 → 直接在 SRAM 里得到 O 的块 → 写回 HBM N×N 的 S、P 矩阵从未出现在 HBM 中!

分块大小的计算:

# SRAM 大小约束SRAM_size=20*1024*1024# 20 MB(A100 粗略估计)dtype_bytes=2# float16# 分块大小 B# 每个 SRAM 需要存:Q_block(B×d) + K_block(B×d) + V_block(B×d) + O_block(B×d)# 4 × B × d × 2 bytes ≤ SRAM_sized=64B=SRAM_size//(4*d*dtype_bytes)print(f"最大块大小 B ≈{B}")# ≈ 40960,远大于 N=2048 时无需分块# 对于超长序列(N=100K+),确实需要分块# B 的选择是关键超参数,影响 SRAM 利用率

难点:分块 Softmax 怎么做

分块计算 QK^T 很简单,但 Softmax 需要全局信息(分母需要对所有 key 求和),这是分块计算最大的难点。

Online Softmax(在线 Softmax)是解决方案。

# 问题:standard softmax 需要两遍扫描defsoftmax_standard(x):# 第一遍:找最大值(数值稳定)m=max(x)# 第二遍:计算 exp 和归一化exp_x=[exp(xi-m)forxiinx]Z=sum(exp_x)return[ei/Zforeiinexp_x]# 如果分块处理,每次只看一块,不知道全局最大值!# 解决:Online Softmax(Milakov & Gimelshein, 2018)# 维护两个统计量,随着看到新的块不断更新:# m_i:目前见过的最大值# l_i:目前的归一化分母(已修正)defonline_softmax_update(m_prev,l_prev,s_new_block):""" 处理新的一块 scores,更新统计量 m_prev:之前块的最大值 l_prev:之前块的归一化分母(基于 m_prev) s_new_block:新块的原始分数 """# 新块的最大值m_new=max(m_prev,max(s_new_block))# 更新归一化分母# 旧的 exp(x_i - m_prev) 需要修正为 exp(x_i - m_new)# exp(x_i - m_new) = exp(x_i - m_prev) * exp(m_prev - m_new)l_new=(l_prev*exp(m_prev-m_new)+sum(exp(xi-m_new)forxiins_new_block))returnm_new,l_new

🔢 第三章:FlashAttention 完整算法

前向传播

输入:Q, K, V ∈ R^{N×d},存在 HBM
输出:O ∈ R^{N×d}

算法流程: ① 设定块大小 B_c, B_r(由 SRAM 大小决定) B_c = ⌈M/(4d)⌉,B_r = min(⌈M/(4d)⌉, d) (M = SRAM 大小) ② 将 Q 分成 T_r 块:Q_1, ..., Q_{T_r},每块大小 B_r × d 将 K, V 分成 T_c 块:K_1, ..., K_{T_c},每块大小 B_c × d ③ 初始化 O = 0(N×d),l = 0(N),m = -∞(N) 以上存在 HBM ④ 外循环(对 K, V 的块): for j = 1 to T_c: 从 HBM 加载 K_j, V_j 到 SRAM 内循环(对 Q 的块): for i = 1 to T_r: 从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM 在 SRAM 中计算: S_ij = Q_i K_j^T / √d ← (B_r × B_c) 的小矩阵 m̃_ij = max(S_ij) ← 当前块的最大值(行最大) P̃_ij = exp(S_ij - m̃_ij) ← 当前块的 exp l̃_ij = rowsum(P̃_ij) ← 当前块的 sum 更新统计量: m_i_new = max(m_i, m̃_ij) l_i_new = exp(m_i - m_i_new) * l_i + exp(m̃_ij - m_i_new) * l̃_ij 更新输出(关键!修正之前的 O_i): O_i = (l_i * exp(m_i - m_i_new) * O_i + exp(m̃_ij - m_i_new) * P̃_ij V_j) / l_i_new 将 O_i, l_i_new, m_i_new 写回 HBM ⑤ 返回 O

用 Python 展示核心逻辑

importtorchimportmathdefflash_attention_forward(Q,K,V,block_size=64):""" FlashAttention 前向传播(教学版本,非真实 CUDA 实现) Q, K, V: (N, d) 输出: O: (N, d) """N,d=Q.shape scale=1.0/math.sqrt(d)# 初始化输出和统计量(存在 HBM)O=torch.zeros(N,d,dtype=Q.dtype)l=torch.zeros(N)# 归一化分母m=torch.full((N,),-float('inf'))# 行最大值# K, V 分块forjinrange(0,N,block_size):K_j=K[j:j+block_size]# (B_c, d),加载到 SRAMV_j=V[j:j+block_size]# (B_c, d),加载到 SRAMB_c=K_j.shape[0]# Q 分块foriinrange(0,N,block_size):Q_i=Q[i:i+block_size]# (B_r, d),加载到 SRAMO_i=O[i:i+block_size]# 从 HBM 加载l_i=l[i:i+block_size]# 当前块的归一化分母m_i=m[i:i+block_size]# 当前块的行最大值B_r=Q_i.shape[0]# 在 SRAM 中计算(不写回 HBM)S_ij=Q_i @ K_j.T*scale# (B_r, B_c)# 当前块的 max 和 expm_tilde=S_ij.max(dim=-1).values# (B_r,)P_tilde=torch.exp(S_ij-m_tilde.unsqueeze(-1))# (B_r, B_c)l_tilde=P_tilde.sum(dim=-1)# (B_r,)# 更新全局统计量(Online Softmax)m_new=torch.maximum(m_i,m_tilde)# (B_r,)l_new=(torch.exp(m_i-m_new)*l_i+torch.exp(m_tilde-m_new)*l_tilde)# (B_r,)# 更新输出(修正之前的 O_i 并加上新贡献)alpha=torch.exp(m_i-m_new)# 旧贡献的修正系数beta=torch.exp(m_tilde-m_new)# 新贡献的系数O_i_new=(alpha.unsqueeze(-1)*l_i.unsqueeze(-1)*O_i+beta.unsqueeze(-1)*P_tilde @ V_j)/l_new.unsqueeze(-1)# 写回 HBM(只有 O, l, m,没有 S 和 P!)O[i:i+block_size]=O_i_new l[i:i+block_size]=l_new m[i:i+block_size]=m_newreturnO# 验证正确性defstandard_attention(Q,K,V):scale=1.0/math.sqrt(Q.shape[-1])S=Q @ K.T*scale P=torch.softmax(S,dim=-1)returnP @ V# 测试torch.manual_seed(42)N,d=128,64Q=torch.randn(N,d)K=torch.randn(N,d)V=torch.randn(N,d)O_flash=flash_attention_forward(Q,K,V,block_size=32)O_standard=standard_attention(Q,K,V)print(f"最大误差:{(O_flash-O_standard).abs().max().item():.2e}")# 最大误差: ~1e-5(数值误差,在 float32 精度范围内)print(f"结果一致:{torch.allclose(O_flash,O_standard,atol=1e-4)}")# True

💾 第四章:IO 复杂度分析

标准 Attention vs FlashAttention

标准 Attention 的 HBM 访问量: 操作 读(bytes) 写(bytes) Q, K, V 初始读 3 × N × d × 2 - 写 S - N² × 2 读 S 做 softmax N² × 2 - 写 P - N² × 2 读 P N² × 2 - 写 O - N × d × 2 总计 ≈ Θ(Nd + N²),序列长时 N² 项主导 N=2048, d=64: 约 32 MB(单头,仅 QK^T 部分) ──────────────────────────────────────────────── FlashAttention 的 HBM 访问量: 操作 读(bytes) 写(bytes) K, V 每块加载 N × d × 2 -(每外循环轮) Q, O, l, m 每块加载 N × d × 2 -(每内循环轮) 写回 O, l, m - N × d × 2 总计 ≈ Θ(N²d/M) (M = SRAM 大小,N²d/M 通常远小于 N²,因为 M >> d) 对比: 标准 Attention:O(Nd + N²) FlashAttention:O(N²d/M) M ≈ 20MB,d = 64,2 bytes → M/d ≈ 163840 N²/(N²d/M)= M/d ≈ 10 倍 → FlashAttention 的 HBM 访问量约是标准 Attention 的 1/10!

为什么省显存

# 显存占用对比N=4096# 序列长度d=64# 头维度h=32# 头数B=1# batch sizedtype_bytes=2# float16defmemory_usage_standard(N,d,h,B):"""标准 Attention 显存占用(MB)"""Q_K_V=3*B*h*N*d*dtype_bytes# 输入S_mat=B*h*N*N*dtype_bytes# 注意力矩阵(训练时保存用于反向传播)P_mat=B*h*N*N*dtype_bytes# softmax 后的矩阵O_mat=B*h*N*d*dtype_bytes# 输出total=Q_K_V+S_mat+P_mat+O_matreturntotal/1024**2defmemory_usage_flash(N,d,h,B):"""FlashAttention 显存占用(MB)"""Q_K_V=3*B*h*N*d*dtype_bytes# 输入O_mat=B*h*N*d*dtype_bytes# 输出L_mat=B*h*N*dtype_bytes# logsumexp 统计量(很小)# S 和 P 不需要保存!反向传播时重新计算total=Q_K_V+O_mat+L_matreturntotal/1024**2print(f"N={N}, h={h}:")print(f" 标准 Attention 显存:{memory_usage_standard(N,d,h,B):.1f}MB")print(f" FlashAttention 显存:{memory_usage_flash(N,d,B,B):.1f}MB")# N=4096, h=32:# 标准 Attention 显存: ~4096 MB(仅 S、P 就 4GB!)# FlashAttention 显存: ~384 MB(减少 10x+)

🔄 第五章:反向传播——重计算(Recomputation)

FlashAttention 训练时,反向传播不保存 S 和 P,而是在需要时重新计算(Recomputation)。

标准 Attention 反向传播: 前向:计算并保存 S、P(用于反向传播求梯度) 反向:读取保存的 S、P → 计算 dQ、dK、dV 显存:O(N²)(需要保存 S、P) FlashAttention 反向传播: 前向:只保存 O 和 logsumexp(L = m + log(l)) 反向:从 Q、K、V 和 L 重新计算 S、P(在 SRAM 中完成) 然后立即计算梯度,不需要把 S、P 写到 HBM 显存:O(N)(只需要 L 向量,O(N) 大小) 代价:额外的计算量(大约 1.33x FLOPs) 权衡:多 33% 计算,换来 10x 显存节省 + 更快的速度(IO 减少更多)
# 保存的统计量:logsumexp L# L_i = m_i + log(l_i)# 反向时从 L 重构 softmax:P_ij = exp(S_ij - L_i)defrecompute_attention(Q_i,K_j,V_j,L_i,scale):"""反向传播时在 SRAM 中重新计算注意力"""S_ij=Q_i @ K_j.T*scale# 重新计算分数P_ij=torch.exp(S_ij-L_i.unsqueeze(-1))# 从 L 重构 softmax 输出# 现在用 P_ij 计算梯度...returnP_ij

📊 第六章:实际性能数字

FlashAttention 在 A100 上的实测数据(来自论文): 序列长度 vs 速度(前向 + 反向,Batch=8,头数=12,d=64) N=512: 标准 Attention:~150 GFLOPS/s(利用率 ~0.5%) FlashAttention:~350 GFLOPS/s 加速比:~2.3x N=1024: 标准 Attention:~200 GFLOPS/s FlashAttention:~600 GFLOPS/s 加速比:~3.0x N=2048: 标准 Attention:~180 GFLOPS/s FlashAttention:~650 GFLOPS/s 加速比:~3.6x 规律:序列越长,加速比越大 原因:N 越大,IO 瓶颈越严重,FlashAttention 节省的 IO 越多 显存对比(N=2048,Batch=8,12 头): 标准 Attention:~3.1 GB FlashAttention:~0.3 GB 节省:~10x

用 PyTorch 调用 FlashAttention

importtorchimporttorch.nn.functionalasF# PyTorch 2.0+ 内置了 FlashAttention# scaled_dot_product_attention 会自动选择最优实现Q=torch.randn(2,8,1024,64,device='cuda',dtype=torch.float16)K=torch.randn(2,8,1024,64,device='cuda',dtype=torch.float16)V=torch.randn(2,8,1024,64,device='cuda',dtype=torch.float16)# 方式1:PyTorch 内置(自动 Flash)withtorch.backends.cuda.sdp_kernel(enable_flash=True,# 启用 FlashAttentionenable_math=False,# 关闭标准实现enable_mem_efficient=False,):O=F.scaled_dot_product_attention(Q,K,V)# 方式2:直接安装 flash-attn 库# pip install flash-attn --no-build-isolationfromflash_attnimportflash_attn_qkvpacked_func,flash_attn_func# QKV packed 格式(省一次内存操作)qkv=torch.randn(2,1024,3,8,64,device='cuda',dtype=torch.float16)O_packed=flash_attn_qkvpacked_func(qkv,dropout_p=0.0,causal=True)# 分别传 Q K VO_sep=flash_attn_func(Q.transpose(1,2),# (batch, seq, heads, dim)K.transpose(1,2),V.transpose(1,2),dropout_p=0.0,causal=True,# 因果掩码(GPT 类模型))# 检查是否用了 FlashAttentionprint(torch.backends.cuda.flash_sdp_enabled())# True

🚀 第七章:FlashAttention-2 和 FlashAttention-3


FlashAttention-2 的改进(2023)

FlashAttention-1 的问题: GPU 核心有两种并行单元: warp(线程束):32 个线程组成一个 warp 多个 warp 共享 SRAM FlashAttention-1 的工作分配:不够均匀,部分 warp 空闲 FlashAttention-2 的改进: ① 减少非矩阵乘法操作(rescaling 等)的频率 → 对 GPU 更友好(矩阵乘法有专用硬件 Tensor Core) ② 更好的工作分配: - 外循环改为 Q(之前是 K/V) - 每个 warp 负责一段 Q,减少 warp 间通信 - 不同序列块可以在不同 SM 上并行 ③ 多查询注意力(MQA)和分组查询注意力(GQA)原生支持 性能提升: 相比 FlashAttention-1:~2x 加速 相比标准 Attention:训练 ~6x,推理 ~3x

FlashAttention-3(2024,H100 专用)

H100 新特性: ① Tensor Core 第 4 代:支持 FP8 ② WGMMA(Warpgroup Matrix Multiply-Accumulate) ③ TMA(Tensor Memory Accelerator):异步内存传输 FlashAttention-3 的改进: ① 生产者-消费者异步流水线: 一部分 warp 异步加载数据(使用 TMA)

另一部分 warp 同时在计算(不等待加载完成) → 内存加载和计算重叠,充分利用带宽


我们可以通过手动执行一些调度来改进这一点。例如,如果我们有两个 warpgroup(分别标记为 1 和 2——每个 warpgroup 由 4 个 warp 组成),我们可以使用同步屏障(bar.sync),使 warpgroup 1 首先执行其 GEMM 操作(例如,一次迭代的 GEMM1 和下一次迭代的 GEMM0),然后 warpgroup 2 执行其 GEMM 操作,同时 warpgroup 1 执行其 softmax 操作,依此类推。

② FP8 支持: 更低精度,进一步提升吞吐量 ③ 归一化/softmax 低精度近似 性能: A100 上 FA-2 达到 ~72% 的 MFU(模型 FLOPs 利用率) H100 上 FA-3 达到 ~75% MFU(fp16),fp8 更高


🎯 第八章:面试高频问题

Q1:FlashAttention 为什么快?不是因为减少了计算量

这是最常见的误解! FlashAttention 的 FLOPs(计算量)和标准 Attention 基本相同 甚至反向传播时多了约 33% 的 FLOPs(用于重计算 S、P) 它快的原因:减少了 HBM 访问(IO-Aware) 标准 Attention: HBM 读写量 ≈ O(N²) 算术强度 ≈ 33 FLOPs/Byte → Memory-Bound FlashAttention: HBM 读写量 ≈ O(N²d/M),减少约 10x → 算术强度提高,更接近 Compute-Bound → GPU 算力利用率从 21% 提升到 70%+

Q2:Online Softmax 是什么?为什么分块 Attention 需要它?

问题:Softmax 需要全局信息 softmax(x_i) = exp(x_i) / Σ_j exp(x_j) 计算 x_i 的 softmax 需要知道所有 j 的 x_j 分块时:只看到部分 x,全局 sum 未知 Online Softmax 解决方案: 维护两个统计量: m:目前见过的行最大值 l:目前的修正后 exp 的和(相对于 m 的) 每看到新块(新的 x_{j'}),更新: m_new = max(m, max(x_{j'})) l_new = l × exp(m - m_new) + Σ exp(x_{j'} - m_new) 最终:softmax(x_i) = exp(x_i - m_final) / l_final 关键:O 也随之更新(rescaling),保证最终输出等价于标准 softmax

Q3:FlashAttention 显存复杂度为什么是 O(N)?

标准 Attention 显存:O(N²) → 需要保存 N×N 的 S 和 P 矩阵用于反向传播 FlashAttention 显存:O(N) → 不保存 S 和 P → 只保存 logsumexp L = m + log(l),O(N) 大小 → 反向传播时从 Q、K、V 重新计算 S、P(在 SRAM 中) 代价:多 33% FLOPs(重计算 S、P) 收益:10x 显存减少 + 3x 速度提升(因为 IO 减少) 权衡:显然值得

Q4:FlashAttention 对训练和推理的影响有什么不同?

训练(前向+反向): 加速:~3x(主要来自减少 HBM 读写) 显存:减少 ~10x(不存 S、P,只存 logsumexp) 代价:额外计算(重计算 S、P,+33% FLOPs) 结论:训练是 IO-Bound 的,额外计算换来的 IO 节省值得 推理(只有前向): 加速:~1.5-2x(反向传播的重计算不存在) 显存:减少 ~3-5x(不存 S、P,但推理时 S、P 本身就不需要保存) 推理的主要瓶颈:KV Cache 的读写(FlashAttention 配合 PagedAttention 效果更好) 长序列(N > 8K): 加速比进一步增大 FlashAttention 使超长上下文(100K+)成为可能 → Claude 的 200K 上下文、Gemini 的 1M 上下文都依赖 FlashAttention

🎁 总结速查

FlashAttention 核心原理: 1. 问题: 标准 Attention 的 N×N 注意力矩阵反复读写 HBM → IO-Bound,GPU 算力浪费 2. 解决:Tiling(分块)+ Online Softmax 把 Q、K、V 切块加载到 SRAM(~10x 带宽) 在 SRAM 里完成计算(不写 S、P 到 HBM) 用 Online Softmax 处理分块带来的全局信息问题 3. 效果: 速度:训练 ~3x,推理 ~1.5-2x 显存:减少 ~10x(训练),~3x(推理) IO:减少 ~10x(从 O(N²) 到 O(N²d/M)) 4. 代价: 额外计算:+33% FLOPs(反向传播重计算) 复杂性:实现更复杂(CUDA kernel 级别) 5. 使用: PyTorch 2.0+ 内置: F.scaled_dot_product_attention(Q, K, V) flash-attn 库: pip install flash-attn from flash_attn import flash_attn_func 6. 演进: FA-1(2022):基础分块 + Online Softmax FA-2(2023):更好的并行化,~2x 于 FA-1 FA-3(2024):H100 异步流水线 + FP8

📣 最后

如果这篇让你从"知道 FlashAttention 很快"升级到"理解为什么快":

  • 👍点赞让更多做大模型训练的同学看到
  • 收藏面试被问 FlashAttention 原理时翻出来
  • 🔔关注持续更新 AI 算法原理,一个正在学 AI 的大学生 👨‍🎓

下期预告:《Speculative Decoding 投机解码:让推理快 2-3 倍的正确姿势》


📚相关阅读

  • 《Attention 注意力机制演化全图:从 MHA 到 GQA、MLA、SWA》
  • 《KV Cache 原理与优化:推理成本和序列长度的关系》
  • 《大模型面试题精选 50 问》
  • 《最小成本体验集群大模型训练实战》

📖参考资料

  • Dao et al., 2022:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness(arXiv:2205.14135)
  • Dao et al., 2023:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
  • Shah et al., 2024:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
  • Milakov & Gimelshein, 2018:Online normalizer calculation for softmax
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/30 4:08:09

17.18.动态规划,背包问题

没加记事本的模板 加记事本的模板 198. 打家劫舍 思路 dfs(i) 从一共i家偷&#xff0c;最多可以偷多少 不偷第i家&#xff0c;dfs&#xff08;i&#xff09;》dfs(i-1) 偷第i家&#xff0c;dfs&#xff08;i&#xff09;》dfs&#xff08;i-2&#xff09;nums[i] 只回溯&…

作者头像 李华
网站建设 2026/4/30 4:07:23

Token经济:一场正在展开的“智能定价革命”

Token并不是答案&#xff0c;它更像是一个信号。 在人工智能产业快速演进的今天&#xff0c;一个原本只在技术圈流行的术语——Token&#xff0c;正悄然成为理解AI经济形态的关键入口。 根据全球最大AI模型API聚合平台OpenRouter最新数据显示&#xff0c;3月16日至22日&#…

作者头像 李华
网站建设 2026/4/30 4:03:49

项目中**LabVIEW 位操作逻辑**的完整、清晰解释,以及与 C# 实现的对应关系

以下是针对项目中LabVIEW 位操作逻辑的完整、清晰解释,以及与 C# 实现的对应关系。 LabVIEW 中关键位操作函数 你的描述(“数字转换成 bool 数组 → 反转一维数组 → 循环检查”)主要涉及以下两个核心 LabVIEW 函数: Number To Boolean Array(数值转布尔数组) 位置:Pr…

作者头像 李华
网站建设 2026/4/30 3:57:24

CaTok:1D因果图像标记化方法解析与应用

1. 项目概述CaTok是一种创新的1D因果图像标记化方法&#xff0c;它基于MeanFlow解码器架构&#xff0c;专门针对序列建模任务中的图像处理需求而设计。这个方法的核心思想是将二维图像数据转化为一维的因果标记序列&#xff0c;同时保持空间信息的完整性。我在计算机视觉和序列…

作者头像 李华
网站建设 2026/4/30 3:56:11

SSH隧道与Tailscale实现AI代理远程运行时本地化连接

1. 项目概述&#xff1a;当本地浏览器需要连接远程大脑时在AI智能体与自动化工具的开发实践中&#xff0c;我们常常会遇到一个经典的“身体与大脑”分离困境。一个强大的AI运行时&#xff08;大脑&#xff09;可能运行在拥有充足算力、稳定网络或特定依赖的远程服务器上&#x…

作者头像 李华
网站建设 2026/4/30 3:50:26

Go分布式爬虫框架clawjob:架构解析与生产部署指南

1. 项目概述与核心价值最近在折腾一些数据采集和自动化任务时&#xff0c;发现了一个挺有意思的项目&#xff0c;叫clawjob。乍一看这个名字&#xff0c;结合它的仓库地址jackychen129/clawjob&#xff0c;就能猜到这玩意儿跟“爬虫”和“任务”脱不了干系。没错&#xff0c;它…

作者头像 李华