返回 2026-06-06
🤖 AI / ML

在 Flax 中使用 SafetensorsUsing Safetensors with Flax

gilesthomas.com·2026-06-04

将 PyTorch 编写的大语言模型(LLM)代码迁移到 JAX/Flax 框架时,模型检查点的存储格式选择是一个关键环节。作者分享了在 Flax 中使用 Safetensors 格式保存模型权重的经验和技巧。Safetensors 相比传统格式更安全、加载速度更快。文章详细记录了在解决兼容性问题过程中发现的具体实现方法。这为其他进行类似框架迁移的开发者提供了直接的参考。

Giles Thomas

归档

分类

友情链接

我正在将我的 PyTorch LLM 代码移植到 JAX,并使用 Flax 作为神经网络层。出于各种原因,我想使用 Safetensors 来存储模型的检查点。花了我一点时间才让它跑通;以下是我学到的诀窍。

如果你查看 Safetensors 的文档,你会发现它并没有提到 JAX 的实现——实际上,在我写这篇文章的时候,搜索“safetensors jax”会给你一个指向 Alvaro Bartolome 的 GitHub 仓库的链接——该仓库最后一次更新是在 2023 年。

然而,如果你仔细查看文档,它们确实提供了一个指向 Flax API 的链接。我觉得这个命名有些不妥,因为它实际上是一个 JAX API。在源代码中(同样,截至撰写本文时)没有任何对 Flax 的引用——全都是 JAX 代码。事实上,Bartolome 的库在底层使用的正是它。

不过有一个问题。该 API 只能处理简单的单层字典,其中字符串直接映射到 JAX 数组。例如,save_file 函数具有这样的签名:

def save_file(
    tensors: Dict[str, Array],
    filename: Union[str, os.PathLike],
    metadata: Optional[Dict[str, str]] = None,
) -> None

如果你不够小心,这可能会引发问题。如果你查看 Flax 关于检查点的文档,它会建议你使用 Orbax 1,后者有自己的 API 和文件格式,但文档接着说道:

在与检查点库(如 Orbax)交互时,你可能更倾向于使用 Python 的内置容器类型。在这种情况下,你可以使用 nnx.State.to_pure_dict 和 nnx.State.replace_by_pure_dict API,将 nnx.State 与纯嵌套字典进行相互转换。

我起初想当然地把这两件事(文档的提示和 Safetensors 基于字典的 API)联系在了一起,并尝试将其中一个“纯”字典喂给 Safetensors。结果我得到了一个非常令人困惑的错误:

SafetensorError: dtype object is not covered

我们值得深入探究一下为什么会发生这种情况。

问题在于,尽管 Safetensors 期望得到一个字符串到张量映射的字典,但它并没有检查实际收到的究竟是不是这种结构。而且,虽然 nnx.State.to_pure_dict 生成的字典是“纯”的,但它们也是嵌套的(正如文档所说!)。即使对于我正在使用的简单模型,我也得到了像这样的结构:

{
    'output_head': {
        'kernel': Array([...], dtype=float32)
    },
    'token_embedding': {
        'embedding': Array([...], dtype=float32)
    }
}

所以,我们的结构是字符串映射到字典,而这些字典又从字符串映射到 JAX 数组。更复杂的模型将具有更深的字典结构。

现在,在 Safetensors 的内部,Flax/JAX API 只是一个简单的包装器。它会遍历所提供字典中的键,并尝试将它们各自的值转换为 NumPy 数组。它通过将这些值传入 NumPy 的 asarray 函数来实现这一点,该函数接受列表、元组和 NumPy 数组等类型,并将它们转换为数组。JAX 自己的 Array 类暴露了该函数能识别的接口,因此它们可以毫无困难地进行转换。

一旦完成这一步,它就会将结果传递给底层的 Rust 实现,由后者实际将所有内容转换为 Safetensors 格式。

但由于 Safetensors 没有检查类型,在我的情况下,它在遍历字典的顶层时,试图将值转换为 NumPy 数组,结果得到了类似这样的东西:

{
    'output_head': numpy.array({'kernel': Array([...], dtype=float32)}, dtype=object),
    'token_embedding': numpy.array({'embedding': Array([...], dtype=float32)}, dtype=object)
}

也就是说——因为它假设顶层字典中的值都是 JAX 数组,所以它不加判断地尝试将它们转换为 NumPy 数组。但它们实际上是字典(恰好是从字符串映射到数组)——如果你让 asarray 基于一个随机对象来创建数组,它会照做不误,将该对象包装在一个 NumPy 数组中,其 dtype 为 object。

当它随后被送入尝试写入文件的底层 Rust 代码时,遇到了它无法处理的 NumPy 数组,其 dtype 为 object —— 因此导致了那个错误:

SafetensorError: dtype object is not covered

通读代码后就会发现这一切都很合理,但我之前确实困惑了一阵子!

我想这一切可能就是 Bartolome 创建他的 GitHub 仓库的原因。在 README 中,他是这么说的:

HuggingFace 目前没有计划扩展 safetensors 以支持张量以外的任何内容(例如 FrozenDicts),请参阅他们在 huggingface/safetensors/discussions/138 中的回复。因此,创建 safejax 的动机是为了轻松提供一种序列化 FrozenDicts 的方法,并使用 safetensors 作为张量存储格式。

但是,你并不需要使用那个库来序列化简单的 Flax 模型。

想想 PyTorch 模型是如何序列化到 Safetensors 的;我的 LLM 中包含名为 out_head.weight、pos_emb.weight 和 trf_blocks.0.att.out_proj.weight 的键。它们是“扁平化”的字典,将字符串映射到 PyTorch Tensors,类似于 Safetensors 对这些 Flax 模型的要求,只不过它们使用点号来分隔不同的层级,列表项使用整数,字段名使用字符串。

看看我为模型准备的纯字典结构:

{
    'output_head': {
        'kernel': Array([...], dtype=float32)
    },
    'token_embedding': {
        'embedding': Array([...], dtype=float32)
    }
}

……你可以看到,可以通过遍历字典结构来生成诸如 output_head.kernel 和 token_embedding.embedding 这样的键。这编码实现起来相当简单。

但是——正如 Adithya Dsilva 在 GitHub 上指出的那样——你可以使用 nnx.to_flat_state 更快地实现这一目标。它会返回一个类似这样的(非字典)结构:

FlatState([
  (('output_head', 'kernel'), Param( # 786,432 (3.1 MB)
    value=Array([[ 2.3581974e-02,  3.0957451e-02, -3.5088759e-02, ...,
            -4.5880198e-02,  5.3717274e-02, -2.6590331e-02],
           ...,
           [-9.6302675e-03, -3.3276502e-02,  5.7173111e-02, ...,
            -7.9063717e-03,  2.0532632e-02,  5.4753982e-02]], dtype=float32)
  )),
  (('token_embedding', 'embedding'), Param( # 786,432 (3.1 MB)
    value=Array([[ 0.00273973, -0.01754938,  0.04656043, ..., -0.04276522,
            -0.03986642, -0.00781331],
           ...,
           [ 0.01421758, -0.0219186 , -0.01701825, ..., -0.00793659,
             0.00500103,  0.03839901]], dtype=float32)
  ))
])

如果你遍历那个 FlatState,会得到一些元组,其第一个元素是字符串元组(如 ('output_head', 'kernel')),第二个元素是封装了 JAX Array 的 Param 对象。这些元组反映了 PyTorch 风格的 Safetensors 文件里以点号分隔的字符串格式。

Param 对象还实现了 asarray 能够理解的接口,因此你可以快速轻松地将 FlatState 转换为适用于 Safetensors 的常规字典:

    from safetensors.flax import save_file

    ...

    model_state = nnx.state(model)
    flat_state = nnx.to_flat_state(model_state)
    simple_dict = {}
    for tuple_key, param in flat_state:
        key = ".".join(str(key) for key in tuple_key)
        simple_dict[key] = param

    save_file(simple_dict, "model.safetensors")

(你需要用 str 包装 key,因为如果你的模型中包含 nnx.Sequential,元组中的项将会是整数索引而不是字符串)。

反向操作也相当容易;给定一个模型,你可以像这样将保存的检查点加载到其中(因为 from_flat_state 接受原始 JAX Arrays 来代替显式 Params):

    from safetensors.flax import load_file

    ...

    simple_dict = load_file("model.safetensors")

    dict_flat_state = {}
    for key, array in simple_dict.items():
        elements = key.split(".")
        list_key = []
        for element in elements:
            try:
                list_key.append(int(element))
            except ValueError:
                list_key.append(element)
        dict_flat_state[tuple(list_key)] = array

    new_flat_state = nnx.from_flat_state(dict_flat_state)
    nnx.update(model, new_flat_state)

虽然比我希望的理想情况稍显繁琐,但考虑到它可以被封装到通用的 save_checkpoint/load_checkpoint 函数中,这也不算什么大问题。

希望这能对其他遇到这个问题的人有所帮助!

  • 我开始对所有这些名字以 -ax 结尾的库感到有点应接不暇了。这让我想起了 Asterix 村庄里那些角色的名字…… ↩
  • 需要完整排版与评论请前往来源站点阅读。