4 位浮点数 FP44-bit floating point FP4
本文介绍了 FP4 这一新兴的 4 位浮点格式,用于高效压缩大语言模型的参数。FP4 是 bitsandbytes 库支持的一种低精度数据类型,旨在在保持模型性能的同时大幅减少内存占用。与传统的 32 位或 64 位浮点数相比,FP4 将每个参数压缩至四分之一比特宽度,适用于边缘设备和资源受限环境。文章还讨论了 FP4 在训练与推理中的潜在优势及当前局限性。
John Cook
在古代,浮点数以32位存储。后来,在某个时间点,64位成为标准。C语言保留了这一古老传统,用 float 表示32位浮点数,用 double 表示位数翻倍的浮点数。Python则直接用 float 表示最常见的浮点数格式,即C语言所称的 double。
程序员们很高兴从32位浮点数转向64位浮点数。精度更高总归没有坏处,而且许多数值问题在从32位升级到64位后得以解决。(但并非全部如此。这一点我曾多次提及。)
神经网络带来了意想不到的需求:需要更低精度的浮点数。这些网络拥有海量参数,相比之下,将更多参数塞进内存比追求高精度更重要。开发者不再满足于双精度(64位),转而需要半精度(16位),甚至更低,如FP8(8位)或FP4(4位)。本文将探讨4位浮点数。
既然不需要太高精度,为什么还要用浮点数?为何不用整数?例如,用4位可以表示整数0、1、2……15。也可以引入一个偏置值,比如每个值减7,这样这4位就能表示-7到8。事实证明,具备更宽的动态范围是有用的。
FP4格式的带符号4位浮点数使用第一位表示符号。问题是其余三位该如何分配。记法 ExMy 表示有 x 个指数位和 y 个尾数位。对于带符号的4位数而言:
x + y = 3
但在其他情况下,这个和可能更大。例如,对于8位带符号浮点数,x + y = 7。
对于4位带符号浮点数,有四种可能:E3M0、E2M1、E1M2 和 E0M3。它们都在某些地方被使用,其中 E2M1 最常见,并被NVIDIA硬件支持。
一个具有符号位 s、指数 e 和尾数 m 的数的值为:
(−1)^s × 2^(e−b) × (1 + m/2)
其中 b 是偏置值。偏置的目的在于允许正负指数,而无需对 e 使用带符号的数。例如,若 b = 1 且 e = 1、2 或 3,则指数部分 2^(e−b) 可表示 1、2 或 4。
偏置影响数值的范围,但不影响其相对间距。无论偏置 b 取何值,E3M0 格式全是指数位、无尾数位,因此其可能值在对数尺度上均匀分布。E0M3 格式全是尾数位,其值在**线性**尺度上均匀分布。E1M2 和 E2M1 格式在**对数和线性**尺度上都是不均匀分布的。
当 e = 0 时,上述表达式有一个例外:此时 m = 0 表示 0,m = 1 表示 ½。
数值表
由于只有16种可能的FP4数,因此可以列出全部内容。以下是 E2M1 格式的表。
Bits s exp m Value
-------------------
0000 0 00 0 +0
0001 0 00 1 +0.5
0010 0 01 0 +1
0011 0 01 1 +1.5
0100 0 10 0 +2
0101 0 10 1 +3
0110 0 11 0 +4
0111 0 11 1 +6
1000 1 00 0 -0
1001 1 00 1 -0.5
1010 1 01 0 -1
1011 1 01 1 -1.5
1100 1 10 0 -2
1101 1 10 1 -3
1110 1 11 0 -4
1111 1 11 1 -6注意,即便在这个极小的浮点格式中,也存在两个零:+0 和 -0,就像高精度浮点数一样。更多细节见此处。
Pychop库
Python库Pychop模拟了多种低精度浮点格式。以下是用Pychop生成上述表格的代码。
import pychop
# Pull the format metadata from Pychop.
spec = pychop.MX_FORMATS["mxfp4_e2m1"]
assert (spec.exp_bits, spec.sig_bits) == (2, 1)
def e2m1_value(s: int, e: int, m: int) -> float:
sign = -1.0 if s else 1.0
# Subnormal / zero
if e == 0:
return sign * (m / 2.0)
# Normal
return sign * (2.0 ** (e - 1)) * (1.0 + m / 2.0)
def display_value(bits: int, x: float) -> str:
if bits == 0b0000:
return "+0"
if bits == 0b1000:
return "-0"
return f"{x:+g}"
rows = []
for bits in range(16):
s = (bits >> 3) & 0b1
e = (bits >> 1) & 0b11
m = bits & 0b1
x = e2m1_value(s, e, m)
rows.append(
{
"Bits": f"{bits:04b}",
"s": s,
"exp_bits": f"{e:02b}",
"m": m,
"Value": display_value(bits, x),
}
)
# Pretty-print the table.
header = f"{'Bits':<4} {'s':>1} {'exp':>3} {'m':>1} {'Value':>6}"
print(header)
print("-" * len(header))
for row in rows:
print(
f"{row['Bits']:<4} " f"{row['s']:>1} "
f"{row['exp_bits']:>3} "
f"{row['m']:>1} "
f"{row['Value']:>6}"
)其他格式
FP4 并非唯一的4位浮点格式。实际使用的格式数量出人意料地多。我将在下一篇帖子中介绍另一种格式。
更新:请参见下一篇帖子,其中讨论了 NF4 格式,该格式的表示数字更贴近 LLM 权重的分布。
需要完整排版与评论请前往来源站点阅读。