Shortcuts

KVCache

class torchtune.modules.KVCache(batch_size: int, max_seq_len: int, num_kv_heads: int, head_dim: int, dtype: dtype)[source]

独立的 nn.Module 包含一个 kv-cache,用于在推理过程中缓存过去的键和值。

Parameters:
  • batch_size (int) – 模型将运行的批量大小

  • max_seq_len (int) – 模型将运行的最大序列长度

  • num_kv_heads (int) – 键/值头的数量。

  • head_dim (int) – 每个注意力头的嵌入维度

  • dtype (torch.dpython:type) – 缓存的dtype

reset() None[source]

将缓存重置为零。

update(k_val: Tensor, v_val: Tensor) Tuple[Tensor, Tensor][source]

使用新的k_valv_val更新KV缓存,并返回更新后的缓存。

注意

在更新KV缓存时,假设后续更新应按连续序列位置更新键值位置。如果您希望更新已填充的缓存值,请使用.reset(),这将把缓存重置到第零个位置。

示例

>>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
>>> cache.update(keys, values)
>>> # now positions 0 through 7 are filled
>>> cache.size
>>> 8
>>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32))
>>> cache.update(keys, values)
>>> # this will fill at position 8
>>> cache.size
>>> 9
Parameters:
  • k_val (torch.Tensor) – 当前键张量,形状为 [B, H, S, D]

  • v_val (torch.Tensor) – 当前值张量,形状为 [B, H, S, D]

Returns:

分别更新了键和值的缓存张量。

Return type:

元组[torch.Tensor, torch.Tensor]

Raises:
  • AssertionError – 如果k_val的序列长度超过了最大缓存序列长度。

  • ValueError – 如果新键(或值)张量的批量大小大于缓存设置期间使用的批量大小。