你正在OpenAI进行AI工程师面试。 面试官问: “我们的GPT模型在42秒内生成100个token。 你如何让它快5倍?” 你: “我会分配更多的GPU以加快生成速度。” 面试结束。 你错过了什么:
真正的瓶颈不是计算,而是冗余计算。 没有 KV 缓存,你的模型会为每个 token 重新计算键和值,重复工作。 - 使用 KV 缓存 → 9 秒 - 不使用 KV 缓存 → 42 秒(约慢 5 倍) 让我们深入了解它是如何工作的!
要理解KV缓存,我们必须知道LLM是如何输出令牌的。 - Transformer为所有令牌生成隐藏状态。 - 隐藏状态被投影到词汇空间。 - 最后一个令牌的Logits用于生成下一个令牌。 - 对后续令牌重复此过程。 查看这个👇
因此,要生成一个新令牌,我们只需要最近令牌的隐藏状态。 不需要其他任何隐藏状态。 接下来,让我们看看最后的隐藏状态是如何在变换器层中通过注意力机制计算的。
在注意力机制中: 查询-键-值的最后一行涉及: - 最后一个查询向量。 - 所有键向量。 此外,最终注意力结果的最后一行涉及: - 最后一个查询向量。 - 所有键和值向量。 查看这个视觉图以更好地理解:
上述见解表明,要生成一个新令牌,网络中的每个注意力操作只需要: - 最后一个令牌的查询向量。 - 所有的键值向量。 但是,这里还有一个更重要的见解。
当我们生成新代币时: - 所有先前代币使用的KV向量都不会改变。 因此,我们只需要为前一步生成的代币生成一个KV向量。 其余的KV向量可以从缓存中检索,以节省计算和时间。
这被称为 KV 缓存! 重申一下,缓存所有上下文标记的 KV 向量,而不是冗余地计算它们。 生成一个标记: - 为前一步生成的标记生成 QKV 向量。 - 从缓存中获取所有其他 KV 向量。 - 计算注意力。 查看这个👇
KV 缓存通过在生成令牌之前计算提示的 KV 缓存来加速推理。 这正是 ChatGPT 生成第一个令牌所需时间比生成其余令牌更长的原因。 这个延迟被称为首次令牌时间(TTFT)。 改善 TTFT 是另一个话题!
69.16K