在这里整理了一些闪卡,帮助自己记住以下内容。
和某人聊了聊预训练经常失败的原因,很有意思。能真切感受到各种可能导致事情搞砸的具体方式,以及为何训练如此脆弱。从宏观上看,破坏因果性和引入偏差似乎是两大元凶。
破坏因果性:
在做专家路由时,首先会经过路由器,它会为每对 token 和 expert 打分,表示该 token 希望由哪个 expert 处理。接下来有两种处理方式:1. Token 路由,即从 token 的视角读取分数,并为每个 token 分配其 top k 个专家。问题是,这可能导致专家之间的分配极度不均衡,严重影响性能。另一种选择是(仅在训练中)采用专家选择(expert choice),即根据每个专家相对更偏好的 token 来划分 token。这样就能确保每个专家获得大致相同数量的 token。但一个大问题是破坏了因果性——因为 token n 被分配到哪个专家,可能取决于 token n + k 会被路由到哪个专家。而破坏因果性非常糟糕,因为在训练过程中你会接收到在部署时根本看不到的信息,并据此更新模型。有传言说 Llama 4 表现平平正是因为这个原因。也许你可以在预填充推理阶段也做专家选择?但实际操作中,将 token 分配给那些在真实训练中本不会收到该 token 的专家,可能效果不佳。说实话,我还没完全理解为什么破坏因果性这么严重。我知道在真实推理中无法超越因果关系,但这种微小的偏离为何会造成如此大的问题呢?另一个破坏因果性的情况是 token 丢弃。即某些专家会忽略批次中它们本该处理的、但匹配度不够强的 token,从而避免超出填充长度。但如果后续某个 token 与这个专家匹配度更高,就可能导致前面的 token 被忽略,这就破坏了因果性。据说 Gemini 2 Pro 就存在这个问题。引入偏差:
偏差比方差更可怕——方差可以通过平均化来缓解,但偏差会不断累积。据说原始 GPT-4 的训练初期进展缓慢,就是因为以下 bug:他们在集体通信操作(如 all-reduce)中使用了 FP16 格式。FP16 的分数粒度是按对数密度分布的——在 1 到 2 之间,尾数比特会将区间划分为约 0.001 的间隔;但在 1024 及以上时,尾数可能将区间划分为多个整数值。假设某个集体操作需要累加 1 + 1……共 10,000 次——一旦达到 1024,你加 1 变成 1025,再向下取整到最近的区间就是 1024,再加一次。结果计算值会比真实值大 10 倍。如果你试图将许多小梯度累加到一个大累加器中,这将是灾难性的。想象一下要找出这个 bug 有多难!对 AI 训练的启示:
认为我们能治愈衰老的人中,有些提出一个观点:人类衰老死亡的原因基本上有五种(如心脏病、癌症等),若能攻克这五种疾病,衰老问题就基本解决了。类似地,我们也可以问这些失败的预训练运行是否也有五种失败模式——如果是,那么一旦某个实验室解决了数值计算等问题,后续进展就会一帆风顺;还是说在每个新的规模层级上,仍会不断涌现全新的定制化难题?与我交谈的人似乎倾向于后者,他指出即使在数值计算领域,就有无数种可能导致失败的方式,而且随着规模扩大,新的问题也会持续出现。对 AI 短期内完全自动化编写内核持悲观态度。他推测这是因为他认为内核编写更像是一个 AGI 完整问题,而非一些人认为的那么简单。另一种观点则认为:嘿,在这个扩展规模下,哪个内核或 MLP 运行得最快是一个高度可验证的领域,因此我们可以通过强化学习轻松达到超人类性能。但他指出,即使是拥有全球顶尖内核工程师的英伟达,也花了很长时间才优化出 Blackwell,这表明这项工作其实相当困难,可能并不容易实现闭环优化。有时人们会说 RL 生成中的推理和面向最终用户的推理本质上是一样的。但此人指出,在 RL 推理中,推理引擎与训练引擎之间的数值漂移会导致这些微妙的非策略偏差,这对高质量训练至关重要;然而,如果只是用于向用户提供服务,这些问题则无关紧要。强调了制定严谨流程来整合计算乘数的重要性,因为累积带有微妙偏差的 bug 会带来风险。来自 Horace He 为我们朋友和我所做的一场精彩讲座的笔记。
这场讲座之所以出色,是因为 Horace 将整个主题构建成一系列问题与解决方案的链条:我们想做什么,为什么会失败,我们如何修复,以及为何这种修复最终也会失效。大多数解释都只是罗列各种策略,从未将它们与所解决的问题联系起来,也未说明为何选择其中一种而非另一种。
预训练 FLOPs 公式为 6ND。前向传播每参数每 token 需要 2 次浮点运算(乘加各一次)。反向传播是前向的 2 倍,因为要分别计算两个输入矩阵的梯度。所以总共有 2 + 4 = 6 次 FLOPs。好吧,我们无法在一块 GPU 上完成所有这些计算。那该如何拆分这个问题呢?显而易见的方案是数据并行——即把模型权重复制到每个 GPU 上,每个 GPU 只处理批次的一部分。显而易见的问题是每个 GPU 上的高带宽内存(HBM)容量有限——B300 是 288GB——这不足以存储越来越大的模型的权重,更不用说激活值了。接下来我们尝试的是完全分片数据并行(FSDP)——每个 GPU 只存储每层参数的 1/N。在处理每一层之前,所有 GPU 会通过 all-gather 操作获取该层的完整参数(此时每个 GPU 仍只持有该层的 1/N)。处理完成后,各 GPU 会丢弃已收集到的参数。这一点被强调为首选默认方案,只有当 GPU 数量过多迫使你升级策略时(原因稍后说明),才考虑其他方式。之所以成为默认,是因为计算与通信时间极易重叠——因为传输的只是权重,而这些权重不依赖于当前层之前的计算结果,因此可以在当前层仍在计算的同时,就开始对下一层进行 all-gather。相比之下,张量并行或专家并行需要在处理下一层前共享本层的激活值。而流水线并行的主要问题在于存在气泡(如下所述)。 从通信量角度看,FSDP 初看似乎代价极高——每层都要在所有 GPU 间 all-gather 完整的权重,用于一次矩阵乘法后即丢弃。但这忽略了常规数据并行本身已有的开销:在标准 DP 中,反向传播时仍需对每层的梯度执行 all-reduce 以同步各 GPU 上的批次梯度,其通信量为 params × 2。FSDP 则额外增加了两次 all-gather:正向和反向各一次,每层一次。但 all-gather 的通信量仅为 all-reduce 的一半。因此,粗略估算下 FSDP 的总通信量为 params × 4(正向 all-gather + 反向 all-gather + 反向 all-reduce)。实际上还能进一步优化:由于每个梯度分片只需最终出现在拥有它的那个 GPU 上,可将 all-reduce 替换为 reduce-scatter(省去最后的广播步骤),这样总通信量降至 params × 3,仅比 vanilla DP 多出 50%。那为什么不能总是使用 FSDP 呢?通信交叉点:我们希望计算时间大于通信时间——不希望被通信所限制。但随着 GPU 数量的增加,FSDP 的计算时间会减少,而通信时间不会随之下降。因此,在扩展 FSDP 的 GPU 数量时,MFU(模型利用效率)可能会急剧下降。当这种情况发生时,还需要引入流水线并行。计算时间 = (6 × token 数 × 激活参数) / (每 GPU 算力 × GPU 数量),随着 GPU 数量增加而减少。通信时间 = (总参数量 × 3) / (NVLink 域大小 × InfiniBand 带宽),通信时间不会随域的增加而上升。这一点曾让我非常困惑。每个域共同持有所有参数,并在反向传播每一层后同步梯度。直觉上,增加域意味着环中跳数增多,all-reduce 会变慢。但标准的环形算法将消息分割为每个参与者一个分片。更多域意味着更多跳数,但每跳的分片也更小。(当分片小到足以让每跳延迟占主导时,这种优势就消失了,此时会切换到树形算法。)技术上,可以对所有域之间的梯度进行比朴素单次 all-reduce 更优的操作。采用分层集体通信来优化跨多个 NVLink 域的通信时间。关键是要记住:每个域中的 GPU 都能独立访问 InfiniBand 带宽。因此,应充分利用这一带宽,因为互连带宽是瓶颈所在。为此,尽量在一个扩展开销内完成尽可能多的操作,再向外迁移。具体做法是:在扩展开销内执行 reduce scatter,使每个 GPU 获得该层分片的域级归约梯度;然后在各域对应的 GPU 之间执行 all-reduce;最后在域内执行 all gather。这降低了通信时间线,从而将交叉点右移。用 Cursor 和 Composer 2 制作了一个动画来演示这个过程:从公式可以看出,增大 batch size 会使交叉点右移,而提高模型稀疏性则会使交叉点左移。此外,TPUs 在 FSDP 上表现更好,是因为一个域内有更多加速器。Batch size 下限:FSDP 是数据并行的,因此每个 GPU 至少处理一个序列。注意力机制是在序列内部计算的,且无法(轻易地)跨 GPU 拆分。如果你的关键 batch size 是 1000 万个 tokens,序列长度为 1 万,那么你只有 1000 个序列——即使还有充足的通信带宽,纯 FSDP 也无法扩展到超过 1000 个 GPU。流水线并行的问题(这是你接下来要添加到 FSDP 中解决这些问题的方法):流水线并行的不同之处在于,开始时负责最后几层的 GPU 未被使用,而结束时负责前几层的 GPU 也未被使用,从而产生“气泡”。在训练过程中无法重叠批次来解决流水线气泡问题,因为必须在处理下一批数据之前完成梯度的聚合和模型的更新。此外,这样做还会引入架构限制——例如 Kimi 的注意力到残差机制(每个块都关注所有先前层的残差)在残差分布在不同流水线阶段时会变得非常困难。同样,交错滑动窗口和全注意力层可能会导致各阶段的负载不平衡。处理所有这些情况会减缓研究迭代速度,这是你所犯的最大错误。没有帖子