Reading Notes: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”
date
Mar 9, 2025
slug
flash-attn
status
Published
tags
MLSys
summary
type
Post
Motivation
Attention 模块是 transformer 模型的核心组件,但是由于 attention 算子在时间和空间复杂度上对于序列长度都是平方级别增长,它在长序列上表现不佳。为了提升 attention 的效率,有研究者提出了一些近似 attention 的算法,但是这些方法常常是以牺牲部分模型质量来换取速度上的提升,并且,尽管这些方法降低了 attention 的时间复杂度,却没有带来实际执行时间的下降。本文认为,这是由于这些方法没有考虑 attention 算子中访存带来的影响。
基于上述考虑,本文提出了一种 IO-aware 的 exact attention 算法 flash-attention, 通过减少 HBM 和片上 SRAM 存储之间的数据移动来提升 attention 算子的效率。
Approach
Flash-Attention 主要通过减少 HBM 和片上 SRAM 存储之间的数据移动(也就是对应题目中的 IO-Awareness)来提升 attention 算子的效率(尽管实际上的 FLOPS 数更高),使用的主要方法是:
- Tiling: 将参与运算的输入分块,每次执行块和块之间的计算。(参考矩阵乘法的分块,但是这里更复杂)
- Recomputation: 丢弃中间结果,而不需要把它们写回 HBM, 需要的时候再重新算。(不需要把整个 attention score 矩阵存下来在 backward pass 用,而是直接根据保存的某些其他结果重新算)
- Kernel Fusion: 把 attention 中的操作都融合到一个 kernel 里面,避免反复写回和读取中间结果,以及减少 kernel launch 开销。
通过上面的三个优化,在整个过程中我们不需要真的 materialize attention score 矩阵,而是直接计算 output 结果,从而避免了对于序列长度平方增长的访存以及存储开销。
Standard Attention

标准的 attention 算子依次执行 matmul + softmax + matmul,每两个操作之间需要把中间结果写回 HBM, 在 HBM 和 SRAM 之间的数据移动需要读取: 以及写入: ,总共是:
Flash Attention Prelude: Online Softmax
Tiling 是一个常用的用来减少算子中数据移动的优化手段,如果考虑用和分块 GEMM 类似的方法来直接优化 attention 会发现一个主要的障碍在于中间的 softmax 需要 reduce 一整个行来计算 scaling 系数以及最大值(因为需要用数值稳定的 softmax),因此直接对 attention 做分块是没办法得到非近似的结果的。
在考虑如何做分块的 attention 之前,先来考虑如何做分块的 softmax,一个简单的实现需要 3 个 pass:

其中第二个 pass 主要是因为要先计算了最大值才能求 scaling 系数,然而,如果定义 ,那么我们有:
这样,在计算 的时候也可以一边计算 ,注意到 ,在一个 pass 里面就可以同时完成 scaling 系数和最大值的计算,这样整个算法只需要两个 pass:

Flash Attention: Algorithm
那么,能否继续优化使得整个算法只需要 1 个 Pass 呢,对于 Softmax, 答案是不可以,但是对于 attention 算子,我们的目标是求出最后的输出,而不是 attention score 矩阵。
如果能够像 online softmax 那样推导出一个递归式直接从输入计算到输出(或者说计算到输出的某种中间结果,使得我们在整个 pass 结束后得到完整的输出) ,那么就可以在一个 pass 内完成 attention 运算。
根据上面的 2 Pass Softmax, 容易写出如下的 2 Pass 的 attention 实现(这里给出输出的每行的计算,由于行之间计算是独立的,以此类推):

由上算法可以得到 ,这里由于以来计算完的最大值和 scaling 系数所以不能 1 pass 做完,考虑使用和 online softmax 类似的 trick,把出现 的地方都换成循环变量 , 定义 ,注意到有:
当 ,这样我们就能在一个 pass 得到 attention 的输出。
由上,可以得到 flash attention 的分块实现:

此外,由于我们在 forward pass 中保存了最大值和 scaling 系数,即使没有保存 attention score 矩阵,在 backward pass 也可以做 recomputation.
Analysis: I/O Complexity of Flash Attention
上述算法中,可以在 i 这个轴做并行化,考虑对于 的分块,如果 SRAM 的大小是 ,那么每次可以加载 的元素,因此对于 需要经过 个 pass,在每个 pass 要遍历整个 ,因此需要加载以及写回 这么多元素,因此总的 I/O Complexity 是 ,通常 , 而标准的 attention 需要 这么多的 HBM 访问。
此外,由于不需要把整个 attention 矩阵写回 HBM, 需要存储在 HBM 的只有输出,最大值,以及 scaling 系数,只需要占 的空间,而原来需要 ,这也使得 flash attention 的 memory efficiency 更高,能够支持更长的 context length.
Experiments and Results
- 训练速度:在训练 benchmark 比如 GPT2 以及 BERT 上取得比 Huggingface, Megatron 等 SOTA 方法显著更快的结果 (1.15x to 3x)
- 模型质量:在更长的上下文上训练,得到的模型在多个 benchmark 上显著好于 SOTA 的准确率/loss.
- 随上下文长度 scaling:测试运行时间和显存使用随上下文长度 scaling 的关系,验证了显存使用随上下文长度线性增长,以及运行速度最快达到标准 attention 的 3x
Future Directions
- Compiling to CUDA: 目前需要使用 low level 的语言来实现这个 flash attention 的 kernel,使用 pytorch 这样的高级语言实现 flash attention 然后编译到 kernel 是一个可以探索的方向。
- I/O Aware Deep Learning: 其他层也可以采用类似的方式进行 I/O Aware 的优化。
- Multi-GPU I/O Awareness: 目前只考率 1 个 GPU 上的 I/O, 可以进一步考虑多个 GPU 之间的通信 I/O.
References
[1] From Online Softmax to FlashAttention: 从 Online Softmax 到 Flash Attention 的推导