从头编写LLM(第34a部分)—— 为LLM训练构建JAX训练循环Writing an LLM from scratch, part 34a -- building a JAX training loop for an LLM training run
基于 Sebastian Raschka 的 LLM 构建指南,作者展示了如何脱离参考书籍,仅凭个人笔记从头搭建大语言模型的训练管线。本篇重点聚焦于系列教程的第 34a 部分,详细演示了如何使用高性能计算框架 JAX 编写完整的 LLM 训练循环。这一从零开始的工程实践为深入理解现代 AI 模型的底层训练机制提供了极佳的技术参考。
Giles Thomas
写下那篇我希望自己在刚开始学习相关技术时就能看到的文章……
归档
分类
友情链接
一年多来,我一直把 Sebastian Raschka 的著作《Build a Large Language Model (from Scratch)》——以及阅读该书衍生出的众多副项目——当作学习现代 AI 的教程。我给自己设定的最后一个任务是,仅凭笔记从零构建并训练一个 LLM——既不参考那本书,也不参考我跟着书本写的模型代码。
对于输出结果,我希望能达到我基于 Raschka 代码编写的最佳 PyTorch 模型的水平——那是一个在 3.2B tokens 上训练的基础模型,据我(虽然有限的)评估,其质量已接近原版 GPT-2 small。
我想使用一个不同的框架,以确保自己不是在机械地默写记下来的代码。因此我在 Twitter 上问大家该用哪个,最终胜出的是 JAX。
我采取了与 Raschka 的书稍有不同的路线;他采用由内而外(inside-out)的视角,先解释注意力(attention)等概念,逐步构建出一个完整的 GPT-2 风格模型,然后再在此基础上构建训练循环。而我想由外而内(outside-in)进行:先组装一个训练框架(training harness),用类似真实 LLM 的 API 去训练一个尽可能简单的模型,让它达到我满意的效果,然后再逐一为这个简单模型添加特性,直到它具备完整的架构。这个计划(最终效果非常好!)的初衷是,我能够展示每一次改动是如何带来改进的。
现在这些都已完成,我将分两部分发文介绍;在这一部分,我将解释我是如何构建训练框架的,而在下一部分,我将展示 LLM 的实际构建和训练过程。
那么,让我们开始吧!
JAX 之上选哪个框架?
JAX 本身的 API 相当精简,不包含诸如线性层(linear layers)等标准神经网络组件。同样,它也没有任何内置的优化器(optimisers)、数据加载器(data loaders)或类似的机器学习工具。
其实,我本可以决定仅使用纯 JAX 来构建我的 LLM,就像我之前处理玩具 XOR 模型那样。但我认为,按照现实世界中 JAX 代码的编写风格来构建会更好,这就意味着要使用众多实用库中的一些。
在 JAX 官网上,有一个看起来很有用的链接:“如果你打算用 JAX 训练神经网络,请查看 JAX AI Stack!” 在链接到的页面上,明确指出该技术栈的两个核心部分是:
我看了看这两个库,它们似乎都很容易上手。确实,乍一看,我觉得 NNX 看起来非常像 PyTorch!在它们的教程示例中,唯一真正明显的区别就是 JAX 风格的求导式梯度计算,以及随机数的处理方式。而且,即使是随机数的处理,也没有采用像纯 JAX 那种纯函数式的方式——你不必费力去拆分密钥(splitting keys),只需传入一个看起来是有状态的变量即可,它会在需要时以某种方式在内部自行拆分。
因此,NNX 和 Optax 就成了我要使用的框架。与其死磕教程,我决定直接上手实操,在实践中边做边学。
这能有多难呢……?
A到A语言模型与训练循环
为了构建一个能正常运行的训练循环,我需要一个用于训练的最小化模型——不是真正的 LLM,而是至少行为上有点像它的东西。它需要接收一个 token 序列,并为每个 token 输出 logits。
在我偏好的 LLM 工作原理模型中,对于模型的顶层架构,我们输入一系列 token ID,然后:
所有这些启发了我,为了起步我能写出的最笨的“LLM”,就是一个仅仅将 token ID 投影到 embedding 空间,然后再投影回词汇表空间的东西。完全没有 Transformer 层。
然后我会训练它,让它不再尝试预测下一个 token,而是尝试“预测”一开始输入给它的内容。换句话说,你会向训练循环输入以下内容:
The fat cat sat on the mat……以及这个目标:
The fat cat sat on the mat……而不是 LLM 的常规设置,在那种情况下你输入:
The fat cat sat on the……并给它以下目标:
fat cat sat on the mat如果我能让它跑通——而且感觉这种东西不需要大量训练就能把 loss 降到接近于零——那么我就可以相当确信我的训练循环是正常工作的。1
我决定将其称为 A-to-A 模型。编写模型本身的代码极其简单,看起来像这样:
from flax import nnx
class GPTModel(nnx.Module):
def __init__(
self,
vocab_size, context_length,
emb_dim,
n_heads, n_layers,
qkv_bias,
drop_rate,
rngs,
):
self.token_embedding = nnx.Embed(
num_embeddings=vocab_size,
features=emb_dim,
rngs=rngs,
)
self.output_head = nnx.Linear(
in_features=emb_dim,
out_features=vocab_size,
use_bias=False,
rngs=rngs,
)
def __call__(self, xs):
input_embeddings = self.token_embedding(xs)
return self.output_head(input_embeddings)里面包含了大量的样板代码——用于存放我知道在构建完整的 LLM 时模型会需要的参数——其篇幅和真正执行操作的代码一样多!
但是训练循环就更有趣一些了。
移植基础代码
正如我所说,我在这里的计划是通过仅凭笔记重建一个 LLM,来确保我对 LLM 内部原理的理解是正确的。“仅限笔记”的限制并不适用于训练循环本身,因此我允许自己稍微借鉴一点我之前在云端用来训练原始模型的 PyTorch DistributedDataParallel 代码。
我使用的第一个版本就在这里。让我们从底部拥有 main 函数的地方开始看。
它以一些处理“runs(运行)”概念的样板代码开始。这是我发现自己在大多数项目中都在使用的模式。在开发模型时,能够进行多次训练运行且每次更改一些内容是非常有用的。你会希望保留每次运行的 checkpoints、metadata 和训练图表,以供将来参考。
所以在我的代码库里,会有一个“runs”目录,里面为我想要追踪的每次训练运行分别建一个子目录。
在这些子目录中,有一些 JSON 文件——一个用于配置模型,即 model.json,另一个用于配置训练超参数及类似内容,即 train.json。(值得一提的是,在这个阶段,有一堆超参数是未被使用的;出于偷懒我把它们留在了那里,因为我知道以后会用到它们。)因此,我们通过加载这些文件来开始我们的主函数。
我们的下一步是完全忽略其中一个训练超参数:gradient_accumulation_steps。我确实想做梯度累积,但决定把它留到以后。我觉得最好先完成一次扎实且更简单的训练运行。
接下来,我们使用 download_dataset 将我们要用的数据集下载到本地磁盘(只有在本地没有最新副本时才会执行下载)。
下一步是调用 load_dataset 将其加载到 RAM 中。你可以看到那里还有一个硬编码的变量 world_size。这是这段代码所基于的多 GPU DistributedDataParallel 代码遗留下来的;在这篇博文中,我只介绍单 GPU 训练的代码,但我决定保留 DDP 相关的内容用于整理数据集,并将其硬编码为单 GPU,这样如果我以后决定在 JAX 中实现类似功能时,重新引入它们会更容易。
让我们来看看 load_dataset 及其相关内容。如果你跳到第 39 行,就会看到代码。首先,有一个 BigTrainDataset 用于跟踪我们的训练数据。如果你仔细观察,可能会发现那个类中有一个奇怪的地方。我们有这个:
self.xs = all_tokens[:-1].reshape(-1, microbatch_size, seq_length)
self.ys = all_tokens[:-1].reshape(-1, microbatch_size, seq_length)请记住,在这个阶段,计划是训练模型将 token 映射到它们自身,而不是进行下一个 token 的预测。因此,目标(targets)与输入(inputs)是相同的,而不是更常见的下一个 token,后者看起来会像这样(并且在下一篇博文中也会变成这样):
self.xs = all_tokens[:-1].reshape(-1, microbatch_size, seq_length)
self.ys = all_tokens[1:].reshape(-1, microbatch_size, seq_length)接下来,我们有一个 load_dataset 函数,用于将适当的数据子集从本地磁盘上的副本加载到其中一个 BigTrainDataset 对象中。当我运行它的第一个版本时遇到了内存溢出(out-of-memory)的问题。它试图将数据加载到我的 GPU 的 VRAM 中——如果你有 GPU 并且安装了 CUDA 版本的 JAX,这是 JAX 的默认行为——但数据量太大,无法全部装进去。经过一番探索,我了解了如何更改 JAX 的默认设备,以便将数据加载到普通的系统 RAM 中。
不幸的是,完成这一步后,我发现遍历它的速度超级慢——从数组中获取一个包含 6,144 个 token 的训练批次大约需要 1.2 秒,这意味着仅此一项就会将我的训练速度限制在 5,120 token/秒。我最终了解到,数据虽然已加载到主 RAM 中,但由于尚未真正提交给主 RAM,它仍然会被复制到 GPU 进行处理——详情见此。修复这个问题(通过显式调用 jax.device_put)意味着从数据集中获取一个训练批次并将其放到 GPU 上只需不到 0.001 秒,这就好多了。
所以,这就是被浓缩到代码第 55 到 58 行的数小时的工作成果:
cpu0 = jax.devices("cpu")[0]
with jax.default_device(cpu0):
full_dataset = load_file(dataset_dir / f"{split}.safetensors")["tokens"]
full_dataset = jax.device_put(full_dataset, cpu0)load_dataset 中剩余的逻辑,只是为了确保我们得到的数据集大小,能够与我们正在使用的 world size(尽管现在始终为 1)、微批次大小(microbatch size)、梯度累积步数(gradient accumulation steps)以及序列长度(sequence length)完全匹配,
让我们再次回到 main 函数。在加载了数据集之后,我们创建模型,传入模型配置参数以及(目前未使用的)dropout 训练超参数,然后我们创建了一个封装了 Optax 优化器的 Flax NNX 优化器。这基本上是从 Flax 教程中复制粘贴过来的,只不过我们使用了训练配置中的学习率和权重衰减超参数来配置该优化器:
optimizer = nnx.Optimizer(
model,
optax.adamw(
learning_rate=train_conf["learning_rate"],
weight_decay=train_conf["weight_decay"],
),
wrt=nnx.Param
)最后,我们调用 train 来启动训练循环,并传入一些合适的参数。接下来我们来看看这个函数。
我们先做一些初始化的常规处理,然后进入主循环。你可以看到它似乎在暗示梯度累积(gradient accumulation):
for global_step in progress_bar:
for accumulation_step in range(gradient_accumulation_steps):……但如果你看看这个循环的实际主体,它根本没有做任何类似的事情。它只是获取训练批次,将它们放到 GPU 上,执行一个完整的训练步骤,并记录一些指标:
inputs, targets = train_dataset[((global_step * gradient_accumulation_steps) + accumulation_step) * world_size + rank]
inputs = jax.device_put(inputs, model_device)
targets = jax.device_put(targets, model_device)
train_loss = train_step(model, optimizer, inputs, targets)
train_losses.append(train_loss.item())
microbatch_size, sequence_length = inputs.shape
tokens_seen_this_rank += microbatch_size * sequence_length所以,我们现在只是在做一个传统的逐批(batch-by-batch)训练循环,没有进行梯度累积。但是已经有一些相关的基础架构了,因为这是我让基本循环跑通之后打算添加的下一个功能。
train 函数的其余部分只是常规处理和检查点保存;我们很快会回到检查点这部分,但首先让我们看看 train_step 函数及其关联的 calculate_loss 函数,前者负责在一组输入和目标上实际训练模型,它们的定义就在 train 函数的正上方。
现在,正如你可能从我第一篇关于 JAX 的文章中回忆起的那样,对训练循环进行 JIT 编译的最佳方式是在尽可能高的层级上进行。所以当我最初编写这段代码时,我将其整合到了按传统命名的 train_step 函数中,就像这样:
@jax.jit
def train_step(model, optimizer, inputs, targets):
loss, grads = nnx.value_and_grad(calculate_loss)(model, inputs, targets)
optimizer.update(model, grads)
return loss当我第一次实际运行它时,loss 根本没有下降,在苦思冥想了一番之后,我意识到我应该使用 nnx.jit 而不是 jax.jit,修复这个问题后,我重新启动了训练。loss 马上开始下降了。哎!
现在让我们来看看 loss。交叉熵损失(cross entropy loss)显然是我训练 LLM 所需要的,而且对于 A 到 A 模型来说似乎也是合适的选择。
Optax 有五个与交叉熵相关的损失函数;其中三个看起来比我需要的要复杂一些:
所以选择范围缩小到了
后者是正确的选择——softmax_cross_entropy 期望标签(即目标 token ID)是独热向量(one-hot vectors),而 softmax_cross_entropy_with_integer_labels,正如函数名所示,期望的是整数标签,这正是我们所拥有的。
这听起来和 PyTorch 的 cross_entropy 非常相似,但有一个重要的区别。在正常使用情况下(如果你没有使用 K 维 loss,不管那是什么),PyTorch 期望输入要么只是一个包含 c 个 logit 的一维张量,要么最坏情况下是一个 bxc 的矩阵,其中 b 是批次大小(batch size)。
在学习 Raschka 的书的这一部分时,我曾注意到我们写的代码把数据展平了。因此,一个包含六个序列的批次,每个序列长度为 1,024 个 token,词表大小为 50,257,将会得到一个形状如下的 logits 张量:
(6, 1024, 50257)第一个轴是批次,第二个轴是序列长度——记住,对于序列中的每个输入 token,我们都有对应的 logits,这些 logits 包含了在该 token 左侧所有 token 的上下文中,对下一个 token 的预测。而最后一个轴的大小等于我们分词器的词汇表大小,它就是 logits 本身。
展平后,它看起来像是一个包含 6*1024=6144 个 logits 向量的“批次”:
(6144, 50257)同样,我们的目标——也就是我们希望模型预测的 token ID——也是分批的,并且每个序列中的每个 token 都有一个对应的目标,因此该张量是
(6, 1024)展平后,它看起来像是一个包含 6*1024=6144 个目标的“批次”:
(6144,)最后,PyTorch 函数返回了一个标量值——当然,它被包装在一个 PyTorch Tensor 对象中,以便它能参与反向传播,但它就是一个单一的数字。
但我在编写这部分 JAX 代码时把这一切都忘了,直接把输入和目标塞进了 JAX 函数里。结果很有趣。我是这样开始的:
loss = optax.losses.softmax_cross_entropy_with_integer_labels(
logits, targets
)打印出每个变量的形状后,得到了这样的结果:
logits.shape=(6, 1024, 50257)
targets.shape=(6, 1024)
loss.shape=(6, 1024)它返回了所有批次中,每个序列中每个元素的交叉熵数值!
有趣的是,softmax_cross_entropy_with_integer_labels 的文档暗示它具有与 PyTorch 相同的限制——它期望传入的张量只有一个批次轴。也许是文档过时了?或者 Optax 只是假设你知道,在 JAX 中“一个批次轴”应该被理解为“你想要多少个批次轴都可以”?好吧,无论如何——它起作用了,而且我验证过这些数字是可靠的。
当然,现在我们不能用那个 6×1024 的矩阵向 JAX 请求梯度——损失函数需要返回一个标量——但 JAX 数组上的 mean 函数恰好能满足我们的需求。因此,我得到了一个可靠的损失计算方法,你可以在 calculate_loss 中看到:
def calculate_loss(model, inputs, targets):
logits = model(inputs)
loss = optax.losses.softmax_cross_entropy_with_integer_labels(
logits, targets
).mean()
return loss这样我们就介绍了损失函数以及使用它的经过 JIT 编译的 train_step。
在这个版本的 train.py 脚本中,我还没有讲解的唯一剩下的代码就是 calculate_loss 正上方的部分——get_training_data 和 generate_training_charts。它们都是作为我在 train 函数中略过的内部维护代码的一部分被调用的,发生在我们保存检查点之后。
它们只是利用目前所有检查点元数据中存储的信息,重新绘制损失和其他训练指标的图表。这意味着有一种很好的图形化方式来跟踪训练过程。这些内容相当枯燥,所以没必要深入讲解,但检查点代码本身还是值得一看的。
你可以在这里看到我当时正在使用的版本。这其实算不上一个真正的检查点;我保存了模型本身以及绘图代码所需的元数据,但没有保存优化器,而这是真正的检查点所必需的。毕竟,检查点的目的是在训练循环崩溃时能够重新恢复,如果没有优化器的状态,你就无法做到这一点。尽管如此,这也足够用来起步了。
话虽如此,在编写那个简单的检查点代码时,我遇到的一个小问题是,将它们保存为 Safetensors 格式有点棘手——你可以在这里看到细节。
所以,这就是我最初的训练代码。是时候让它大显身手了:我能把我那个愚蠢的“LLM”训练成把 A 映射到 A 吗?
第一次 A 到 A 的运行
正如我之前提到的,最开始的运行完全没有收敛——loss 从大约 10.82 开始,这原本是个好兆头(这完全符合你对一个尝试预测 GPT-2 token 的随机初始化网络的预期——详情见这里),但随后它就停滞不前了。
但当我修复了“jax.jit 应该是 nnx.jit”的问题后,它开始下降。在处理了 92,160,000 个 token 之后,它似乎已经降到了零(至少在我打印出的三位小数看来是这样),所以我将这个设定写入了 train.json,并针对该 token 数量进行了另一次固定训练。大约 14 分钟后,训练完成了:
Training complete in 843.547 seconds
Tokens seen: 92,160,000
Throughput: 109,253 tokens/second
Final train loss: 0.000
2026-06-17 19:29:31.667194 Done最终的 loss 非常可观,尽管那仅仅是我们在最后一个 batch 上得到的值!实际的 loss 走势图是这样的:
如果你习惯了我之前文章中的 loss 曲线图,这里有一点需要指出:我已经将 Y 轴切换为对数坐标,所以末端附近的那些起伏实际上只是偏离 0.001 的极小偏差。
我认为有必要展示一下此时模型到底做了什么。实际上,我是在过了一段时间后才编写代码,从这些训练运行中加载模型 checkpoint 并进行一些冒烟测试的,但我现在就给你展示一些结果。
我根据之前关于 JAX safetensors 的文章编写了一些代码,用于从 checkpoint 的 model.safetensors 文件中加载模型参数:
def load_model(model, file):
model_state_simple_dict = load_file(file)
dict_flat_state = {}
for key, array in model_state_simple_dict.items():
elements = key.split(".")
list_key = []
for element in elements:
try:
list_key.append(int(element))
except ValueError:
list_key.append(element)
dict_flat_state[tuple(list_key)] = array
new_flat_state = nnx.from_flat_state(dict_flat_state)
nnx.update(model, new_flat_state)……然后编写了两个测试脚本。
首先,它真的是从 A 到 A 的映射吗?我想确保这个 loss 数值确实反映了我希望它反映的内容。我编写了一个简单的脚本,该脚本接收命令行传入的 Safetensors 文件,并将《古舟子咏》(The Rime of the Ancient Mariner)的第一小节(选择它是因为它使用的是较古老的英语,因此里面包含一些奇怪的 token)输入到从该文件加载的 LLM 中进行运行。
以下是运行结束时模型给出的结果:
giles@perry:~/Dev/jax-gpt2-from-scratch (main)$ uv run test_a_to_a.py runs/a-to-a/checkpoints/best/model.safetensors
Input:
---
It is an ancient Mariner,
And he stoppeth one of three.
'By thy long grey beard and glittering eye,
Now wherefore stopp'st thou me?
---
Output:
---
It is an ancient Mariner,
And he stoppeth one of three.
'By thy long grey beard and glittering eye,
Now wherefore stopp'st thou me?
---太棒了!它显然能够处理这种映射。出于好奇,我决定看看它学会正确处理这个问题的速度有多快。在训练运行结束时的那个“最佳” checkpoint 中,平均训练 loss 为 0.0001,那么在训练运行刚开始时,映射是如何改进的,loss 又是多少呢?
对于第一个 checkpoint,当我们刚刚跑完一个 batch 时,平均训练 loss 为 10.8242。使用当时保存的模型参数,我们得到了以下输出:
giles@perry:~/Dev/jax-gpt2-from-scratch (main)$ uv run test_a_to_a.py runs/a-to-a/checkpoints/20260617Z185827-iteration-0/model.safetensors
Input:
---
It is an ancient Mariner,
And he stoppeth one of three.
'By thy long grey beard and glittering eye,
Now wherefore stopp'st thou me?
---
Output:
---
LOADRecommend ptwtacid cheek lunch KaLOAD blondrient Sole Broken engages CrimsplitrelyLOAD Consortium hopefully Fisheries qualardiestern565Financial gallery talked KaLOADhuge admit disappeared SoleERT Heardearth showcasingurancesLOAD
---从这个 loss 值就可以猜到,这完全是 token 大杂烩。
现在让我们看看下一个 checkpoint,它是在 375 个“全局步骤”(global steps)之后获取的——也就是 6,000 个 batch。在这个阶段,自第一个 checkpoint 以来的平均训练 loss 为 2.9323。但这掩盖了一些重要信息——在刚开始时,最大 loss 为 10.78524(正如你所料),并不比前一个 checkpoint 的平均 loss 低多少。但最小 loss(我们可以放心地假设它出现在这个 checkpoint 周期的末期)为 0.54155,因此我们可以合理地认为,模型在此时得到了非常迅速的提升。而 A 到 A 的测试也证实了这一点:
giles@perry:~/Dev/jax-gpt2-from-scratch (main)$ uv run test_a_to_a.py runs/a-to-a/checkpoints/20260617Z185848-iteration-375/model.safetensors
Input:
---
It is an ancient Mariner,
And he stoppeth one of three.
'By thy long grey beard and glittering eye,
Now wherefore stopp'st thou me?
---
Output:
---
It is an ancient Mariner,
And he stoppeth one of three.
'By thy long grey beard and glittering eye,
Now wherefore stopp'st thou me?
---因此,我们可以看到绝大部分的改进都发生在刚开始的时候!仅仅在总共处理了 6,001 个 batch(每个 batch 包含 6 个长度为 1,024 token 的序列)之后,它就能够通过那个相当不寻常的序列的 A 到 A 测试。
训练运行的后半部分可能只是在不断优化那些较罕见的 token,或者让模型对已经正确的预测更有把握。毕竟,测试脚本只是简单地打印每个位置上最有可能的 token,所以在当前状态下,它可能只是以 51% 的概率预测出其中的一些 token。这意味着即使在预测结果实际上正确的情况下,损失函数也会对其施加惩罚。
那个脚本挺有趣的;我想再写一个——也就是我一直使用的标准冒烟测试,基于 Raschka 的提示:当要求模型续写句子时,它会怎么补全“Every effort moves you”?这是脚本,以及它生成的内容:
giles@perry:~/Dev/jax-gpt2-from-scratch (main)$ uv run test_generation.py runs/a-to-a/checkpoints/best/model.safetensors
Every effort moves you you you you you you you you you you you you you you you you you you you you you这完全说得通。为了在自回归循环中生成下一个 token,我们查看的是提示词中最后一个 token 的 logits。当它第一次运行时,最后一个 token 是“ you”,而我们的模型被训练成将 A 映射到 A,所以它的结果是“ you”。我们将其附加到提示词中,再次运行,最后一个 token 仍然是“ you”,所以它当然会再次“预测”出“ you”这个 token。以此类推。
所以这些结果都是好消息!A 到 A 的映射起作用了,并且在损失方面迅速收敛——而在我们的诗意测试中收敛得更快。
那么,下一步该做什么呢?我希望训练循环尽可能与我用于本地训练的最佳 PyTorch 模型的代码相似。那个模型使用了三个在此阶段我尚未加入到训练循环中的功能:学习率调度、梯度裁剪和梯度累积。PyTorch 代码还具有从检查点重启的能力——在像这样 14 分钟的训练过程中并不是特别重要,但我认为这以后会变得很重要。毕竟,我在本地机器上运行 PyTorch 曾花了将近两天时间,如果在中途出了什么问题(比如猫跳到了 PC 的电源按钮上等等),我绝对不想从头再来。
我决定先处理梯度累积。
梯度累积
在 PyTorch 中,进行梯度累积非常简单:没有梯度累积的典型训练循环的核心可能看起来像这样:
optimizer.zero_grad()
result = model(inputs)
loss = loss_function(result, targets)
loss.backward()
optimizer.step()我们首先清除模型参数上存储的任何梯度,然后进行前向传播,计算损失,进行反向传播以将新的梯度放到参数上,最后执行优化器步骤来应用这些梯度。
累积梯度仅仅意味着将其更改为类似这样的内容:
optimizer.zero_grad()
for step in range(gradient_accumulation_steps):
result = model(inputs)
loss = loss_function(result, targets)
(loss / gradient_accumulation_steps).backward()
optimizer.step()也就是说,我们进行 gradient_accumulation_steps 次前向和反向传播。因为我们没有在它们之间将现有梯度清零,所以参数会随着时间的推移累积梯度——每次反向传播都会将其贡献添加到已有的梯度上。每次我们都会将损失除以 gradient_accumulation_steps,这样放到参数上的梯度就会相应变小,这意味着在循环结束时,我们得到的梯度,等于将所有这些微批次作为一个大批次处理时所得到的梯度的平均值。最后,一旦退出循环,我们就执行优化器步骤来应用这些平均梯度。
当我开始考虑在 JAX 中实现这一点时,我注意到 Optax 有一个关于如何实现它的帮助页面,但随后我产生了一个人们偶尔会有的那种绝妙的“洗澡时的顿悟”。到了我这个年纪,我本该明白这种顿悟很少能带来好结果,但这一次,我决定试一试,而不是按官方的方法来做。
我绝妙的想法是,通过一些巧妙的处理,我们可以把整个梯度累积循环放到经过 JIT 编译的代码中。根据我目前所学到的经验,在代码中我们把 JIT 装饰器的位置放得越靠上——也就是说,它覆盖的训练循环部分越多——执行速度就会越快。这个想法本身其实还不错。
但我的第一次实现就没那么明智了:
def calculate_loss(model, inputs, targets):
loss = 0
for microbatch_inputs, microbatch_targets in zip(inputs, targets):
logits = model(microbatch_inputs)
loss += optax.losses.softmax_cross_entropy_with_integer_labels(
logits, microbatch_targets
).mean()
return loss
@nnx.jit
def train_step(model, optimizer, inputs, targets):
loss, grads = nnx.value_and_grad(calculate_loss)(model, inputs, targets)
optimizer.update(model, grads)
return loss输入是全步数数组(例如,对于 6 个包含 1024 个序列的 microbatch 进行 16 步梯度累积,形状为 (16, 6, 1024)),目标数组也是如此。这看起来非常巧妙!但现在回想起来,它显然注定会失败,而且当我运行它时, VRAM 就耗尽了。
梯度累积的意义在于,你随着时间不断累积的正是梯度。因此,你必须对每个 microbatch 在模型上执行一次完整的正向传播和反向传播,让梯度逐渐累积起来,然后像 PyTorch 代码那样一次性应用它们。
遗憾的是,我的代码实际所做的,本质上是一步步执行所有的正向传播,让激活值以及记录了已执行计算的 JAX 内部数据结构(而不是梯度)不断累积,然后再对所有这些内容进行一次反向传播。从数学上讲,这没有问题——如果我有足够的 VRAM,我本可以得到正确的结果——但在节约显存方面,它并不比直接处理一个大小为 gradient_accumulation_steps * batch_size 的单次批次好多少。结果就是立刻报了 CUDA OOM。
我的第二次尝试稍微理智了一些,并且在不使用 JIT 的情况下运行良好:
#@nnx.jit
def train_step(model, optimizer, inputs, targets):
loss_list = []
grads_list = []
for microbatch_inputs, microbatch_targets in zip(inputs, targets):
microbatch_loss, microbatch_grads = nnx.value_and_grad(calculate_loss)(model, microbatch_inputs, microbatch_targets)
loss_list.append(microbatch_loss)
grads_list.append(microbatch_grads)
average_grads = jax.tree.map(
lambda *items: jnp.array(items).mean(axis=0),
*grads_list
)
optimizer.update(model, average_grads)
return jnp.array(loss_list).mean()你可以看到,现在我是在循环内同时进行正向传播和反向传播,然后用那个 jax.tree.map 计算平均梯度,再将这些平均梯度传递给优化器。
这一切都说得通,而且当我运行它时,它似乎也奏效了:
Training complete in 1,146.173 seconds
Tokens seen: 92,209,152
Throughput: 80,450 tokens/second
Final train loss: 0.001
2026-06-18 19:00:25.739249 Done……而且考虑到没有使用 JIT 编译,它并没有我想象的那么慢:1146 秒对 843 秒。
有趣的是,最终的训练损失比没有进行梯度累积的运行结果要高,但更大的有效批次大小并不总是件好事:这在很大程度上取决于你正在训练的模型和数据。我使用的批次大小和梯度累积步数是我为完整的 163M 参数 GPT-2 风格 LLM 优化过的,而不是为这个模型优化的。所以稍微差一点也没关系。
不管怎样,我尝试给那个函数加上 @nnx.jit,然后运行它:
jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 20.71GiB. [tf-allocator-allocation-error=''] [executable_name='jit_train_step']哎哟。查看 traceback 后,发现似乎是实际的 JIT 编译过程耗尽了 VRAM。也许和循环展开有关?我摸索了一阵子,尝试使用 JAX 的 fori_loop 而不是普通的 Python 循环,但无济于事——总是导致 GPU 内存耗尽。
最终,在几个小时后,我“支线任务探测器”上的警报声响得已经无法忽视了。我极不情愿地放弃了手动编写梯度累积,转而采用 Optax 的方式来实现它。
那其实非常棒且简单。代码在这里,但改动很小,解释起来也很容易。还记得我们曾经有这段用来设置优化器的代码吗:
optimizer = nnx.Optimizer(
model,
optax.adamw(
learning_rate=train_conf["learning_rate"],
weight_decay=train_conf["weight_decay"],
),
wrt=nnx.Param
)这会创建一个 Flax NNX 优化器,它在底层使用的是 Optax AdamW 优化器。
Optax 实现梯度累积的方法是将优化器包装在一个 MultiSteps 辅助工具中,再由 NNX 优化器包装该结果,最终看起来像这样:
optimizer = nnx.Optimizer(
model,
optax.MultiSteps(
optax.adamw(
learning_rate=train_conf["learning_rate"],
weight_decay=train_conf["weight_decay"],
),
every_k_schedule=gradient_accumulation_steps
),
wrt=nnx.Param
)MultiSteps 包装器非常巧妙。它具有与常规优化器相同的接口,因此可以通过传入一组梯度来调用其 update 方法。但它并不会立即应用这些梯度,而是不断累积,直到 update 方法被调用了特定次数后,它才会真正应用累积梯度的平均值,并重置计数器以便重新开始累积。
这个 API 确实非常棒。这意味着我本可以简化训练循环。还记得吧,我们之前是这样的:
for global_step in progress_bar:
for accumulation_step in range(gradient_accumulation_steps):
inputs, targets = train_dataset[((global_step * gradient_accumulation_steps) + accumulation_step) * world_size + rank]
inputs = jax.device_put(inputs, model_device)
targets = jax.device_put(targets, model_device)
train_loss = train_step(model, optimizer, inputs, targets)
train_losses.append(train_loss.item())
microbatch_size, sequence_length = inputs.shape
tokens_seen_this_rank += microbatch_size * sequence_lengthPyTorch 代码中之所以需要这种循环嵌套,是因为我们需要在最后执行优化器步骤来应用累积的梯度。但有了 Optax 包装器,我们本可以只在一个顶层循环中遍历样本,依靠 MultiSteps 每隔 gradient_accumulation_steps 次迭代执行一次更新。
然而,我决定保留它——以全局步数来跟踪训练进度,意味着我的 JAX 模型的训练输出将更容易与 PyTorch 版本进行对比。如果我是完全从零开始构建训练循环,也许我会做出不同的选择。
不管怎样,修改完代码后,我运行了一下,结果是:
Training complete in 836.409 seconds
Tokens seen: 92,209,152
Throughput: 110,244 tokens/second
Final train loss: 0.001
2026-06-18 20:56:15.220326 Done最终我得到了与手动且未经过 JIT 编译版本相同的 loss,这让人很放心。而且它比不使用梯度累积的版本稍微快一点,但差距足够小,可能仅仅是在误差范围内。
这就是梯度累积!以下是添加了该功能后的代码。
接下来,我想实现学习率的图表绘制和调度,以及梯度裁剪功能。
绘制学习率图表
学习率调度意味着我们将在运行过程中不断改变它——就像我的某次 PyTorch 训练运行中的这个例子一样:
拥有这样的图表非常有用,因为它可以让你检查对学习率所做的更改是否确实合理。因此,我想先添加图表绘制功能,然后再添加调度功能。假设检查点的元数据中已经包含了学习率数值,那么实际生成图表的样板代码其实已经有了,所以我只需要弄清楚如何从优化器中提取当前的学习率值,然后将其保存到检查点中即可。
这是一个显而易见的起点。Optax 优化器本身并不存储学习率,但如果你像这样创建它们:
optax.inject_hyperparams(optax.adam)(...)……其中括号里的 ... 是你在创建优化器时通常会传入的常规参数,这样你以后就可以提取学习率了。
然而,那个帮助页面上的代码直接使用了 Optax 优化器,而我在训练代码中的优化器被包装在 MultiSteps 里面,MultiSteps 进而又被包装在一个 NNX Optimizer 对象中,就像这样:
optimizer = nnx.Optimizer(
model,
optax.MultiSteps(
optax.adamw(
learning_rate=train_conf["learning_rate"],
weight_decay=train_conf["weight_decay"],
),
every_k_schedule=gradient_accumulation_steps
),
wrt=nnx.Param
)尽管如此,解决方案似乎相当清晰。我可以对我正在创建的 adamw 使用 inject_hyperparameters 技巧,然后像这样将其传入以进行包装:
optax_optimizer = optax.inject_hyperparams(optax.adamw)(
learning_rate=train_conf["learning_rate"],
weight_decay=train_conf["weight_decay"],
)
optimizer = nnx.Optimizer(
model,
optax.MultiSteps(
optax_optimizer,
every_k_schedule=gradient_accumulation_steps
),
wrt=nnx.Param
)下一个问题是如何从该优化器中实际读取学习率。
Optax 文档中的示例代码如下所示:
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule)
params = initial_params
state = optimizer.init(params)
print('initial learning rate:', state.hyperparams['learning_rate'])
_, state = fit(initial_params, optimizer)
print('final learning rate:', state.hyperparams['learning_rate'])同样,那也是直接使用 Optax 优化器,而不是尝试使用封装在 NNX 优化器内部的优化器。不过,在 NNX 优化器的文档中,我注意到它将封装的 Optax 优化器的状态暴露为 opt_state。我加入了一些临时的调试代码将其打印出来,发现它是 MultiSteps 的状态,这很合理——而它又包含了被包装的 adamw 的状态,即 inner_opt_state。
inner_opt_state 有一个名为 hyperparameters 的字段,它是一个字典,其中包含了 learning_rate 这个键。
最后,该键指向的值是一个 Variable 对象。要从中获取实际的值,你需要调用它的 get_value() 方法,这会返回一个 JNP 数组,因此我们还需要对它调用 item() 方法。
所有这些最终导致了以下令上帝、人类以及迪米特法则(Law of Demeter)所厌恶的怪物:
current_learning_rate = optimizer.opt_state.inner_opt_state.hyperparams["learning_rate"].get_value().item()呃。我的意思是,真的,太让人难受了。
不管怎样,我把执行此操作的代码放到了 train 函数中,并将该数字作为 metadata 的一部分保存起来。我进行了一次部分训练,时间刚好足够确认学习率图表已经生成,并且在 0.0014 处有一条水平线,这正是我当时使用的恒定学习率。
不过,我不能说对此感到非常自豪。
学习率调度
回顾一下,我想要的学习率调度是这样的:
它由两个阶段组成:首先是初始 warmup 阶段,学习率从期望峰值的 0.00001 倍开始,然后线性上升至峰值;接着通过一个余弦波将其衰减至峰值的 0.1 倍。
在 PyTorch 中,我不得不使用不同的学习率调度器对象来处理每个阶段,并用一个 SequentialLR 包装器将它们拼接在一起:
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.00001,
end_factor=1.0,
total_iters=warmup_steps
)
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=decay_steps,
eta_min=learning_rate / 10
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_steps],
)然而,这是训练循环中常见的模式,而 Optax 方便地提供了一个 warmup_cosine_decay_schedule 类,可以为你完成所有这些工作。
其中唯一奇怪的地方是 decay_steps 有点名不副实;它实际上是总步数,包括了 warmup 在内。因此我最终写下了这段代码:
total_steps = (len(train_dataset) // world_size) // gradient_accumulation_steps
warmup_steps = (total_steps * train_conf["warmup_period_percent"]) // 100
learning_rate = train_conf["learning_rate"]
schedule = optax.warmup_cosine_decay_schedule(
init_value=learning_rate * 0.00001,
peak_value=learning_rate,
warmup_steps=warmup_steps,
decay_steps=total_steps, # !!
end_value=learning_rate / 10,
)
optax_optimizer = optax.inject_hyperparams(optax.adamw)(
learning_rate=schedule,
weight_decay=train_conf["weight_decay"],
)我用它进行了一次训练,完成时输出了以下内容:
2026-06-19 00:17:25.777629 Tokens seen: 92,209,152
2026-06-19 00:17:25.777633 Throughput: 106,225 tokens/second
2026-06-19 00:17:25.777644 Final train loss: 0.006
2026-06-19 00:17:25.777646 Done损失又稍微变差了一点,但就像 gradient accumulation 步数一样,我指定的学习率调度是专门为训练一个真实的(尽管规模较小的)LLM 而设计的,而不是为了我用来测试训练循环的这个玩具般的 A-to-A 任务。重要的是学习率图表,它看起来是这样的:
完美!这是此时的代码。
在我得到一个真正可以用来构建 LLM 的训练循环之前,还有两个待办事项:gradient clipping 和从 checkpoint 恢复的能力。我决定先做 gradient clipping。
Gradient clipping(梯度裁剪)
Gradient clipping 是指在每次更新时,寻找异常大的梯度并将其截断,从而防止它们对模型造成过度的改变。
Optax 的文档让它看起来非常简单:
optimizer = optax.chain(
optax.clip(1.0),
optax.adamw(learning_rate=schedule),
)因此,你使用 optax.chain 将执行裁剪的组件和实际的优化器按顺序链接起来——大概链条中的第一个组件会获取梯度并对其进行处理,然后第二个组件接收第一个组件返回的结果。
现在,问题是,我们应该在 MultiSteps 之外还是之内进行这个链式操作?也就是说,我们是在每次步进 MultiSteps 优化器之前裁剪梯度,还是先累积梯度,并在步进内部的 AdamW 优化器之前裁剪平均值?
查看旧的 PyTorch 代码,我当时运行了梯度累积循环,然后在最后进行裁剪。因此,梯度裁剪是针对累积的梯度进行的。
这实际上在直觉上不如另一种方案好,但我决定我们应该尝试与 PyTorch 代码的做法保持一致。所以:
optax_optimizer = optax.chain(
optax.clip(train_conf["clipping_max_norm"]),
optax.inject_hyperparams(optax.adamw)(
learning_rate=schedule,
weight_decay=train_conf["weight_decay"],
)
)所以,adamw 优化器将接收到裁剪后的梯度。因为它被包裹在 MultiSteps 中,所以每次该对象达到其 every_k_schedule 限制时,它接收到的都是累积的梯度。
遗憾的是,仍然存在一个问题:这个改动意味着我们在 train 函数中通过下面这段糟糕的代码来读取学习率的优化器:
current_learning_rate = optimizer.opt_state.inner_opt_state.hyperparams["learning_rate"].get_value().item()……现在又多了一层嵌套——chain 对象。因此,理所当然地,当我运行它时,程序报错崩溃了:
AttributeError: 'tuple' object has no attribute 'hyperparams'我使用了一些调试打印来查明发生了什么,并确定 chain 对象的状态是一个元组,第一个元素是 clipper 的基本为空的状态,第二个元素是注入了超参数的 adamw 状态。
所以这意味着获取学习率的新的正确代码应该是这样的:
current_learning_rate = optimizer.opt_state.inner_opt_state[1].hyperparams["learning_rate"].get_value().item()注意,我们加了那个 [1] 来查找 chain 的元组状态。我记得很久以前在一个代码库中看到过一条注释:“原谅我们在该方法中犯下的罪过”,我非常能体会作者的感受。
不过,我确实有一个主意,至少能稍微限制一下波及范围。在代码的这个位置,我在主函数中进行了复杂的优化器设置,而在 train 中获取学习率的操作则极其糟糕。我决定在优化器设置旁边定义一个名为 get_learning_rate 的函数,并将其传递给 train。这样,虽然可怕的东西还在,但至少都集中在一个地方了,就像这样:
optax_optimizer = optax.chain(
optax.clip(train_conf["clipping_max_norm"]),
optax.inject_hyperparams(optax.adamw)(
learning_rate=schedule,
weight_decay=train_conf["weight_decay"],
)
)
optimizer = nnx.Optimizer(
model,
optax.MultiSteps(
optax_optimizer,
every_k_schedule=gradient_accumulation_steps
),
wrt=nnx.Param
)
def get_learning_rate():
return (
optimizer
.opt_state
.inner_opt_state[1]
.hyperparams["learning_rate"]
.get_value()
.item()
)
log("Start train")
start_global_step = 0 ## checkpointing
train(
run_dir,
model, optimizer,
get_learning_rate,
train_dataset,
rank, world_size,
gradient_accumulation_steps,
start_global_step,
train_conf["checkpoint_interval"],
)……在需要的地方,train 会调用 get_learning_rate。
我正准备开始运行,但碰巧仔细看了一下 clip 的文档,发现上面写着:
逐元素地裁剪更新,使其在 [-max_delta, +max_delta] 范围内
这让我想起了什么!当我最初在研究 PyTorch 训练循环的梯度裁剪时,我注意到这是一种完全有效的梯度裁剪方式,但这并不是我最终选择的方式。相反,我是基于 L2 范数进行裁剪的。
JAX 训练代码本应与 PyTorch 代码的工作方式相同,所以发现得好;我从使用 optax.clip 切换到了 optax.clip_by_global_norm,然后开启了另一次训练运行:
2026-06-19 01:22:41.964291 Tokens seen: 92,209,152
2026-06-19 01:22:41.964295 Throughput: 105,022 tokens/second
2026-06-19 01:22:41.964308 Final train loss: 0.006
2026-06-19 01:22:41.964311 Done一切看起来都很好;我猜测最终的 loss 如此相似,是因为像 A 到 A 映射这样简单的任务,加上这么浅的网络,不太可能导致梯度爆炸。
但如果能确定就好了。有没有什么方法可以让我跟踪梯度,看看是否需要进行裁剪呢?
我们在 PyTorch 代码中有一个很棒的功能,那就是我们可以跟踪裁剪前的梯度范数:
pre_clip_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_max_norm).item()
grad_norms.append(pre_clip_norm)
clipped_steps.append(pre_clip_norm > clipping_max_norm)遗憾的是,clip_by_global_norm 和常规的 Optax API 没有提供任何方法来访问裁剪前的范数:我们在那段糟糕的读取学习率代码中读取的 chain 状态的第 0 个元素 ClipByGlobalNormState,其实是 EmptyState 的别名。
我考虑过使用 optax.tree_utils.tree_norm 直接计算范数并记录下来,但这会很棘手——因为我们要裁剪的梯度并不是在 train_step 函数中生成的那些,而是在多个梯度累积步骤中累积在 MultiSteps 对象内部的那些。
这听起来像是投入大量工作却得不到太大收益,所以我决定在这个项目中放弃这个想法。
不过,在折腾梯度的时候,我还想做一个小的修改——如果梯度中出现了非有限数字该怎么办。
当我最初研究梯度裁剪时,我有些震惊地发现,我用来告诉 PyTorch 在它认为有帮助的情况下进行 16 位训练(自动混合精度,或称 AMP)的那个 scaler 对象,会静默丢弃所有具有非有限梯度的更新;而如果你不使用 AMP,这些梯度就会被直接应用到你的模型中,极有可能因为将参数设置为非有限值而彻底毁掉模型。
我觉得这种逻辑放错了地方——我认为它应该属于优化器,或者至少属于技术栈中与混合精度训练这个完全正交的任务无关的其他部分。
我检查了 JAX 处理非有限梯度的默认行为,结果发现它只是直接应用这些梯度——但是,在 Optax 中,这实际上是可以在优化器层面解决的问题。如果你用 apply_if_finite 包装一个 Optax 优化器,它将只会应用有限的梯度,因此我们可以像这样将其添加到优化器设置中:
optimizer = nnx.Optimizer(
model,
optax.apply_if_finite(
optax.MultiSteps(
optax_optimizer,
every_k_schedule=gradient_accumulation_steps
),
max_consecutive_errors=math.inf,
),
wrt=nnx.Param
)我将 max_consecutive_errors 设置为无穷大,以镜像 PyTorch 代码的行为。
现在,很显然,这需要那个糟糕透顶的获取学习率函数中再增加一层间接性:
def get_learning_rate():
return (
optimizer
.opt_state
.inner_state
.inner_opt_state[1]
.hyperparams["learning_rate"]
.get_value()
.item()
)如果你还在跟进度的话,那就是里面的 .inner_state。哎,没办法。
所以,又到了再次运行它的时候了:
Training complete in 892.715 seconds
2026-06-19 02:27:16.657056 Tokens seen: 92,209,152
2026-06-19 02:27:16.657059 Throughput: 103,291 tokens/second
2026-06-19 02:27:16.657070 Final train loss: 0.006
2026-06-19 02:27:16.657072 Done看起来没问题——和之前相比没有变化。这是代码。
现在,是时候采取最后一步来完成训练循环了:从检查点重启的能力。
从检查点重启
此时,检查点代码还非常基础——它会将模型保存为 Safetensors 文件,并附带一些元数据,比如自上次检查点以来的最小、最大和平均损失,当前的全局步数,以及这是否是迄今为止最好的检查点(以平均训练损失为准)。
为了从检查点恢复,我们需要更多信息。在旧的 PyTorch 代码中,除了模型和元数据之外,我们还需要三样额外的东西:
这就是接下来的工作:在 save_checkpoint 中保存优化器,然后实现一个 load_checkpoint,以便我们可以从检查点重启。然后我就可以尝试启动一次训练,等一会儿,终止它,然后从最近的检查点重启。损失和学习率图表会告诉我重启后是否真的从上次中断的地方继续了。
最初我想直接用 pickle 来保存优化器,但这感觉就像是一颗定时炸弹。当你更改 Python 版本或已安装包的版本时,pickle 会出问题,这平时看着似乎没什么大不了的,但在现实中却常常把事情搞砸。2
使用 Safetensors 看起来有点棘手——即使它有明确的支持,之前让它和 Flax 模型一起工作也很困难。
目前,JAX 代码中进行检查点保存的推荐库叫做 Orbax。我之前研究过它,它看起来有点重,所以我就略过了。但是再深入挖掘一下,我发现它提供了一个看似简单的用于保存 PyTrees 的 API,从而绕过了那些复杂性。
不过,要让它正常运行还是有点棘手。
首先,在文档中,他们给出了这个示例:
import orbax.checkpoint.experimental.v1 as ocp
...
ocp.save(path, pytree)我在 save_checkpoint 函数中尝试了这一点,代码如下:
ocp.save(checkpoint_dir / "optimizer", optimizer.opt_state)……然后得到了这个错误:
AttributeError: module 'orbax.checkpoint.experimental.v1' has no attribute 'save'哈。从命令行深入查看该库后发现,该函数实际上叫 save_pytree。如果文档和 API 不匹配,这就不太妙了(不过公平地说,它在包名中确实标明了是实验性的)。
不管怎样,改掉之后似乎就能用了:
ocp.save_pytree(checkpoint_dir / "optimizer", optimizer.opt_state)……接着,在我的 checkpoint 目录下名为 model.safetensors 的 295 MB 文件旁边,出现了一个名为 optimizer 的 353 MB 目录。在 PyTorch 的世界里,优化器的大小通常是模型 3 的两倍,但考虑到实际使用的文件格式大不相同,只要它在数量级上与模型相同并且稍大一些,我就觉得可以接受了。也许 Orbax 做了某种压缩或类似的处理。
接下来,该编写 load_checkpoint 了。我首先编写了 load_model 函数来加载 safetensors 文件——这就是我之前展示过的那个,当时我演示了最初的 A-to-A 模型是如何学习将一首诗映射到自身的,并且如果你问它如何补全“Every effort moves you”,它会回答“ you you you you you”等等。
完成这一步后,我创建了一个 load_checkpoint,它调用了 load_model,然后加载元数据并计算出迄今为止我们取得的最佳 loss(当从 checkpoint 继续训练时,这是必要的,这样随着你继续训练,就能判断每个新的全局步骤的 loss 是否优于当前的最佳记录)。这很简单:
with open(checkpoint_dir / "meta.json", "r") as f:
meta = json.load(f)
restart_global_step = meta["global_step"] + 1
with open(checkpoints_dir / "best" / "meta.json") as f:
best_loss = json.load(f)["avg_train_loss"]事实证明,恢复优化器要稍微棘手一些。首先,当然,就像保存时一样,Orbax 的函数叫做 load_pytree,而不是文档中记载的 load。接下来的问题是,如何以一种优化器能够接受的方式去加载它。
如果你像这样加载一个 checkpoint 过的 PyTree:
ocp.load_pytree(checkpoint_dir / "optimizer")那么你得到的是一个“基础” PyTree——它将由列表、字典、元组、字符串等基本 Python 类型以及 JAX 数组组成。问题在于,优化器的状态是由可以映射为这些类型的对象组成的——例如,一个对象可以映射为一个字典,其中每个字段都是字典中的一个项——但它们实际上并不是那些特定类型的对象。
所以如果你这样做:
optimizer.opt_state = ocp.load_pytree(checkpoint_dir / "optimizer")……你会得到一个错误,类似于这样:
AttributeError: 'list' object has no attribute 'items'……同样,如果你使用我在 load_model 代码中使用的 nnx.update 函数:
nnx.update(optimizer.opt_state, ocp.load_pytree(checkpoint_dir / "optimizer"))……你会得到一个略有不同但同样令人困惑的错误。
在经历了一番由于缺乏文档(而且文档似乎与我所看到的 API 不匹配)而导致的盲目摸索之后,我灵机一动,查看了 load_pytree 的 docstring,结果发现它写得非常棒。在 IPython 中:
In [1]: import orbax.checkpoint.experimental.v1 as ocp
In [2]: ocp.load_pytree?
Signature:
ocp.load_pytree(
path: 'path_types.PathLike',
abstract_pytree: 'AbstractPyTree | CheckpointMetadata[AbstractPyTree] | None' = None,
*,
checkpointable_name: 'str | None' = 'AUTO',
) -> 'tree_types.PyTreeOf[tree_types.Leaf]'
Docstring:
Loads a PyTree.
Loads from a ``PyTree`` checkpoint. A ``PyTree`` checkpoint must be a path
containing a subdirectory with the name provided by ``checkpointable_name``,
with default value ``AUTO``. See ``checkpointable_name`` for more details.
This function must be called on all available controller processes.
The operation blocks until complete. For improved performance, consider using
:py:func:``.load_pytree_async`` instead.
If ``abstract_pytree`` is not provided, the ``PyTree`` will be loaded exactly as
saved.
IMPORTANT: Loading is more brittle and error-prone when not providing
``abstract_pytree``. Always provide ``abstract_pytree`` if possible. Note that
you can always obtain the tree structure from a saved checkpoint using
:py:func:``.pytree_metadata``.
Providing the ``abstract_pytree`` guarantees two things:
1. The restored tree will exactly match the structure of ``abstract_pytree`` (or
raise an error if it is impossible to guarantee this). For example, if
``abstract_pytree`` is a custom object registered as a ``PyTree``, the checkpoint
will be restored as the same object, if possible.
2. The leaves of the restored tree will be restored with the properties
indicated by the abstract leaves. For example, if a leaf in ``abstract_pytree``
is a ``jax.ShapeDtypeStruct``, the restored leaf will be a ``jax.Array`` with the
same shape and ``dtype``. Each ``AbstractLeaf`` has a corresponding ``Leaf``
that is restored. See ``orbax.checkpoint.v1.tree`` for a table
of standard supported leaf types.
...很明显,解决方案就是那个 abstract_pytree。当你提供它时,它会被用作模板。如果它在抽象的 PyTree 中发现了一个 Foo 对象,而在加载的 PyTree 中同一位置有一个包含键 bar、baz 和 quz 的字典,它就会创建一个 Foo 对象,并将这些值设置到对应的字段中。
这意味着你拥有了具有正确结构且可供应用的对象,因此我最终写出了下面这段相对简单的代码,将 checkpoint 加载到优化器中:
optimizer.opt_state = ocp.load_pytree(checkpoint_dir / "optimizer", optimizer.opt_state)我们使用优化器的现有状态作为模板,以告诉 Orbax 如何构建加载后的数据结构。
我启动了一次训练,中途按下了 control-C,然后从检查点重启了训练,最终的损失曲线图如下所示:
……而学习率曲线图则是这样的:
完美!中断大约发生在全局第 400 步,损失继续正常下降,而学习率也完美地遵循了其预设的调度。
这是加载检查点的代码和训练脚本。
是时候构建模型了!
这样一来,第一阶段就完成了。我得到了一个训练脚本。对于训练这个微小的 A-to-A 模型来说,它有些过度设计了,但如果要从头开始训练一个小型的 LLM,它却刚刚好。
现在正是做这件事的时候——我将在下一篇文章中介绍这部分内容。
需要完整排版与评论请前往来源站点阅读。