在 Flax 中使用 SafetensorsUsing Safetensors with Flax
将 PyTorch 编写的大语言模型(LLM)代码迁移到 JAX/Flax 框架时,模型检查点的存储格式选择是一个关键环节。作者分享了在 Flax 中使用 Safetensors 格式保存模型权重的经验和技巧。Safetensors 相比传统格式更安全、加载速度更快。文章详细记录了在解决兼容性问题过程中发现的具体实现方法。这为其他进行类似框架迁移的开发者提供了直接的参考。
Giles Thomas
归档
分类
友情链接
我正在将我的 PyTorch LLM 代码移植到 JAX,并使用 Flax 作为神经网络层。出于各种原因,我想使用 Safetensors 来存储模型的检查点。花了我一点时间才让它跑通;以下是我学到的诀窍。
如果你查看 Safetensors 的文档,你会发现它并没有提到 JAX 的实现——实际上,在我写这篇文章的时候,搜索“safetensors jax”会给你一个指向 Alvaro Bartolome 的 GitHub 仓库的链接——该仓库最后一次更新是在 2023 年。
然而,如果你仔细查看文档,它们确实提供了一个指向 Flax API 的链接。我觉得这个命名有些不妥,因为它实际上是一个 JAX API。在源代码中(同样,截至撰写本文时)没有任何对 Flax 的引用——全都是 JAX 代码。事实上,Bartolome 的库在底层使用的正是它。
不过有一个问题。该 API 只能处理简单的单层字典,其中字符串直接映射到 JAX 数组。例如,save_file 函数具有这样的签名:
def save_file(
tensors: Dict[str, Array],
filename: Union[str, os.PathLike],
metadata: Optional[Dict[str, str]] = None,
) -> None如果你不够小心,这可能会引发问题。如果你查看 Flax 关于检查点的文档,它会建议你使用 Orbax 1,后者有自己的 API 和文件格式,但文档接着说道:
在与检查点库(如 Orbax)交互时,你可能更倾向于使用 Python 的内置容器类型。在这种情况下,你可以使用 nnx.State.to_pure_dict 和 nnx.State.replace_by_pure_dict API,将 nnx.State 与纯嵌套字典进行相互转换。
我起初想当然地把这两件事(文档的提示和 Safetensors 基于字典的 API)联系在了一起,并尝试将其中一个“纯”字典喂给 Safetensors。结果我得到了一个非常令人困惑的错误:
SafetensorError: dtype object is not covered我们值得深入探究一下为什么会发生这种情况。
问题在于,尽管 Safetensors 期望得到一个字符串到张量映射的字典,但它并没有检查实际收到的究竟是不是这种结构。而且,虽然 nnx.State.to_pure_dict 生成的字典是“纯”的,但它们也是嵌套的(正如文档所说!)。即使对于我正在使用的简单模型,我也得到了像这样的结构:
{
'output_head': {
'kernel': Array([...], dtype=float32)
},
'token_embedding': {
'embedding': Array([...], dtype=float32)
}
}所以,我们的结构是字符串映射到字典,而这些字典又从字符串映射到 JAX 数组。更复杂的模型将具有更深的字典结构。
现在,在 Safetensors 的内部,Flax/JAX API 只是一个简单的包装器。它会遍历所提供字典中的键,并尝试将它们各自的值转换为 NumPy 数组。它通过将这些值传入 NumPy 的 asarray 函数来实现这一点,该函数接受列表、元组和 NumPy 数组等类型,并将它们转换为数组。JAX 自己的 Array 类暴露了该函数能识别的接口,因此它们可以毫无困难地进行转换。
一旦完成这一步,它就会将结果传递给底层的 Rust 实现,由后者实际将所有内容转换为 Safetensors 格式。
但由于 Safetensors 没有检查类型,在我的情况下,它在遍历字典的顶层时,试图将值转换为 NumPy 数组,结果得到了类似这样的东西:
{
'output_head': numpy.array({'kernel': Array([...], dtype=float32)}, dtype=object),
'token_embedding': numpy.array({'embedding': Array([...], dtype=float32)}, dtype=object)
}也就是说——因为它假设顶层字典中的值都是 JAX 数组,所以它不加判断地尝试将它们转换为 NumPy 数组。但它们实际上是字典(恰好是从字符串映射到数组)——如果你让 asarray 基于一个随机对象来创建数组,它会照做不误,将该对象包装在一个 NumPy 数组中,其 dtype 为 object。
当它随后被送入尝试写入文件的底层 Rust 代码时,遇到了它无法处理的 NumPy 数组,其 dtype 为 object —— 因此导致了那个错误:
SafetensorError: dtype object is not covered通读代码后就会发现这一切都很合理,但我之前确实困惑了一阵子!
我想这一切可能就是 Bartolome 创建他的 GitHub 仓库的原因。在 README 中,他是这么说的:
HuggingFace 目前没有计划扩展 safetensors 以支持张量以外的任何内容(例如 FrozenDicts),请参阅他们在 huggingface/safetensors/discussions/138 中的回复。因此,创建 safejax 的动机是为了轻松提供一种序列化 FrozenDicts 的方法,并使用 safetensors 作为张量存储格式。
但是,你并不需要使用那个库来序列化简单的 Flax 模型。
想想 PyTorch 模型是如何序列化到 Safetensors 的;我的 LLM 中包含名为 out_head.weight、pos_emb.weight 和 trf_blocks.0.att.out_proj.weight 的键。它们是“扁平化”的字典,将字符串映射到 PyTorch Tensors,类似于 Safetensors 对这些 Flax 模型的要求,只不过它们使用点号来分隔不同的层级,列表项使用整数,字段名使用字符串。
看看我为模型准备的纯字典结构:
{
'output_head': {
'kernel': Array([...], dtype=float32)
},
'token_embedding': {
'embedding': Array([...], dtype=float32)
}
}……你可以看到,可以通过遍历字典结构来生成诸如 output_head.kernel 和 token_embedding.embedding 这样的键。这编码实现起来相当简单。
但是——正如 Adithya Dsilva 在 GitHub 上指出的那样——你可以使用 nnx.to_flat_state 更快地实现这一目标。它会返回一个类似这样的(非字典)结构:
FlatState([
(('output_head', 'kernel'), Param( # 786,432 (3.1 MB)
value=Array([[ 2.3581974e-02, 3.0957451e-02, -3.5088759e-02, ...,
-4.5880198e-02, 5.3717274e-02, -2.6590331e-02],
...,
[-9.6302675e-03, -3.3276502e-02, 5.7173111e-02, ...,
-7.9063717e-03, 2.0532632e-02, 5.4753982e-02]], dtype=float32)
)),
(('token_embedding', 'embedding'), Param( # 786,432 (3.1 MB)
value=Array([[ 0.00273973, -0.01754938, 0.04656043, ..., -0.04276522,
-0.03986642, -0.00781331],
...,
[ 0.01421758, -0.0219186 , -0.01701825, ..., -0.00793659,
0.00500103, 0.03839901]], dtype=float32)
))
])如果你遍历那个 FlatState,会得到一些元组,其第一个元素是字符串元组(如 ('output_head', 'kernel')),第二个元素是封装了 JAX Array 的 Param 对象。这些元组反映了 PyTorch 风格的 Safetensors 文件里以点号分隔的字符串格式。
Param 对象还实现了 asarray 能够理解的接口,因此你可以快速轻松地将 FlatState 转换为适用于 Safetensors 的常规字典:
from safetensors.flax import save_file
...
model_state = nnx.state(model)
flat_state = nnx.to_flat_state(model_state)
simple_dict = {}
for tuple_key, param in flat_state:
key = ".".join(str(key) for key in tuple_key)
simple_dict[key] = param
save_file(simple_dict, "model.safetensors")(你需要用 str 包装 key,因为如果你的模型中包含 nnx.Sequential,元组中的项将会是整数索引而不是字符串)。
反向操作也相当容易;给定一个模型,你可以像这样将保存的检查点加载到其中(因为 from_flat_state 接受原始 JAX Arrays 来代替显式 Params):
from safetensors.flax import load_file
...
simple_dict = load_file("model.safetensors")
dict_flat_state = {}
for key, array in simple_dict.items():
elements = key.split(".")
list_key = []
for element in elements:
try:
list_key.append(int(element))
except ValueError:
list_key.append(element)
dict_flat_state[tuple(list_key)] = array
new_flat_state = nnx.from_flat_state(dict_flat_state)
nnx.update(model, new_flat_state)虽然比我希望的理想情况稍显繁琐,但考虑到它可以被封装到通用的 save_checkpoint/load_checkpoint 函数中,这也不算什么大问题。
希望这能对其他遇到这个问题的人有所帮助!
需要完整排版与评论请前往来源站点阅读。