KV 缓存与注意力 O(n²)
大模型从左到右一个词一个词地生成,每生成一个新词,它都要回头“看”前面所有词——这就是注意力。麻烦在于:序列越长,要看的越多。如果每生成一步都把前面所有词重新算一遍“键(Key)和值(Value)”,那么生成 n 个词的总计算量是 1+2+…+n ≈ n²/2,随长度平方级增长,长上下文因此又慢又贵。KV 缓存的办法很简单:每个词的 K、V 只在它第一次出现时算一次、存起来反复用,后面就不重算了。于是总量降到约 n(线性)。点“生成下一个词”,对比开/关缓存的累计计算量。
下面的三角格子表示“在第几步、给第几个词算 K/V”。蓝=这步新算的,浅青=从缓存直接取的。开缓存时只有对角线要算,关缓存时整片三角都要重算。
注意力 ~ O(n²)
每步都重算前面所有词,生成 n 个词总量约 n²/2,随长度平方增长——长上下文贵在这里。
KV 缓存 ~ O(n)
每个词的 K/V 只算一次存起来,后续直接取,总量降到线性,生成快得多。
代价是显存
缓存要占显存,序列越长占得越多——这也是长上下文的另一道坎。