返回 2026-04-18
🤖 AI / ML

4 位浮点数 FP44-bit floating point FP4

johndcook.com·2026-04-18

本文介绍了 FP4 这一新兴的 4 位浮点格式,用于高效压缩大语言模型的参数。FP4 是 bitsandbytes 库支持的一种低精度数据类型,旨在在保持模型性能的同时大幅减少内存占用。与传统的 32 位或 64 位浮点数相比,FP4 将每个参数压缩至四分之一比特宽度,适用于边缘设备和资源受限环境。文章还讨论了 FP4 在训练与推理中的潜在优势及当前局限性。

John Cook

在古代,浮点数以32位存储。后来,在某个时间点,64位成为标准。C语言保留了这一古老传统,用 float 表示32位浮点数,用 double 表示位数翻倍的浮点数。Python则直接用 float 表示最常见的浮点数格式,即C语言所称的 double。

程序员们很高兴从32位浮点数转向64位浮点数。精度更高总归没有坏处,而且许多数值问题在从32位升级到64位后得以解决。(但并非全部如此。这一点我曾多次提及。)

神经网络带来了意想不到的需求:需要更低精度的浮点数。这些网络拥有海量参数,相比之下,将更多参数塞进内存比追求高精度更重要。开发者不再满足于双精度(64位),转而需要半精度(16位),甚至更低,如FP8(8位)或FP4(4位)。本文将探讨4位浮点数。

既然不需要太高精度,为什么还要用浮点数?为何不用整数?例如,用4位可以表示整数0、1、2……15。也可以引入一个偏置值,比如每个值减7,这样这4位就能表示-7到8。事实证明,具备更宽的动态范围是有用的。

FP4格式的带符号4位浮点数使用第一位表示符号。问题是其余三位该如何分配。记法 ExMy 表示有 x 个指数位和 y 个尾数位。对于带符号的4位数而言:

x + y = 3

但在其他情况下,这个和可能更大。例如,对于8位带符号浮点数,x + y = 7。

对于4位带符号浮点数,有四种可能:E3M0、E2M1、E1M2 和 E0M3。它们都在某些地方被使用,其中 E2M1 最常见,并被NVIDIA硬件支持。

一个具有符号位 s、指数 e 和尾数 m 的数的值为:

(−1)^s × 2^(e−b) × (1 + m/2)

其中 b 是偏置值。偏置的目的在于允许正负指数,而无需对 e 使用带符号的数。例如,若 b = 1 且 e = 1、2 或 3,则指数部分 2^(e−b) 可表示 1、2 或 4。

偏置影响数值的范围,但不影响其相对间距。无论偏置 b 取何值,E3M0 格式全是指数位、无尾数位,因此其可能值在对数尺度上均匀分布。E0M3 格式全是尾数位,其值在**线性**尺度上均匀分布。E1M2 和 E2M1 格式在**对数和线性**尺度上都是不均匀分布的。

当 e = 0 时,上述表达式有一个例外:此时 m = 0 表示 0,m = 1 表示 ½。

数值表

由于只有16种可能的FP4数,因此可以列出全部内容。以下是 E2M1 格式的表。

Bits s exp m  Value
-------------------
0000 0  00 0     +0
0001 0  00 1   +0.5
0010 0  01 0     +1
0011 0  01 1   +1.5
0100 0  10 0     +2
0101 0  10 1     +3
0110 0  11 0     +4
0111 0  11 1     +6
1000 1  00 0     -0
1001 1  00 1   -0.5
1010 1  01 0     -1
1011 1  01 1   -1.5
1100 1  10 0     -2
1101 1  10 1     -3
1110 1  11 0     -4
1111 1  11 1     -6

注意,即便在这个极小的浮点格式中,也存在两个零:+0 和 -0,就像高精度浮点数一样。更多细节见此处。

Pychop库

Python库Pychop模拟了多种低精度浮点格式。以下是用Pychop生成上述表格的代码。

import pychop

# Pull the format metadata from Pychop.
spec = pychop.MX_FORMATS["mxfp4_e2m1"]
assert (spec.exp_bits, spec.sig_bits) == (2, 1)

def e2m1_value(s: int, e: int, m: int) -> float:
    sign = -1.0 if s else 1.0

    # Subnormal / zero
    if e == 0:
        return sign * (m / 2.0)

    # Normal
    return sign * (2.0 ** (e - 1)) * (1.0 + m / 2.0)

def display_value(bits: int, x: float) -> str:
    if bits == 0b0000:
        return "+0"
    if bits == 0b1000:
        return "-0"
    return f"{x:+g}"

rows = []
for bits in range(16):
    s = (bits >> 3) & 0b1
    e = (bits >> 1) & 0b11
    m = bits & 0b1
    x = e2m1_value(s, e, m)

    rows.append(
        {
            "Bits": f"{bits:04b}",
            "s": s,
            "exp_bits": f"{e:02b}",
            "m": m,
            "Value": display_value(bits, x),
        }
    )

# Pretty-print the table.
header = f"{'Bits':<4} {'s':>1} {'exp':>3} {'m':>1} {'Value':>6}"
print(header)
print("-" * len(header))
for row in rows:
    print(
        f"{row['Bits']:<4} " f"{row['s']:>1} "
        f"{row['exp_bits']:>3} "
        f"{row['m']:>1} "
        f"{row['Value']:>6}"
    )

其他格式

FP4 并非唯一的4位浮点格式。实际使用的格式数量出人意料地多。我将在下一篇帖子中介绍另一种格式。

更新:请参见下一篇帖子,其中讨论了 NF4 格式,该格式的表示数字更贴近 LLM 权重的分布。

需要完整排版与评论请前往来源站点阅读。