OpenAI APIs - 嵌入#

SGLang 提供与 OpenAI 兼容的 API,以便从 OpenAI 服务顺利过渡到自托管的本地模型。完整的 API 参考可在 OpenAI API 参考 中找到。

本教程涵盖了嵌入模型的嵌入API,例如

启动服务器#

在终端中启动服务器并等待其初始化。记得在命令中添加--is-embedding

[1]:
from sglang.utils import (
    execute_shell_command,
    wait_for_server,
    terminate_process,
    print_highlight,
)

embedding_process = execute_shell_command(
    """
python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \
    --port 30000 --host 0.0.0.0 --is-embedding
"""
)

wait_for_server("http://localhost:30000")
[2025-01-03 02:31:02] server_args=ServerArgs(model_path='Alibaba-NLP/gte-Qwen2-7B-instruct', tokenizer_path='Alibaba-NLP/gte-Qwen2-7B-instruct', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='Alibaba-NLP/gte-Qwen2-7B-instruct', chat_template=None, is_embedding=True, revision=None, skip_tokenizer_init=False, return_token_ids=False, host='0.0.0.0', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=1, stream_interval=1, random_seed=606729663, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
[2025-01-03 02:31:07] Downcasting torch.float32 to torch.float16.
[2025-01-03 02:31:15 TP0] Downcasting torch.float32 to torch.float16.
[2025-01-03 02:31:15 TP0] Overlap scheduler is disabled for embedding models.
[2025-01-03 02:31:15 TP0] Downcasting torch.float32 to torch.float16.
[2025-01-03 02:31:15 TP0] Init torch distributed begin.
[2025-01-03 02:31:16 TP0] Load weight begin. avail mem=78.81 GB
[2025-01-03 02:31:16 TP0] Ignore import error when loading sglang.srt.models.grok. unsupported operand type(s) for |: 'type' and 'NoneType'
[2025-01-03 02:31:16 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  14% Completed | 1/7 [00:00<00:04,  1.31it/s]
Loading safetensors checkpoint shards:  29% Completed | 2/7 [00:01<00:05,  1.04s/it]
Loading safetensors checkpoint shards:  43% Completed | 3/7 [00:03<00:05,  1.36s/it]
Loading safetensors checkpoint shards:  57% Completed | 4/7 [00:05<00:04,  1.55s/it]
Loading safetensors checkpoint shards:  71% Completed | 5/7 [00:07<00:03,  1.65s/it]
Loading safetensors checkpoint shards:  86% Completed | 6/7 [00:09<00:01,  1.68s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:11<00:00,  1.75s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:11<00:00,  1.58s/it]

[2025-01-03 02:31:28 TP0] Load weight end. type=Qwen2ForCausalLM, dtype=torch.float16, avail mem=64.40 GB
[2025-01-03 02:31:28 TP0] Memory pool end. avail mem=7.42 GB
[2025-01-03 02:31:28 TP0] max_total_num_tokens=1028801, max_prefill_tokens=16384, max_running_requests=4019, context_len=131072
[2025-01-03 02:31:28] INFO:     Started server process [4034424]
[2025-01-03 02:31:28] INFO:     Waiting for application startup.
[2025-01-03 02:31:28] INFO:     Application startup complete.
[2025-01-03 02:31:28] INFO:     Uvicorn running on http://0.0.0.0:30000 (Press CTRL+C to quit)
[2025-01-03 02:31:29] INFO:     127.0.0.1:43060 - "GET /v1/models HTTP/1.1" 200 OK
[2025-01-03 02:31:29] INFO:     127.0.0.1:43076 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-01-03 02:31:29 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-03 02:31:30] INFO:     127.0.0.1:43086 - "POST /encode HTTP/1.1" 200 OK
[2025-01-03 02:31:30] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.

使用cURL#

[2]:
import subprocess, json

text = "Once upon a time"

curl_text = f"""curl -s http://localhost:30000/v1/embeddings \
  -d '{{"model": "Alibaba-NLP/gte-Qwen2-7B-instruct", "input": "{text}"}}'"""

text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))["data"][0][
    "embedding"
]

print_highlight(f"Text embedding (first 10): {text_embedding[:10]}")
[2025-01-03 02:31:34 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-03 02:31:34] INFO:     127.0.0.1:43088 - "POST /v1/embeddings HTTP/1.1" 200 OK
Text embedding (first 10): [0.00830841064453125, 0.0006804466247558594, -0.00807952880859375, -0.000682830810546875, 0.01438140869140625, -0.009002685546875, 0.01239013671875, 0.0020999908447265625, 0.006214141845703125, -0.0030345916748046875]

使用Python Requests#

[3]:
import requests

text = "Once upon a time"

response = requests.post(
    "http://localhost:30000/v1/embeddings",
    json={"model": "Alibaba-NLP/gte-Qwen2-7B-instruct", "input": text},
)

text_embedding = response.json()["data"][0]["embedding"]

print_highlight(f"Text embedding (first 10): {text_embedding[:10]}")
[2025-01-03 02:31:34 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 3, cache hit rate: 21.43%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-03 02:31:34] INFO:     127.0.0.1:43098 - "POST /v1/embeddings HTTP/1.1" 200 OK
Text embedding (first 10): [0.00830078125, 0.0006747245788574219, -0.00807952880859375, -0.000682830810546875, 0.01438140869140625, -0.009002685546875, 0.01239013671875, 0.0020961761474609375, 0.006198883056640625, -0.003025054931640625]

使用OpenAI Python客户端#

[4]:
import openai

client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")

# Text embedding example
response = client.embeddings.create(
    model="Alibaba-NLP/gte-Qwen2-7B-instruct",
    input=text,
)

embedding = response.data[0].embedding[:10]
print_highlight(f"Text embedding (first 10): {embedding}")
[2025-01-03 02:31:34 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 3, cache hit rate: 33.33%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-03 02:31:34] INFO:     127.0.0.1:43102 - "POST /v1/embeddings HTTP/1.1" 200 OK
Text embedding (first 10): [0.00830078125, 0.0006747245788574219, -0.00807952880859375, -0.000682830810546875, 0.01438140869140625, -0.009002685546875, 0.01239013671875, 0.0020961761474609375, 0.006198883056640625, -0.003025054931640625]

使用输入ID#

SGLang 也支持 input_ids 作为输入来获取嵌入。

[5]:
import json
import os
from transformers import AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

tokenizer = AutoTokenizer.from_pretrained("Alibaba-NLP/gte-Qwen2-7B-instruct")
input_ids = tokenizer.encode(text)

curl_ids = f"""curl -s http://localhost:30000/v1/embeddings \
  -d '{{"model": "Alibaba-NLP/gte-Qwen2-7B-instruct", "input": {json.dumps(input_ids)}}}'"""

input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))["data"][
    0
]["embedding"]

print_highlight(f"Input IDs embedding (first 10): {input_ids_embedding[:10]}")
[2025-01-03 02:31:41 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 3, cache hit rate: 40.91%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-03 02:31:41] INFO:     127.0.0.1:35590 - "POST /v1/embeddings HTTP/1.1" 200 OK
Input IDs embedding (first 10): [0.00830078125, 0.0006747245788574219, -0.00807952880859375, -0.000682830810546875, 0.01438140869140625, -0.009002685546875, 0.01239013671875, 0.0020961761474609375, 0.006198883056640625, -0.003025054931640625]
[6]:
terminate_process(embedding_process)