JAX:承诺问题(设备分配与执行陷阱)JAX: commitment issues
深入探讨了在配置了 CUDA 的机器上运行 JAX 代码时遇到的设备分配与执行陷阱。JAX 的异步计算模型和显式设备放置机制常常会让开发者产生“承诺问题(commitment issues)”,导致难以预测的计算阻塞或同步错误。理解 JAX 底层的调度逻辑和数组在 CPU/GPU 间的传输机制是解决这些问题的关键。掌握这些细节能显著提升高性能计算代码的调试效率。
Giles Thomas
归档
分类
友情链接
发布于 2026 年 6 月 15 日,收录于 AI、TIL、JAX
假设你有这样一段 JAX 代码,并在配置了 CUDA 的机器上运行它:
key = jax.random.key(42)
cpu0 = jax.devices("cpu")[0]
with jax.default_device(cpu0):
array = jax.random.randint(
key,
(530640, 6, 1024),
0, 50_000,
dtype=jax.numpy.uint16
)
array.block_until_ready()
item = array[0]
item.block_until_ready()我们创建了一个大数组,并阻塞等待直到它准备就绪(JAX 是异步的,这样可以确保它确实完成了创建),然后获取第一个元素,并作为双重保险确保该元素也就绪。你觉得最后两行代码——从一个大数组中简单地提取一个 6 x 1024 的数组——会花费多长时间?几十分之一秒似乎是比较合理的。
但刚才在我的机器上运行时,结果有些出乎意料:刚好超过 5 秒。如果你紧接着尝试获取 array[1],仍然需要大约 1.2 秒。对 array 的后续查找持续需要超过 1 秒的时间——因此,虽然最初较大的数值可能与设置有关(也许是内部的东西正在进行 JIT 编译),但这显然不是全部原因。某种原因导致这些看似简单的数组查找所花费的时间比预期的要长得多。
让我们深入探究一下。
一些背景知识
首先,你到底为什么会想要用 jax.default_device 上下文管理器来进行这种略显奇怪的操作,而不是直接告诉 randint 你想使用哪个设备(例如通过 out_sharding)?
我正在编写一些 LLM 训练代码,并希望加载我的训练数据集。我不想将其加载到 GPU 的 VRAM 中——那会浪费宝贵的 GPU 资源——所以我需要将它放在 CPU 端的内存中。我使用的是 Safetensors,它会将数据加载到系统的默认设备上。因此,我需要临时覆盖该设置,以确保数据集被加载到我想要的设备上。
我最初是在训练循环中尝试遍历生成的数组时发现这个问题的;上面的代码是该过程的简化版本——该问题的最小复现示例。而且这是一个严重的问题!如果仅仅是为了给模型准备 6,144 个 token,每次迭代就有 1.2 秒的开销,那么仅由于这一开销,JAX 的训练速度最高只能达到每秒约 5,000 个 token——真正的正向和反向传播加上优化器步骤显然会让速度变得更慢。作为对比,我的 PyTorch 训练循环在相同硬件上达到了近 20,000 tokens/秒的速度:这包含了从获取训练数据、将其放到 GPU 上,再到执行实际训练的所有步骤。
调试
那么,让我们再看看那段代码。我们明确地在 CPU 上创建了变量 array,而且如果你打印 array.device,它会显示 CpuDevice(id=0)。但如果你打印该元素的设备,你会得到 CudaDevice(id=0)。更糟糕的是,如果你在代码运行时观察 nvtop,一旦执行到数组查找操作,它就会开始使用 GPU——每一次查找都会导致 GPU 使用率激增。
那么,到底是怎么回事呢?我们要求 JAX 将数组放在 CPU 上,但它现在却在进行 GPU 运算,并将元素放在了那里。
问题在于,当你使用 default_device 上下文管理器创建数组时,它会被放置在指定的设备上,但并没有提交到该设备。如果一个数组没有提交到其所在的设备,那么 JAX 就会随意将其移动到其他设备上。
为了将数组提交到某个设备,你需要使用 jax.device_put 并明确指出你希望它位于哪个设备上。运行相同的代码,但替换为以下代码:
array = jax.device_put(array, cpu0)……就在对数组进行查找之前,这些数字发生了剧烈变化;在我的机器上,第一次查找大约耗时 0.95s,第二次 0.0002s,随后的耗时则不到 0.0001s。
一些更详细的测试
我决定对此进行深入测试,并编写了这个脚本。如果不带 --commit 命令行标志运行,它会创建数组,然后遍历前十个元素,测量获取每个元素所需的时间。刚刚运行的结果如下:
加上 --commit 标志后,它会使用 device_put 将数组显式提交到 CPU。运行结果如下:
不过,这并没有完全覆盖我的用例——我想知道,如果缓慢的操作是将数据放到 GPU 上呢?该脚本还有一个 --put_items_to_gpu 标志来实现这一点——在获取每个元素后,它会使用 device_put。加上该标志后:
所以,仍然有一点启动开销——也许 JAX 需要对其内部的一些东西进行 JIT——但在此之后速度就非常理想了。提交起作用了!
总结一下
我仍在构建关于 JAX 工作原理的思维模型,而确切地弄清楚这里到底发生了什么确实有点棘手。已提交(committed)和未提交(uncommitted)数组之间的区别似乎很明确:前者绑定到特定设备,而 JAX 会根据需要移动后者。
它想把元素移动到 GPU 上也有一定道理;毕竟,它是默认设备。但我不太明白的是,与手动获取元素然后再放到那里的过程相比,为什么它会这么慢。
假设:数组位于 CPU 的 RAM 中,但未在那里提交。我们请求该数组中的一个元素,也许 JAX 希望它位于默认设备 GPU 上。所以它把整个“父”数组移动到那里,提取该元素,然后将其返回。接着下一次我们请求下一个元素时,它会再次做同样的事情。
合理吗?也许吧,但这听起来确实有点病态!
无论如何,最终我得出了自己的一条可靠的新启发式规则:如果你希望某些数据确实位于某个特定设备上,请确保使用 device_put 将其固定在那里。这样你就不会遇到类似这样的提交问题了。
需要完整排版与评论请前往来源站点阅读。