跳至主要内容

插件介绍

插件是由TaskWeaver的代码解释器编排的基本单元。可以将插件视为大语言模型用来完成特定任务的工具。

在TaskWeaver中,每个插件都表示为一个Python函数,可以在生成的代码片段中调用。 一个具体例子是从数据库拉取数据并应用异常检测。生成的代码(简化版)如下所示:

df, data_description = sql_pull_data(query="pull data from time_series table")  
anomaly_df, anomaly_description = anomaly_detection(df, time_col_name="ts", value_col_name="val")

上面的代码片段调用了两个插件:sql_pull_dataanomaly_detectionsql_pull_data 插件 从数据库拉取数据,而 anomaly_detection 插件用于检测数据中的异常。

插件结构

一个插件包含两个文件:

  • 插件实现: 一个定义插件的Python文件
  • 插件模式: 一个yaml格式的文件,用于定义插件的模式

插件实现

插件功能需要用Python实现。 为了与TaskWeaver的编排协调配合,一个插件Python文件包含两个部分:

  • 插件功能实现代码
  • TaskWeaver插件装饰器

以下代码展示了异常检测插件的示例:

import pandas as pd
from pandas.api.types import is_numeric_dtype

from taskweaver.plugin import Plugin, register_plugin


@register_plugin
class AnomalyDetectionPlugin(Plugin):
def __call__(self, df: pd.DataFrame, time_col_name: str, value_col_name: str):

"""
anomaly_detection function identifies anomalies from an input dataframe of time series.
It will add a new column "Is_Anomaly", where each entry will be marked with "True" if the value is an anomaly
or "False" otherwise.

:param df: the input data, must be a dataframe
:param time_col_name: name of the column that contains the datetime
:param value_col_name: name of the column that contains the numeric values.
:return df: a new df that adds an additional "Is_Anomaly" column based on the input df.
:return description: the description about the anomaly detection results.
"""
try:
df[time_col_name] = pd.to_datetime(df[time_col_name])
except Exception:
print("Time column is not datetime")
return

if not is_numeric_dtype(df[value_col_name]):
try:
df[value_col_name] = df[value_col_name].astype(float)
except ValueError:
print("Value column is not numeric")
return

mean, std = df[value_col_name].mean(), df[value_col_name].std()
cutoff = std * 3
lower, upper = mean - cutoff, mean + cutoff
df["Is_Anomaly"] = df[value_col_name].apply(lambda x: x < lower or x > upper)
anomaly_count = df["Is_Anomaly"].sum()
description = "There are {} anomalies in the time series data".format(anomaly_count)

self.ctx.add_artifact(
name="anomaly_detection_results", # a brief description of the artifact
file_name="anomaly_detection_results.csv", # artifact file name
type="df", # artifact data type, support chart/df/file/txt/svg
val=df, # variable to be dumped
)

return df, description

您需要完成以下步骤来注册一个插件:

  1. 导入TaskWeaver插件装饰器 from taskWeaver.plugin import Plugin, register_plugin
  2. 创建你的插件类,继承自Plugin父类(例如AnomalyDetectionPlugin(Plugin)),并使用@register_plugin装饰器进行装饰
  3. 在插件类的__call__方法中实现你的插件功能。

我们在本教程中提供了一个开发新插件的示例流程。

tip

在插件实现中的一个良好实践是返回自然语言描述的结果。由于大语言模型仅理解自然语言,让模型理解执行结果非常重要。在上面的示例实现中,描述说明了检测到多少个异常。在其他情况下,比如加载CSV文件时,良好的描述可以是展示已加载数据的结构模式。大语言模型可以利用这个描述来规划后续步骤。

重要说明

  1. 如果您的插件功能依赖于额外的库或软件包,必须确保在使用前已安装这些依赖项。

  2. 如果您希望在插件实现中持久化中间结果(如数据、图表或提示),TaskWeaver提供了add_artifact API,允许您将这些结果存储在工作区中。在我们提供的示例中,如果您执行了异常检测并以CSV文件形式获得结果,可以使用add_artifact API将该文件保存为工件。这些工件会被存储在项目目录下的project/workspace/session_id/cwd文件夹中。

    self.ctx.add_artifact(
    name="anomaly_detection_results", # 工件的简要描述
    file_name="anomaly_detection_results.csv", # 工件文件名
    type="df", # 工件数据类型,支持chart/df/file/txt/svg
    val=df, # 待转储的变量
    )

插件架构

插件模式由以下几个部分组成:

  • *name: Python代码的主函数名称。
  • enabled: 决定该插件在对话期间是否可供选择。默认值为true。
  • plugin_only: 确定该插件是否在仅插件模式下启用。默认值为false。
  • code: 插件对应的代码文件名。默认值与插件名称相同。
  • *描述: 简要介绍插件功能的说明。
  • *parameters: 本节列出了所有输入参数信息。包括参数名称、类型、是否为必选参数,以及提供参数更多细节的描述。
  • *returns: 本节列出了所有返回值信息。包括返回值的名称、类型以及描述,这些描述提供了函数返回值的相关信息。
  • configurations: 插件的配置参数。默认值为空字典。
tip

添加任何额外字段或遗漏必填字段(上方标有*的)将导致插件架构内的验证失败。

插件模式需要使用YAML格式编写。以下是上述异常检测插件的模式示例:

name: anomaly_detection
enabled: true
plugin_only: false
required: false
description: >-
anomaly_detection function identifies anomalies from an input DataFrame of
time series. It will add a new column "Is_Anomaly", where each entry will be marked with "True" if the value is an anomaly or "False" otherwise.

parameters:
- name: df
type: DataFrame
required: true
description: >-
the input data from which we can identify the anomalies with the 3-sigma
algorithm.
- name: time_col_name
type: str
required: true
description: name of the column that contains the datetime
- name: value_col_name
type: str
required: true
description: name of the column that contains the numeric values.

returns:
- name: df
type: DataFrame
description: >-
This DataFrame extends the input DataFrame with a newly-added column
"Is_Anomaly" containing the anomaly detection result.
- name: description
type: str
description: This is a string describing the anomaly detection results.

info

如果不指定code字段,插件模式将使用插件名称作为代码文件名。 例如,插件名称为anomaly_detection,代码文件名为anomaly_detection.py。 当插件名称与代码文件名不一致时,可以在插件模式中指定代码名称(代码文件)以确保清晰准确。例如,插件名称为anomaly_detection而代码文件名为anomaly_detection_code.py。此时,您可以在插件模式中按如下方式指定代码名称:

code: anomaly_detection_code

请注意,代码文件名应与代码名称相同,但不带.py扩展名。 有关如何使用此功能的更多信息,请参阅Multiple YAML files to one Python implementation

info

在使用需要根据不同插件修改某些配置参数的通用代码时,务必在插件模式中明确指定这些配置参数。配置参数在插件模式中的指定方式如下:

 configurations:
key1: value1
key2: value2

这些配置参数可以在插件实现中按如下方式访问:

self.config.get("key1")
self.config.get("key2")
info

当此插件在仅插件模式下启用时,将plugin_only字段设为true。 默认值为false。请注意在非仅插件模式(默认模式)下会加载所有插件。 但在仅插件模式下,只会加载带有plugin_only: true设置的插件。