跳至内容

Python客户端

LoRAX Python客户端提供了一种便捷的方式来与您环境中运行的lorax实例进行交互。

安装

pip install lorax-client

使用说明

from lorax import Client

endpoint_url = "http://127.0.0.1:8080"

client = Client(endpoint_url)
text = client.generate("Why is the sky blue?", adapter_id="some/adapter").generated_text
print(text)
# ' Rayleigh scattering'

# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
    if not response.token.special:
        text += response.token.text

print(text)
# ' Rayleigh scattering'

或者使用异步客户端:

from lorax import AsyncClient

endpoint_url = "http://127.0.0.1:8080"

client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?", adapter_id="some/adapter")
print(response.generated_text)
# ' Rayleigh scattering'

# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
    if not response.token.special:
        text += response.token.text

print(text)
# ' Rayleigh scattering'

查看API参考获取完整详情。

批量推理

在某些情况下,您可能有一个提示词列表需要批量处理("批量处理")。

与其逐个处理每个提示,您可以利用AsyncClient和LoRAX的原生并行功能一次性提交所有提示并等待结果:

import asyncio
import time
from lorax import AsyncClient

# Batch of prompts to submit
prompts = [
    "The quick brown fox",
    "The rain in Spain",
    "What comes up",
]

# Initialize the async client
endpoint_url = "http://127.0.0.1:8080"
async_client = AsyncClient(endpoint_url)

# Submit all prompts and do not block on the response
t0 = time.time()
futures = []
for prompt in prompts:
    resp = async_client.generate(prompt, max_new_tokens=64)
    futures.append(resp)

# Await the completion of all the prompt requests
responses = await asyncio.gather(*futures)

# Print responses
# Responses will always come back in the same order as the original list
for resp in responses:
    print(resp.generated_text)

# Print duration to process all requests in batch
print("duration (s):", time.time() - t0)

输出:

duration (s): 2.9093329906463623

将此与逐个提交的持续时间进行比较。您会发现对于3个提示,异步处理的持续时间大约比串行处理快2.5到3倍:

from lorax import Client

client = Client(endpoint_url)

t0 = time.time()
responses = []
for prompt in prompts:
    resp = client.generate(prompt, max_new_tokens=64)
    responses.append(resp)

for resp in responses:
    print(resp.generated_text)

print("duration (s):", time.time() - t0)

输出:

duration (s): 8.385080099105835

Predibase 推理端点

LoRAX客户端也可用于连接Predibase托管的LoRAX端点(包括Predibase的serverless endpoints)。

您只需要对上述示例进行以下更改:

  1. endpoint_url更改为与您选择的Predibase LLM的端点相匹配。
  2. 在提供给客户端的headers中提供您的Predibase API令牌。

示例:

from lorax import Client

# You can get your Predibase API token by going to Settings > My Profile > Generate API Token
# You can get your Predibase Tenant short code by going to Settings > My Profile > Overview > Tenant ID
endpoint_url = f"https://serving.app.predibase.com/{predibase_tenant_short_code}/deployments/v2/llms/{llm_deployment_name}"
headers = {
    "Authorization": f"Bearer {api_token}"
}

client = Client(endpoint_url, headers=headers)

# same as above from here ...
response = client.generate("Why is the sky blue?", adapter_id=f"{model_repo}/{model_version}")

请注意,默认情况下Predibase会使用其内部模型仓库作为默认的adapter_source。要使用来自Huggingface的适配器:

response = client.generate("Why is the sky blue?", adapter_id="some/adapter", adapter_source="hub")