Python客户端
LoRAX Python客户端提供了一种便捷的方式来与您环境中运行的lorax实例进行交互。
安装
pip install lorax-client
使用说明
from lorax import Client
endpoint_url = "http://127.0.0.1:8080"
client = Client(endpoint_url)
text = client.generate("Why is the sky blue?", adapter_id="some/adapter").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
或者使用异步客户端:
from lorax import AsyncClient
endpoint_url = "http://127.0.0.1:8080"
client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?", adapter_id="some/adapter")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
查看API参考获取完整详情。
批量推理
在某些情况下,您可能有一个提示词列表需要批量处理("批量处理")。
与其逐个处理每个提示,您可以利用AsyncClient和LoRAX的原生并行功能一次性提交所有提示并等待结果:
import asyncio
import time
from lorax import AsyncClient
# Batch of prompts to submit
prompts = [
"The quick brown fox",
"The rain in Spain",
"What comes up",
]
# Initialize the async client
endpoint_url = "http://127.0.0.1:8080"
async_client = AsyncClient(endpoint_url)
# Submit all prompts and do not block on the response
t0 = time.time()
futures = []
for prompt in prompts:
resp = async_client.generate(prompt, max_new_tokens=64)
futures.append(resp)
# Await the completion of all the prompt requests
responses = await asyncio.gather(*futures)
# Print responses
# Responses will always come back in the same order as the original list
for resp in responses:
print(resp.generated_text)
# Print duration to process all requests in batch
print("duration (s):", time.time() - t0)
输出:
duration (s): 2.9093329906463623
将此与逐个提交的持续时间进行比较。您会发现对于3个提示,异步处理的持续时间大约比串行处理快2.5到3倍:
from lorax import Client
client = Client(endpoint_url)
t0 = time.time()
responses = []
for prompt in prompts:
resp = client.generate(prompt, max_new_tokens=64)
responses.append(resp)
for resp in responses:
print(resp.generated_text)
print("duration (s):", time.time() - t0)
输出:
duration (s): 8.385080099105835
Predibase 推理端点
LoRAX客户端也可用于连接Predibase托管的LoRAX端点(包括Predibase的serverless endpoints)。
您只需要对上述示例进行以下更改:
- 将
endpoint_url更改为与您选择的Predibase LLM的端点相匹配。 - 在提供给客户端的
headers中提供您的Predibase API令牌。
示例:
from lorax import Client
# You can get your Predibase API token by going to Settings > My Profile > Generate API Token
# You can get your Predibase Tenant short code by going to Settings > My Profile > Overview > Tenant ID
endpoint_url = f"https://serving.app.predibase.com/{predibase_tenant_short_code}/deployments/v2/llms/{llm_deployment_name}"
headers = {
"Authorization": f"Bearer {api_token}"
}
client = Client(endpoint_url, headers=headers)
# same as above from here ...
response = client.generate("Why is the sky blue?", adapter_id=f"{model_repo}/{model_version}")
请注意,默认情况下Predibase会使用其内部模型仓库作为默认的adapter_source。要使用来自Huggingface的适配器:
response = client.generate("Why is the sky blue?", adapter_id="some/adapter", adapter_source="hub")