Flax 调试:利用哈希解决问题Flax debugging: making a hash of things
在调试 JAX/Flax NNX 训练循环时,定位梯度是否真正传递并更新到模型参数是一个棘手问题。开发者通常难以通过直接打印参数来判断优化器或损失函数是否生效,因为底层数据结构的更新方式十分隐蔽。作者分享了一个巧妙的小技巧,通过计算和比对参数的哈希值来追踪训练过程中的实际变化。这种方法有效隔离了模型缺陷、损失函数错误与底层“管道”通信不畅等常见机器学习调试障碍。
Giles Thomas
归档
分类
友情链接
发表于 2026年6月17日,收录于 AI、TIL、JAX、Python
前几天我在调试一个 JAX/Flax NNX 训练循环的问题,并发现了一个有助于调试的巧妙小技巧。具体来说,我想确认问题究竟是出在我的模型、损失函数、优化器设置,还是训练循环本身的“管道”上——梯度是否真的在反向传递并应用到了参数上?
我可以打印出损失和梯度,但打印参数来观察它们是否发生变化并没有什么帮助——单次更新可能只会改变一小部分参数,或者改变的幅度极小以至于我根本察觉不到——考虑到该模型有 7700 万个参数,情况更是如此!
让我们来看一看。
世界上最差的 LLM
我正在用 JAX 和 Flax NNX 从头构建一个 LLM,目前我正致力于让训练循环正确运行。作为一个简单的测试,我只实现了 LLM 的“外壳”——即输入端的 token embeddings 和用于输出头的最终线性层,并将它们直接连接在一起。我的计划是训练它,使得给定一个序列时,它不是预测每个位置的下一个 token,而是“预测”序列本身——也就是说,我可能会用这样的输入来训练它
The fat cat sat on the mat……以及目标
The fat cat sat on the mat……而不是 LLM 的常规设置,在常规设置中你会输入
The fat cat sat on the……并给它以下目标
fat cat sat on the mat因此,用 LLM 的术语来说,我是在训练一个模型,将词汇表空间(vocab space)投影到一个学习到的嵌入空间,在这个嵌入空间中,每个 token 都有足够独特的 embedding,使得输出头能够可靠地将其投影回词汇表空间中的 logits。如果你觉得这难以理解,这里有一些背景知识。
这是我当时正在处理的核心代码,即 train_step 函数。这似乎是 JAX 中的一个传统命名,用于指代代码中经过 JIT 编译的部分,该部分负责执行模型的前向传播、计算梯度,然后应用梯度来更新模型:
@jax.jit
def train_step(model, optimizer, inputs, targets):
loss, grads = nnx.value_and_grad(calculate_loss)(model, inputs, targets)
optimizer.update(model, grads)
return loss我是基于 Flax 网站首页上目前的“基本用法”示例编写的。经验丰富的 Flax 老手可能会立刻发现问题所在,但我当时并没有看出来——所以是时候深入探究了。
处理损失
问题在于损失值没有下降——事实上,保留两位小数的话,它一直卡在 10.82。虽然之后的小数位会随每个批次变化,但前四位数字是不变的。由于这个模型使用的是 GPT-2 分词器(tokeniser),如果模型本质上是在随机猜测,10.82 正是你预期的损失值——如果你通过计算 e^10.82 将其转换为困惑度(perplexity),结果大约是 50,011——这非常接近 GPT-2 的词汇表大小 50,257。通俗地说,困惑度是模型在处理典型输入时需要从中选择的 token 数量,因此,对于一个每 50,257 次才猜对一次的随机模型来说,困惑度等于词汇表大小正是你所预期的结果。
话虽如此,能够持续获得这样的损失值,是对我的损失函数的有力验证!如果我把损失函数搞砸了,它几乎不可能如此稳定地得出那个特定的数字。我在小数点后第三位及之后看到的微小变化也是合理的,因为它们很容易归因于不同批次内容的变化。
滑向疯狂的梯度下降
那么究竟是梯度变成了零,还是变成了 NaN,或者是其他什么优化器无法有效应用到模型上的值?我在 train_step 函数中将它们打印了出来(去掉了 jit 装饰器,否则 print 语句只会在函数编译的初始 JIT 阶段执行,而不是在处理实际数据 1 时执行)。
结果打印出了类似这样的值:
State({
'output_head': {
'kernel': Param( # 38,597,376 (154.4 MB)
value=Array([[-2.6879393e-06, -1.2799728e-04, 2.6441864e-09, ...,
-1.0780521e-09, -1.9232946e-09, 1.2057198e-04],
[ 7.2428256e-06, -9.0873800e-05, 1.9621261e-08, ...,
1.9959407e-08, 2.0515712e-08, -1.1401048e-06],
[-2.4080187e-05, 1.0717572e-04, -4.7910085e-09, ...,
-7.3136892e-09, -5.4990306e-09, 1.4717734e-04],
...,
[ 1.9500087e-05, 1.4264552e-05, -3.0880422e-08, ...,
-3.0595814e-08, -3.7087858e-08, -1.2066610e-06],
[ 1.8085115e-05, 7.6247423e-05, -3.0720415e-08, ...,
-3.1052533e-08, -3.1693808e-08, -9.7857817e-05],
[ 5.2281484e-06, -1.4398852e-04, 6.2573882e-08, ...,
5.5977843e-08, 6.6571232e-08, -1.0639715e-05]], dtype=float32)
)
},
...这些值看起来相当合理——虽然很小,但还没小到在我看来会在 0.0014 的学习率下毫无作用的地步。是时候深入探究一下训练循环的内部机制了。
深入探究
最明显的疑点在于更新步骤——调用 optimizer.update 究竟有没有改变参数?与 JAX 常规的函数式编程范式相比,Flax 的 NNX API 显得有些奇特。在原生的 JAX 代码中,你通常会像下面这样来应用梯度:
new_parameters = jax.tree.map(
lambda p, g: p - g * learning_rate,
old_parameters,
grads,
)也就是说,通过对旧参数进行变换来获得新的参数。
相比之下,NNX 更带有 PyTorch 的风格。它通过一种带有副作用(即修改其传入参数)的函数来原地(in-place)更新参数:
optimizer.update(model, grads)……而不是像下面这个假想 API 那样更具函数式风格的写法:
model = optimizer.apply(model, grads)我不难想象,自己可能是在什么地方弄错了,从而导致原地更新失效,毕竟要在像 JAX 这样的函数式系统之上实现这种功能,感觉需要极其精细的操作。
但是,既然模型有 7700 万个参数,而且它们的更新量(基于类似 -2.6879393e-06 的梯度和 1.4e-3 的学习率)发生在小数点后第九位甚至更靠后,我又怎么能看出参数到底变没变呢?直接把这些数组打印出来根本行不通!
用哈希一探究竟
稍加思索后,我意识到解决办法是使用哈希值。参数值中极其微小的变化也会导致其哈希值发生剧烈改变。因此,如果真如我所料参数没有被更新,我就会看到固定不变的哈希值;而如果它们被更新了,哪怕只有一丁点儿,哈希值也会随之改变。
这篇 GitHub 讨论为我指明了方向:如果我能将参数提取为纯 JAX 数组,我就可以这么做:
print(hash(np.asarray(some_array).tobytes()))……这里的 np 就是 numpy。这会生成一个在本次运行期间保持稳定的哈希值——相同的参数其哈希值总是相同的,不同的参数哈希值则不同,这正是我们想要的效果。不同次运行的哈希值可能会不一样(Python 在每个新的解释器中会使用不同的哈希种子),但这对于此类调试来说无关紧要。
我起初并不清楚我的 Flax 模型参数具有怎样的结构,但在训练循环中将其打印出来后,我得知了:
Embed( # Param: 38,597,376 (154.4 MB)
embedding=Param( # 38,597,376 (154.4 MB)
value=Array(shape=(50257, 768), dtype=dtype('float32'))
),
...
)
Linear( # Param: 38,597,376 (154.4 MB)
kernel=Param( # 38,597,376 (154.4 MB)
value=Array(shape=(768, 50257), dtype=dtype('float32'))
),
...
)因此,受此启发,我在训练循环中加入了这几行代码:
print(hash(np.asarray(model.token_embedding.embedding.value).tobytes()))
print(hash(np.asarray(model.output_head.kernel.value).tobytes()))显然,像这样把数组复制并转换过来会拖慢运行速度,但出于调试目的,这个方法看起来很稳妥。
我启动了训练循环,问题立刻一目了然了:
0%| | 43/530640 [00:06<13:39:02, 10.80it/s, loss=10.824, tps=43,576]
5694185712877458479
-5759723708627894111
0%| | 43/530640 [00:06<13:39:02, 10.80it/s, loss=10.824, tps=43,897]
5694185712877458479
-5759723708627894111……以此类推。哈希值根本没有改变,说明模型的参数没有得到更新,哪怕是一丁点儿都没变。找到问题了!
果然如我所料,问题就出在 NNX 执行的原地更新上。正如我早先所说,我的训练循环是基于 Flax 官网上的“Basic Usage”示例编写的——但我弄错了一个关键的地方。我写的是这样:
@jax.jit
def train_step(model, optimizer, inputs, targets):
loss, grads = nnx.value_and_grad(calculate_loss)(model, inputs, targets)
optimizer.update(model, grads)
return loss……而官方示例是这样写的:
@nnx.jit # automatic state propagation
def train_step(model, optimizer, x, y):
loss_fn = lambda model: ((model(x) - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads) # in-place updates
return loss你可以看到许多不同之处——例如,他们通过词法闭包将输入和目标值绑定到了用于损失函数的 lambda 表达式中,这意味着他们只需将模型传给被 value_and_grad 包装的版本。但这些都不重要!真正的区别其实已在一处注释中被很好地凸显出来了,但我却完全没注意到。就在最开始,我使用的是 @jax.jit,而他们使用的是:
@nnx.jit # automatic state propagation为了支持这种非函数式的、原地的模型参数更新,你必须使用修改版的 JIT 装饰器,这百分之百是合理的。而我当时使用的只是标准的、函数式的纯 JAX 装饰器。
修复此处后,问题就解决了:
0%| | 1/530640 [00:06<903:18:25, 6.13s/it, loss=10.824, tps=1,003]
5024998356359528747
-4835662927486742764
0%| | 2/530640 [00:06<397:16:33, 2.70s/it, loss=10.785, tps=1,914]
6231090084827524676
8293831317336780907
0%| | 3/530640 [00:06<228:14:32, 1.55s/it, loss=10.741, tps=2,791]
7896237091035346857
-7117477486466304738哈希值改变了!更棒的是,如果你向右滚动,就会看到 loss 正在缓慢下降。在大约 1 万次迭代后,我看到了 0.000:我那个什么也不做的“LLM”终于跑起来了。
总结
这是一段令人满意的调试之旅——虽然我觉得自己以后不会再犯这种特定的错误,但我认为参数哈希这个技巧确实是一个非常实用的工具箱技巧。如果你不确定参数是否被更新了,光盯着它们看可能没什么用。但是查看它们的哈希值就能帮你查明到底有没有发生变化。
而且我认为自己用来聚焦问题的排查模式也很有用。我总是会追踪 loss,所以这是一个很好的切入点(事实上,正是看到 loss 没有下降,才让我意识到出问题了)。但检查 loss 是否具有一个合理的——或者理想情况下,像本例中这样有意义的——数值,是一个很好的合理性检查,能确保我们拥有一个有效的损失函数,且模型没有表现出完全病态的行为。接下来,检查是否有某种梯度在流动是稳妥的下一步(随着模型变深,梯度可能会消失或爆炸,这一步可能会变得越来越关键)。最后我们就可以检查参数了——特别是,它们发生变化了吗? 2
看看我在推进这个 LLM 项目的过程中还能掌握多少新技巧吧。
需要完整排版与评论请前往来源站点阅读。