推测性采样

关于推测性采样

推测采样(也称为推测解码)是一组旨在允许每次前向传递迭代生成多个令牌的技术。这可以在由于批量较小而导致GPU未充分利用的情况下减少每个令牌的平均延迟。

推测性采样涉及预测一系列未来的标记,称为草稿标记,使用一种比重复执行目标大型语言模型(LLM)更高效的方法。然后,这些草稿标记通过目标LLM在一次前向传递中集体验证。其基本假设有两个:

  1. 同时处理多个草稿令牌将像处理单个令牌一样迅速

  2. 在完整生成过程中,多个草稿令牌将被成功验证

如果第一个假设成立,推测解码的延迟不会比标准方法差。如果第二个假设成立,输出令牌的生成在每次前向传递中统计上会超过一个令牌。这两个假设的结合使得推测解码能够减少延迟。

TensorRT-LLM 支持几种生成草稿令牌的方法,包括:

  1. 利用一个较小的辅助模型,称为草稿模型方法。更多信息,请参考Fast Inference from Transformers via Speculative Decoding paper

  2. 实现额外的语言模型头,用于预测未来位置的标记:

    1. Medusa: 具有多重解码头的简单LLM推理加速框架论文.

    2. 用于大型语言模型中快速推测解码的循环草稿器

    3. EAGLE: 推测性采样需要重新思考特征不确定性.

  3. 利用提示令牌作为草稿令牌。更多信息,请参考Prompt Lookup Decoding

  4. 利用类似Jacobi的解码方法来预测和验证草稿令牌,使用相同的模型,无需额外的微调。请参考Break the Sequential Dependency of LLM Inference Using Lookahead Decoding

性能改进

值得注意的是,推测解码技术的有效性高度依赖于手头的具体任务。例如,在代码补全场景中预测后续标记可能比生成文章摘要更简单。

此外,当将Medusa与标准的PyTorch模型实现集成时,由于这些实现可能不如TensorRT-LLM那样精细调整,潜在的时间节省更为显著。

草稿-目标-模型

草案-目标-模型涉及使用两个独立训练但共享相同词汇的模型:一个较小的草案模型和一个较大的目标模型。例如,GPT 125M / 6.7B 模型可以作为草案/目标模型。

目前,在TensorRT-LLM中使用Draft-Target-Model有两种方式。第一种是在Triton中使用TensorRT-LLM-BLS,更多信息和详细步骤可以在本文档中找到。第二种是直接在TensorRT-LLM中使用,步骤可以在examples/draft_target_model/README.md中找到,代码可以在examples/prompt_lookup/run_dtm_pld.py中找到。

草稿和目标模型的管理通过两个独立的GptManager实例进行。 有效地协调草稿和目标模型之间的交互是至关重要的。 最初,查询草稿模型以生成最多K个草稿令牌。 然后,这些令牌被转发到目标模型进行验证。 验证后,目标模型可能会返回最多K+1个令牌。 随后,更新了已接受令牌的提示被发送回草稿模型,以启动新草稿令牌的生成。 这个迭代过程会持续进行,直到满足预定义的停止条件。 可以在TensorRT-LLM Triton后端中找到这种编排过程的示例。

在Inflight Fused Batching (IFB)框架中配置和执行Draft模型的过程与IFB中的任何其他模型相同。maxNewTokens参数应设置为Draft模型查询的LlmRequest中的草稿令牌数量。

在构建目标模型时,需要为trtllm-build命令指定--max_draft_len --speculative_decoding_mode draft_tokens_external选项。 在IFB中的目标模型推理阶段,maxNewTokens应设置为1, 并且必须在目标模型查询的LlmRequestdraftTokens字段中设置草稿令牌。

注意: 为了提高性能,特别是由于对Draft和Target模型的重复查询请求共享一个共同的前缀, 建议为这两个模型启用KV缓存重用。 这可以通过在trtllm-build命令中添加--use_paged_context_fmha=enable标志 并在KVCacheConfig中设置enableBlockReuse=true来实现。

使用Draft-Target-Model方法与Triton推理服务器

此示例仅适用于Draft-Target-Model模型方法。对于所有其他推测性解码模型,您可以像标准非推测性自回归模型一样在Triton服务器中部署它们。

  • 自TensorRT-LLM-0.7.0起支持草稿模型方法(使用两个独立的Tritonserver分别维护草稿和目标模型),但在TensorRT-LLM-0.10.0中进行了显著优化(使用一个Tritonserver与业务逻辑脚本,BLS)。

  • 带有BLS的Draft模型的源文件可以在这里找到here

  • 此示例基于 TensorRT-LLM-0.10.0 和 TRTLLM-backend-0.10.0,使用 docker 镜像 nvcr.io/nvidia/tritonserver:24.05-trtllm-py3

  • 在这个例子中,Llama-7B-hf 和 Llama-30B-hf 分别用作草稿和目标模型,假设模型仓库的路径是 DRAFT_MODEL_PATHTARGET_MODEL_PATH

  • 在此示例中,草稿令牌的最大数量设置为10。

  1. 为推理准备TensorRT引擎

    • 以下是构建FP16或FP8的草稿/目标引擎的命令。支持所有数据类型(草稿-FP16/FP8 + 目标-FP16/FP8)的组合。

    • --remove_input_padding=enable --paged_kv_cache=enable 是 inflight-batching 所必需的。

    • --context_fmha=enable --use_paged_context_fmha=enable 是可选的,但建议使用以提高性能。

    • --gather_generation_logits 如果使用生成logits来选择目标模型中的token,则是必要的。

    • --tp_size 可以在使用TP模式时修改设置,适用于草稿/目标模型。

    • --max_batch_size 在一般情况下可以接受大于1的值,但在此示例中我们使用1。

    export MAX_DRAFT_LENGTH=10
    export COMMON_COMMAND="--max_batch_size=1 --max_input_len=2048 --max_seq_len=3072 --gpt_attention_plugin=float16 --gemm_plugin=float16 --remove_input_padding=enable --paged_kv_cache=enable --context_fmha=enable --use_paged_context_fmha=enable --gather_generation_logits"
    export DRAFT_COMMAND_FP16="$COMMON_COMMAND"
    export TARGET_COMMAND_FP16="$DRAFT_COMMAND_FP16 --max_draft_len=$MAX_DRAFT_LENGTH --speculative_decoding_mode draft_tokens_external"
    export DRAFT_COMMAND_FP8="$COMMON_COMMAND --use_fp8_context_fmha=enable"
    export TARGET_COMMAND_FP8="$DRAFT_COMMAND_FP8 --max_draft_len=$MAX_DRAFT_LENGTH --speculative_decoding_mode draft_tokens_external"
    
    # Build checkpoints and engines in tensorrt_llm/examples/llama/
    # FP16 mode
    export DRAFT_NAME=llama-7b-fp16-tp1
    export TARGET_NAME=llama-30b-fp16-tp1
    python3 convert_checkpoint.py --model_dir=$DRAFT_MODEL_PATH --output_dir=ckpt/$DRAFT_NAME --tp_size=1
    python3 convert_checkpoint.py --model_dir=$TARGET_MODEL_PATH --output_dir=ckpt/$TARGET_NAME --tp_size=1
    trtllm-build --checkpoint_dir=ckpt/$DRAFT_NAME --output_dir=engine/draft/$DRAFT_NAME $DRAFT_COMMAND_FP16
    trtllm-build --checkpoint_dir=ckpt/$TARGET_NAME --output_dir=engine/target/$TARGET_NAME $TARGET_COMMAND_FP16
    export DRAFT_ENGINE_PATH=$(pwd)/engine/draft/$DRAFT_NAME
    export TARGET_ENGINE_PATH=$(pwd)/engine/target/$TARGET_NAME
    
    # FP8 mode
    export DRAFT_NAME=llama-7b-fp8-tp1
    export TARGET_NAME=llama-30b-fp8-tp1
    python3 ../quantization/quantize.py --model_dir=$DRAFT_MODEL_PATH --dtype float16 --qformat fp8 --kv_cache_dtype fp8 --output_dir=ckpt/$DRAFT_NAME --tp_size=1
    python3 ../quantization/quantize.py --model_dir=$TARGET_MODEL_PATH --dtype float16 --qformat fp8 --kv_cache_dtype fp8 --output_dir=ckpt/$TARGET_NAME --tp_size=1
    trtllm-build --checkpoint_dir=ckpt/$DRAFT_NAME --output_dir=engine/draft/$DRAFT_NAME $DRAFT_COMMAND_FP8
    trtllm-build --checkpoint_dir=ckpt/$TARGET_NAME --output_dir=engine/target/$TARGET_NAME $TARGET_COMMAND_FP8
    export DRAFT_ENGINE_PATH=$(pwd)/engine/draft/$DRAFT_NAME
    export TARGET_ENGINE_PATH=$(pwd)/engine/target/$TARGET_NAME
    
  2. 编辑Triton配置

    • 如果草稿模型和目标模型可以放在一个GPU中(例如,llama-7B-FP8 + llama-30B-FP8,总共40GiB在一个H100-80GiB GPU中),DRAFT_GPU_DEVICE_IDSTARGET_GPU_DEVICE_IDS 可以是相同的,例如 0。这比放在两个独立的GPU上表现更好。

    • 否则,草稿模型和目标模型可以放置在不同的GPU中,例如DRAFT_GPU_DEVICE_IDS="0"TARGET_GPU_DEVICE_IDS="1"

    • 此外,如果使用TP模式,GPU_DEVICE_IDS的值可以是一个列表,例如DRAFT_GPU_DEVICE_IDS="0"TARGET_GPU_DEVICE_IDS="1,2,3,4"

    • 有关使用Tritonserver启动模型的更多配置,请访问TensorRT-LLM Backed repo

    ACCUMULATE_TOKEN="false"
    BACKEND="tensorrtllm"
    BATCH_SCHEDULER_POLICY="guaranteed_no_evict"
    BATCHING_STRATEGY="inflight_fused_batching"
    BLS_INSTANCE_COUNT="1"
    DECODING_MODE="top_k_top_p"
    DECOUPLED_MODE="False"
    DRAFT_GPU_DEVICE_IDS="0"
    E2E_MODEL_NAME="ensemble"
    ENABLE_KV_CACHE_REUSE="true"
    ENGINE_PATH=$TARGET_ENGINE_PATH
    EXCLUDE_INPUT_IN_OUTPUT="false"
    KV_CACHE_FREE_GPU_MEM_FRACTION="0.8"
    MAX_ATTENTION_WINDOW_SIZE=""
    MAX_BEAM_WIDTH="1"
    MAX_QUEUE_DELAY_MICROSECONDS="0"
    MAX_TOKENS_IN_KV_CACHE=""
    NORMALIZE_LOG_PROBS="true"
    POSTPROCESSING_INSTANCE_COUNT="1"
    PREPROCESSING_INSTANCE_COUNT="1"
    TARGET_GPU_DEVICE_IDS="1"
    TENSORRT_LLM_DRAFT_MODEL_NAME="tensorrt_llm_draft"
    TENSORRT_LLM_MODEL_NAME="tensorrt_llm"
    TOKENIZER_PATH=$DRAFT_MODEL_PATH
    TOKENIZER_TYPE=llama
    TRITON_GRPC_PORT="8001"
    TRITON_HTTP_PORT="8000"
    TRITON_MAX_BATCH_SIZE="4"
    TRITON_METRICS_PORT="8002"
    TRITON_REPO="triton_repo"
    USE_DRAFT_LOGITS="false"
    
    # Make a copy of triton repo and replace the fields in the configuration files
    cd /tensorrtllm_backend/
    apt-get update && apt-get install -y build-essential cmake git-lfs
    pip3 install git-lfs tritonclient grpcio
    rm -rf ${TRITON_REPO}
    cp -R all_models/inflight_batcher_llm ${TRITON_REPO}
    python3 tools/fill_template.py -i ${TRITON_REPO}/ensemble/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE}
    python3 tools/fill_template.py -i ${TRITON_REPO}/preprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_PATH},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},preprocessing_instance_count:${PREPROCESSING_INSTANCE_COUNT}
    python3 tools/fill_template.py -i ${TRITON_REPO}/postprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_PATH},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},postprocessing_instance_count:${POSTPROCESSING_INSTANCE_COUNT}
    python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},accumulate_tokens:${ACCUMULATE_TOKEN},bls_instance_count:${BLS_INSTANCE_COUNT},tensorrt_llm_model_name:${TENSORRT_LLM_MODEL_NAME},tensorrt_llm_draft_model_name:${TENSORRT_LLM_DRAFT_MODEL_NAME}
    
    # Make a copy of tensorrt_llm as configurations of draft / target models.
    cp -R ${TRITON_REPO}/tensorrt_llm ${TRITON_REPO}/tensorrt_llm_draft
    sed -i 's/name: "tensorrt_llm"/name: "tensorrt_llm_draft"/g' ${TRITON_REPO}/tensorrt_llm_draft/config.pbtxt
    python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm/config.pbtxt          triton_backend:${BACKEND},engine_dir:${ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${TARGET_GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},encoder_input_features_data_type:TYPE_FP16
    python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_draft/config.pbtxt    triton_backend:${BACKEND},engine_dir:${DRAFT_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${DRAFT_GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE}
    
  3. 启动 Triton 服务器

    • --multi-model 如果目标模型使用TP模式,则是必要的。

    python3 scripts/launch_triton_server.py \
        --model_repo=${TRITON_REPO} \
        --tensorrt_llm_model_name "tensorrt_llm,tensorrt_llm_draft" \
        --multi-model \
        --log &
    
    • 详细日志将被写入文件 triton_log.txt。如果您在文件中看到以下输出,则表示 Triton 服务器已成功启动:

    Started HTTPService at 0.0.0.0:8000
    Started GRPCInferenceService at 0.0.0.0:8001
    Started Metrics Service at 0.0.0.0:8002
    
  4. 发送请求

    • 准备一个包含输入数据的JSON文件 input_data.json,如下所示(更多的请求是可以接受的)。

    [
        {
            "input": "James Best, best known for his ",
            "instruction": "Continue writing the following story:",
            "output": "                                                                "
        }
    ]
    
    • 使用以下命令启动推理请求。

    • --num-draft-tokens 可以根据运行时草稿长度进行修改,本示例中使用的是4。

    python3 tools/inflight_batcher_llm/speculative_decoding_test.py \
        --max-input-len 2048 \
        --dataset=input_data.json \
        --url-target=localhost:8001 \
        --url-draft=localhost:8001 \
        --draft-tensorrt-llm-model-name="${TENSORRT_LLM_DRAFT_MODEL_NAME}" \
        --target-tensorrt-llm-model-name="${TENSORRT_LLM_MODEL_NAME}" \
        --bls-speculative-tensorrt-llm-model-name="tensorrt_llm_bls" \
        --execute-bls-speculative-decoding \
        --disable-output-comparison \
        --num-draft-tokens=4 \
        --verbose
    
  5. "use_draft_logits": True时,启用快速logits D2D传输

    • 从草稿logits获取调整后的logits分布是Fast Inference from Transformers via Speculative Decoding paper中提出的一种方法。快速logits功能通过隐藏从草稿引擎到目标引擎的logits传输延迟来提高性能(TPS)。

    • TensorRT-LLM-0.15.0 新支持了快速 logits 功能。

    • 修改participant_ids条目在tensorrt_llm/config.pbtxttensorrt_llm_draft/config.pbtxt中为合适的MPI等级。通常在这种设置中,等级0保留给协调器等级;等级1用于草稿引擎;其余等级用于目标引擎。在这个例子中,particpant_ids可以设置为如下片段。同样的逻辑也适用于TP>1的目标引擎。

    ### In tensorrt_llm_draft/config.pbtxt
    parameters: {
        key: "gpu_device_ids"
        value: {
            string_value: "0"
        }
    }
    parameters: {
        key: "participant_ids"
        value: {
            string_value: "1"
        }
    }
    ### In tensorrt_llm/config.pbtxt
    parameters: {
        key: "gpu_device_ids"
        value: {
            string_value: "1"
        }
    }
    parameters: {
        key: "participant_ids"
        value: {
            string_value: "2"
        }
    }
    
    • tensorrt_llm/config.pbtxttensorrt_llm_draft/config.pbtxt 中启用 speculative_decoding_fast_logits

    parameters: {
        key: "speculative_decoding_fast_logits"
        value: {
            string_value: "1"
        }
    }
    
    • 快速logits功能要求Tritonserver在协调器模式下启动,并使用--disable-spawn-process。有关更多信息,请参阅模型配置--world_size必须设置为1(协调器等级0)+ 1(草案引擎等级)+ 1(目标引擎等级)。

    python3 scripts/launch_triton_server.py \
        --model_repo=$TRITON_REPO \
        --tensorrt_llm_model_name "tensorrt_llm,tensorrt_llm_draft" \
        --multi-model \
        --disable-spawn-processes \
        --world_size=3 --log &
    
    • 发送请求到tritonserver BLS API时使用use_draft_logits

    curl -X POST "http://localhost:8000/v2/models/tensorrt_llm_bls/generate" \
        -H "Content-Type: application/json" \
        -d '{
            "text_input": "Continue writing the following story: James Best, best known for his",
            "max_tokens": 128,
            "num_draft_tokens": 10,
            "use_draft_logits": true,
            "stream": false
            }'
    
    • 启用快速logits并遵循模型配置中的优化建议后,与使用Llama 3.2 1B草稿和Llama 3.1 70B目标的自动回归解码相比,使用草稿logits的推测解码在BS1中实现了2.x的吞吐量,在BS16中实现了1.x的吞吐量。

  6. 在完成推理后杀死Tritonserver

    pkill -9 -f trtllmExecutorWorker
    pkill -9 -f tritonserver
    

提示-查找-解码

请参阅文档在examples/prompt_lookup/README.md,代码可以在examples/prompt_lookup/run_dtm_pld.py中找到。

美杜莎

这种方法利用单一模型来生成和验证草稿令牌。 它通过添加多个额外的语言模型头(称为Medusa头)来增强现有模型。 这些额外的头被训练来预测未来的令牌,而基础模型保持不变。 具体来说,第一个Medusa头负责预测紧接着的下一个令牌, 第二个头预测之后的令牌,依此类推。 使用K个Medusa头,模型可以预测多达K个令牌。 在迭代i期间由Medusa头生成的草稿令牌 然后在后续迭代i+1中被验证并可能被接受。

当每个头部使用多个令牌时,Medusa策略的真正潜力得以实现,采用TopK方法创建多个潜在路径,基本上形成一棵树,而不是像Draft模型方法中看到的单一线性路径。为了减少冗余计算,许多这些路径(通常共享共同的前缀)被合并为单一路径。这是通过应用带有表示各种路径的稀疏掩码的注意力机制来实现的。由Medusa树形成的稀疏掩码将在后面详细描述。

通过同时验证多条路径,尽管以额外的计算努力为代价,每次迭代接受多个令牌的可能性增加了。

重要的是要认识到,随着潜在路径的数量随着K呈指数增长,没有必要探索或验证所有路径。管理这种复杂性的推荐策略是通过仅关注具有较高概率标记的路径来修剪树。

你必须在你想探索的树的广度和深度与更大的树对你特定应用整体性能的影响之间找到平衡。

在Medusa的TensorRT-LLM实现中,树的配置是一个运行时参数。 这种灵活性允许您进行实验并确定适合您用例的最佳树结构, 然后可以在生产环境中使用。

美杜莎树

考虑以下图表,它说明了基础模型最后一层的隐藏状态如何传递到基础模型的语言模型(LM)头部和四个Medusa头部(MHs)。

Example Medusa Tree

在这个例子中:

  1. 令牌 l0 表示模型生成的实际令牌。 所有其他令牌,表示为 phk,是来自MHs的预测, 其中 h 表示Medusa头索引(基于1)和 k 表示TopK选择索引(基于0)。

  2. 使用了四个MH,这意味着模型正在预测四个未来的标记。

  3. 前两个MH使用Top-2预测,而后两个使用Top-1。 例如,p10p11 是第一个Medusa Head(MH1)的最高和第二高预测。

  4. 总共探索了四条路径,这比使用完整二叉树(假设所有MH的Top-2预测)将检查的16条路径要少。

  5. 由于其中一些路径可能被接受,有十个潜在的候选者,称为medusa_choices。 每一步可以接受的令牌数量,包括真实令牌, 范围从1(如果所有Medusa预测都不正确)到5(如果所有预测都正确)。

在生成阶段,模型接收10个标记的输入,这些标记对应于每个候选路径的最后标记,而不仅仅是一个。

在TensorRT-LLM中,您可以选择通过提供所有Medusa选项或仅指定唯一路径来定义此类树。

  • 由于每个候选/路径都以真实标记(l0)开头,因此无需单独指定。对于预测的标记,只需要TopK索引。

  • 例如,要指定路径 l0p10p21p30, 可以使用 [0,1,0]。而要指定路径 l0p11p20, 可以使用 [1,0]

  • 要在示例中指定所有4条路径,请使用medusa_choices=[[0,0,0,0], [0,1,0], [1,0], [1,1]]

  • 也可以像Medusa仓库那样明确指定所有候选。例如,medusa_choices=[[0], [0,0], [0,0,0], [0,0,0,0], [0,1], [0,1,0], [1], [1,0], [1,1]]。请注意,当明确指定所有候选时,我们不包含空的[]候选,这种情况只接受真实标记,即所有来自MHs的预测都是错误的。因此,只指定了9个候选。

目前仅在Python运行时支持指定仅路径而不是所有选择。

使用Medusa与TensorRT-LLM

有关使用Python运行时构建和执行Medusa的指导,请参阅Medusa README。当使用C++ API的Inflight Fused Batching (IFB)时,必须在模型配置中明确定义medusa_choices。有关详细说明,请参阅TensorRT-LLM后端中的模型配置以获取更多详细信息。

限制

  • TensorRT-LLM 仅支持 Vicuna(微调的 LLaMA)的 Medusa。 然而,与任何新模型类似,您可以遵循相同的方法定义自己的 Medusa 模型并使用 TensorRT-LLM 进行部署。

  • 我们仅在验证阶段匹配标记,即medusa_temperature=0

  • Beam搜索与Medusa兼容。

ReDrafter

ReDrafter方法通过使用同一模型预测和验证令牌来增强单模型Medusa方法。然而,与Medusa不同,它使用循环预测器预测草稿令牌,其中每个草稿令牌依赖于前一个。此方法还允许使用束搜索来识别更突出的草稿令牌。更多详情,请阅读ReDrafter论文

TensorRT-LLM 实现了 ReDrafter 模型,使得 logits 预测、beam search 和 draft token 接受都在 TensorRT 引擎内部执行。这与标准模型推理不同,后者仅在引擎外部预测 logits 并执行解码。由于引擎预测的是显式的 draft tokens 而不是从 logits 解码的隐式 tokens,我们将这种推测解码方法分类为 explicit_draft_tokens。请访问 ReDrafter README 获取有关构建和运行模型的信息。ReDrafter 支持 Inflight Fused Batching 运行时和 Python 静态批处理运行时。

EAGLE

EAGLE方法通过使用相同的模型预测和验证令牌来增强单模型Medusa方法。与ReDrafter类似,它使用循环预测器预测草稿令牌,其中每个草稿令牌依赖于前一个令牌。然而,与ReDrafter不同的是,它使用单层变压器模型从先前的隐藏状态和解码的令牌中预测草稿令牌。在EAGLE-1中,解码树需要在解码期间已知。在EAGLE-2中,该树通过沿波束搜索最可能的假设在执行期间组装。

与ReDrafter类似,TensorRT-LLM实现了EAGLE模型,使得logits预测、草稿令牌接受和草稿令牌生成都在TensorRT引擎内部执行。仅支持EAGLE-1。请访问EAGLE README以获取有关构建和运行模型的信息。

前瞻解码

前瞻解码算法通过同一模型内的两个并行计算分支进行操作:一个前瞻分支使用固定大小的2D窗口生成n-gram,另一个验证分支验证有潜力的n-gram候选。这种方法消除了额外模型训练或微调的必要性,并且可以应用于任何自回归模型。有关构建和运行模型的信息,请参阅前瞻解码README