返回 2026-06-07
🤖 AI / ML

JAX 后端与设备JAX backends and devices

gilesthomas.com·2026-06-05

作者在将基于 PyTorch 的大语言模型(LLM)代码移植到 JAX 框架的过程中,详细探讨了 JAX 在处理后端和设备分配时的机制。文章以加载一个包含 10,248,871,837 个 16 位无符号整数、体积超过 19GiB 的庞大训练数据集(`gpjt/fineweb-gpt2-tokens`)为例,展示了 JAX 处理大规模数据的实际表现。通过亲自编写和调试代码,作者澄清了 JAX 框架内部组件之间的协作方式。这为想要从 PyTorch 转向 JAX 的深度学习开发者提供了极具价值的实战经验。

Giles Thomas

归档

分类

友情链接

亲自用框架写代码,最能让人理清各个部分是如何拼在一起的!继续把我的 PyTorch LLM 代码移植到 JAX 的过程中,我想加载一个大型数据集:gpjt/fineweb-gpt2-tokens 的 train 分片中共有 10,248,871,837 个 16 位无符号整数。这差不多就是 19GiB 的数据。

from safetensors.flax import load_file
...
full_dataset = load_file(dataset_dir / f"train.safetensors")["tokens"]

运行时,我收到了一个 CUDA 显存不足(out-of-memory)的错误:

jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 19.09GiB.

这完全说得通!它尝试分配的显存正好就是我想加载的数据大小。我有一块 24 GiB 的 RTX 3090,但其中一部分已经被操作系统、各种应用程序以及代码在前面创建的模型占用了。

不过在 PyTorch 的世界里,我习惯的做法是数据默认加载到系统内存(RAM),只有在我明确要求时才会移到 GPU。而 JAX 显然默认就往 GPU 上加载。在这种情况下,我怎样才能阻止它这么做呢?加载到 GPU 的过程发生在 Safetensors 内部,那段代码我没法直接控制。

搞清楚怎么做之后,也让我对 JAX 有了更深的理解。

JAX 有一个看起来很相关的函数:jax.devices。先不看文档,直接试着运行一下。在安装了 jax[cuda13] 包的虚拟环境中,我得到如下结果:

In [1]: import jax

In [2]: all_devices = jax.devices()

In [3]: all_devices
Out[3]: [CudaDevice(id=0)]

这看起来有点奇怪!我确实有 CUDA 设备,但显然我也有 CPU。为什么它没显示出来?

在另一个只安装了 jax(没有 CUDA)的虚拟环境中运行同样的代码,得到的结果是:

In [1]: import jax

In [2]: all_devices = jax.devices()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

In [3]: all_devices
Out[3]: [CpuDevice(id=0)]

好吧,这次它确实识别到了。看来有必要去读读官方文档(RTFM)了。

jax.devices 的文档对此做了一些解释:

jax.devices(backend=None) 返回指定后端的所有设备列表。 ... 如果 backend 为 None,则返回默认后端的所有设备。默认后端通常是 'gpu' 或 'tpu'(如果可用),否则为 'cpu'。

明白了。所以 JAX 有多个后端——之所以这么叫,是因为它们是 XLA(JIT 背后的编译器)所针对的后端硬件类别。其中有一个默认后端,本质上就是在当前硬件配置和已安装的 JAX 组件中,所能使用的“最好”的那个。

当我安装了 CUDA 版本时,它把 gpu 设为默认后端;而没装的时候,就默认为 cpu(并给出了警告)。又因为它只显示默认后端的设备,所以当默认是 gpu 时,我就看不到 CPU 了。

不过,你可以通过 backend 参数来指定想要使用的后端,让我们回到装有 CUDA 的虚拟环境试试:

In [4]: jax.devices("cpu")
Out[4]: [CpuDevice(id=0)]

很好!那有没有办法列出所有可用的后端呢? apparently not -- the recommended way appears to be to try loading devices for the different possibilities, and catch RuntimeErrors to see which ones aren't available. Yuck. (参考译文:似乎没有——推荐做法似乎是尝试为不同的可能性加载设备,并捕获 RuntimeError 来判断哪些不可用。真让人无语。)

但这可能也不是什么大问题。在 PyTorch 的世界里,我非常习惯在代码开头附近写上这样的代码:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

……然后把模型移到设备上:

model.to(device)

……接着根据需要把数据移到模型所在的设备:

device = next(model.parameters()).device
...
inputs = inputs.to(device)

我真正想要的,其实本质上正是 JAX 的做法——让所有东西始终放在可用的最快设备上——但要有一些特定的例外。特别是引发这次探索的那个问题:我该如何把这个巨大的训练数据数组放在 CPU 的 RAM 里,而不是 GPU 的 VRAM 里?

我一开始走了点弯路。我发现 Safetensors FLAX API 中的 load_file 函数有一个 backend 参数,但这似乎更多是关于文件加载方式的——属于另一种意义上的 backend。而且无论如何,在 JAX 的世界里,backend 并不是我们需要的概念,因为 backend 仅指代类似 gpu 这样的一般性事物——而为了实现我们的目标,我们需要将数据加载到特定的 device 上。

经过一番查阅,我发现 JAX 有一个“默认设备(default device)”的概念,也就是在没有任何明确指示要将数据放在何处时,系统默认使用的设备。合乎情理的是,它会位于默认的 backend 上——实际上,它本质上看起来就是“jax.devices 为默认 backend 返回的设备列表中的第一个设备”。

有一个 jax_default_device 配置选项可以用来设置它;你通常会使用 jax.config.update 或环境变量来更改它。

但如果你只想暂时更改它该怎么办呢?我找到了关于 jax.default_device 的文档。

这文档着实让人有些摸不着头脑:

jax.default_device =<jax._src.config.State object> 这是 jax_default_device 配置选项的上下文管理器。为 JAX 操作配置默认设备。将其设置为 Device 对象(例如 jax.devices("cpu")[0]),即可将该 Device 作为 JAX 操作和 jit 编译的函数调用的默认设备(这对多设备计算无效,例如 pmapped 函数调用)。设置为 None 则使用系统默认设备。

开头附近的那个 = 让我绊了个跟头,因为我没看到下面紧接着的“Context manager(上下文管理器)”字样以及奇怪的 State 类型,然后尝试了这样做:

jax.default_device = jax.devices("cpu")[0]
full_dataset = load_file(dataset_dir / f"train.safetensors")["tokens"]
jax.default_device = None

不过,我还是遇到了 CUDA OOM(显存溢出),所以我重新阅读了文档,注意到了“context manager”部分,忍不住爆了句粗口,然后尝试了这样做:

with jax.default_device(jax.devices("cpu")[0]):
    full_dataset = load_file(dataset_dir / f"train.safetensors")["tokens"]

……这回奏效了。看来文档里的等号所代表的含义,与我们通常使用它的目的截然不同,而且他们决定不真正去记录这个上下文管理器的签名。哎,我猜写文档确实挺难的。

尽管如此,至少我现在有了解决方案。而且正如我之前所说,撇开对文档的抱怨不谈,代码的最终形态可能比 PyTorch 还要简单一点。我创建的对象的默认存放位置是我拥有的最快的硬件,这正是我想要的。而在极少数我不想使用它的情况下,也有一种相当简单的(既然我现在知道了)方法来指定我希望将对象放在哪里。

我觉得这算是个胜利 :-) 唯一需要记住的是,当我在训练循环中想使用那个内存中的张量(in-RAM tensor)的子集时,我需要将它们移动到 GPU 上。jax.device_put 看起来就是完成这项任务的正确工具。

需要完整排版与评论请前往来源站点阅读。