flash attention

https://www.bilibili.com/video/BV1UT421k7rA

梯度检查点

梯度检查点:gradient checkpointing,也叫做activation checkpointing

来看一个正常的单个神经元的传播,我们发现:公式反向传播需要

$a_4,a_3$参数$w_3的梯度需要用到a_2,a_2需要用到a_1,a_1的梯度需要用到x$,

所以训练框架默认会我们保存前向传播所有的激活值,但是当batch很大,或者输入的序列很长。处理方法很简单在前向传播时处理一部分的激活值,假如我们这里,释放掉$a_1,a_3$,那么在反向传播时,可以用$a_2\times w_3$重新得到$a_3$,而不需要从头计算,同理$a_1$也是一样。hugging face框架已经帮我们实现了这一点。当如果在训练的时候batch_size已经为1了,但是显存还是不足时,可以尝试这个配置。

有的时候重算比在显存里加载更块的话,那就是没有什么副作用了。比如下图的P,S。

硬件

bfloat16(Brain Floating Point 16)是一种浮点数格式,主要用于加速深度学习模型的训练和推理。它由 Google 提出,并广泛应用于 Tensor Processing Unit (TPU) 等硬件加速器中。bfloat16 的设计目的是在保持计算效率的同时,尽量减少数值精度的损失。

目前的主流是bfloat16(之后可能会变成fp8)一般来说半精度的运算,会是单精度的速度的一倍。

tensor core运算主要用于深度学习和AI的句子运算。是NVIDIA引入的硬件单元,目的是优化深度学习训练推理中的大规模矩阵乘法运算。

pcie类似单买一张显卡,右边包括了NVLINK,通信优化。机内速度可以差10倍。右边是稀疏矩阵(有很多元素为0的矩阵,科学计算中有用,ai中较少用)。

回顾

快速和有效的现存,IO感知下的精确注意力。

标准的是自注意力是${S}={Q}*{K}^{\top}$

然后对S按行求P=softmax(S),最后得到注意力矩阵P。然后和V相乘。QKV都是输入乘以一个矩阵得到的。

Q和K假设都是8192*128的大小(大小一共是2MB),矩阵乘法$[m,n]\times [n,k]$计算量是2mnk(17,179,869,184),Q和K使用bfloat16。P的大小是$8192\times 8192$也就是128MB。(因为这里的长宽比特别夸张,所以数据膨胀的速度很快。)

假设计算速度为312TFLOPS(A 100的tensor core),也就是每秒执行312万亿次。显存带宽为2TB/S,然后除以计算速度就是计算时间。

计算一次55微秒。而从显存中读取的话大概是61微妙。

这里没有对S作缩放,没有丢弃。

矩阵Q,K,V存储在HBM(显存)中。

  1. 从HBM加载Q,K到SRAM(显存缓存)
  2. 计算$S=QK^\top$
  3. 将S写到HBM
  4. 将S加载到SRAM
  5. 计算P=softmax(S)
  6. 将P写出到HBM
  7. 将HBM加载P和V到SRAM
  8. 计算O=PV
  9. 把O写出到HBM
  10. 返回O(output)

随着序列长度N增长,缓存$N^2$的增长

llama3-8B,N=8192 d=128

模型制约

模型制约有俩中形式:

compute-bound:计算限制

大的矩阵乘法,多channel的卷积。这些操作都是数据量不大计算很复杂。

还有一种是

memory-Bound:

这类操作数据的加载速度跟不上运算速度主要的操作包括利俩类

  • 按位操作:relu,Dropout
  • 规约操作,sum,softmax

访问数据多一定是访问瓶颈,要看访存速度。(访存:从显存读取数据的时间)

对于memory-bound的优化进行fusion融合操作。不对中间结果进行缓存,减少HBM的访问时间。fusion可以提升效率,但是模型训练时要保留中间结果给反向传播用。

显存里的存储是分级的,有芯片内的缓存SRAM,和芯片外的缓存HBM。

之前对attention的改进着眼于减少计算量,flash Attention着眼于减少IO量。

如何优化

  1. 通过分块计算,融合多个操作。
  2. 反向传播时,重新计算中间结果。

这里为了简单先不考虑P,把结果当成直接$S\times V$。这里我们要进行分块,因为缓存太小了(缓存一定会比显存小),放不下整个矩阵。


首先拿Q的前俩行,K转至的前三列,计算得到S然后和V做计算。因为O是对于所有V的加权瓶颈,后面需要对O进行更新

所以这里用浅一点的颜色来表示,一次类推,得到其他的几行

然后读取K的后三行,V的下三列,在从HBM中读取之间结果。对V的前三行的加权平均值进行家和这就是O的最终结果。

以此类推,最后得到完整的O

我们发现通过将矩阵分块,以及将多部计算进行融合,中途没有将中间结果存入HBM。大大减少了IO的时间。

但是SOFTmax是按行进行的。

现在的训练都在混合精度fp16下进行的。如果x为12,那么e的12次方就大于了FP的最大表示。为了解决这个问题,人们用了一种safe_softmax。将softmax的分子分母同时除以e的m次方,softmax的结果不变。这样就不会有数值溢出的问题了。

softmax也可以通过分块计算,$softmax=(x)=p(x)/l(x)$。

只是需要额外保存几个变量,还有分块进行合并的时候要调整计算,但是这对于节省的时间还是非常划算的。

前向传播保留softmax中的统计值,最大值m和累计和l的值,反向传播可以快速重新计算激活的值。可以看做是另一中国形式的梯度(激活值)检查点。

flash attention2

  1. 减少了非矩阵乘法计算,可以录用tensorcore加速
  2. 调整了内外训练,Q为外层训练,KV为内存循环。减少HBM读写。
  3. 如果一个Block处于矩阵上三角部分(被mask掉的部分),不进行attention计算。
Last modification:November 14, 2024
如果觉得我的文章对你有用,请随意赞赏