节点

在本节中,我们将介绍节点的概念,相关API文档请参阅node

节点是流水线的基本构建单元,代表具体任务。流水线用于组合节点以构建工作流,其范围涵盖从简单的机器学习工作流到端到端(E2E)生产工作流。

要运行下面的代码片段,您首先需要从Kedro和其他标准工具导入库。

from kedro.pipeline import *
from kedro.io import *
from kedro.runner import *

import pickle
import os

如何创建节点

节点通过指定函数、输入变量名和输出变量名来创建。让我们看一个简单的两数相加函数:

def add(x, y):
    return x + y

该函数有两个输入(xy)和一个输出(输入的和)。

以下是如何使用该函数创建一个节点:

adder_node = node(func=add, inputs=["a", "b"], outputs="sum")
adder_node

以下是输出内容:

Out[1]: Node(add, ['a', 'b'], 'sum', None)

您还可以为节点添加标签,这些标签将用于在日志中描述它们:

adder_node = node(func=add, inputs=["a", "b"], outputs="sum")
print(str(adder_node))

adder_node = node(func=add, inputs=["a", "b"], outputs="sum", name="adding_a_and_b")
print(str(adder_node))

这将产生以下输出:

add([a,b]) -> [sum]
adding_a_and_b: add([a,b]) -> [sum]

让我们分解一下节点定义:

  • add 是节点运行时将执行的Python函数

  • ['a', 'b'] 指定输入变量名称

  • sum 指定返回变量名。add 返回的值将被绑定到这个变量中

  • name 是节点的可选标签,可用于描述其提供的业务逻辑

节点定义语法

语法描述了函数的输入和输出。这种语法允许在节点中重用不同的Python函数,并支持在管道中进行依赖解析。

输入变量的语法

输入语法

含义

示例函数参数

节点运行时如何调用函数

None

无输入

def f()

f()

'a'

单一输入

def f(arg1)

f(a)

['a', 'b']

多输入

def f(arg1, arg2)

f(a, b)

['a', 'b', 'c']

变量输入

def f(arg1, *args).

f(arg1, arg2, arg3)

dict(arg1='x', arg2='y')

关键词输入

def f(arg1, arg2)

f(arg1=x, arg2=y)

输出变量的语法

输出语法

含义

示例返回语句

None

无输出

不返回

'a'

单一输出

return a

['a', 'b']

列表输出

return [a, b]

dict(key1='a', key2='b')

字典输出

return dict(key1=a, key2=b)

以上各种组合都是可行的,除了形如node(f, None, None)的节点(必须至少提供一个输入或输出)。

*args 节点函数

通常会有需要接收任意数量输入的函数,比如合并多个数据框的函数。你可以在节点函数中使用*args参数,同时在节点输入中简单声明数据集的名称。

仅支持**kwargs参数的节点函数

有时,例如在创建报告节点时,您需要知道节点接收的数据集名称,但可能事先没有这些信息。这可以通过定义一个仅包含**kwargs参数的函数来解决:

def reporting(**kwargs):
    result = []
    for name, data in kwargs.items():
        res = example_report(name, data)
        result.append(res)
    return combined_report(result)

然后,在构建Node时,只需向节点输入传递一个字典:

from kedro.pipeline import node


uk_reporting_node = node(
    reporting,
    inputs={"uk_input1": "uk_input1", "uk_input2": "uk_input2", ...},
    outputs="uk",
)

ge_reporting_node = node(
    reporting,
    inputs={"ge_input1": "ge_input1", "ge_input2": "ge_input2", ...},
    outputs="ge",
)

或者,您也可以利用一个辅助函数来为您创建映射,这样您就可以在代码库中重复使用它。

 from kedro.pipeline import node


+mapping = lambda x: {k: k for k in x}
+
 uk_reporting_node = node(
     reporting,
-    inputs={"uk_input1": "uk_input1", "uk_input2": "uk_input2", ...},
+    inputs=mapping(["uk_input1", "uk_input2", ...]),
     outputs="uk",
 )

 ge_reporting_node = node(
     reporting,
-    inputs={"ge_input1": "ge_input1", "ge_input2": "ge_input2", ...},
+    inputs=mapping(["ge_input1", "ge_input2", ...]),
     outputs="ge",
 )

如何标记节点

标签可能有助于在不更改代码的情况下运行部分流水线。例如,kedro run --tags=ds将仅运行附加了ds标签的节点。

要为节点添加标签,只需指定tags参数:

node(func=add, inputs=["a", "b"], outputs="sum", name="adding_a_and_b", tags="node_tag")

此外,您还可以Pipeline中的所有节点添加标签。如果流水线定义中包含tags=参数,Kedro会自动将该标签附加到该流水线内的每个节点上。

要使用标签运行管道:

kedro run --tags=pipeline_tag

这将仅运行标记为pipeline_tag的管道中找到的节点。

注意

节点或标签名称只能包含字母、数字、连字符、下划线和/或句点。不允许使用其他符号。

如何运行一个节点

要运行一个节点,你必须实例化它的输入。在本例中,该节点需要两个输入:

adder_node.run(dict(a=2, b=3))

输出如下:

Out[2]: {'sum': 5}

注意

你也可以像调用普通Python函数一样调用节点:adder_node(dict(a=2, b=3))。这实际上会在后台调用adder_node.run(dict(a=2, b=3))

如何在节点中使用生成器函数

警告

本文档部分使用了pandas-iris入门模板,该模板在Kedro 0.19.0及更高版本中不可用。支持pandas-iris的最新Kedro版本是0.18.14:请安装该版本或更早版本来运行此示例pip install kedro==0.18.14)。

要检查已安装的版本,请在终端窗口中输入kedro -V

Generator functionsPEP 255 引入,是Python中一种特殊的函数,它返回惰性迭代器。它们通常用于数据的惰性加载或惰性保存,这在处理无法完全放入内存的大型数据集时特别有用。在Kedro的上下文中,生成器函数可用于节点中,以高效处理和管理这类大型数据集。

设置项目

使用传统的pandas-iris启动器设置Kedro项目。假设Kedro版本为0.18.14,通过以下命令创建项目:

kedro new --starter=pandas-iris --checkout=0.18.14

使用生成器加载数据

要在Kedro节点中使用生成器函数,您需要更新catalog.yml文件,为将使用生成器处理的相关数据集添加chunksize参数。

你需要在catalog.yml中添加一个新的数据集,如下所示:

+ X_test:
+  type: pandas.CSVDataset
+  filepath: data/05_model_input/X_test.csv
+  load_args:
+    chunksize: 10

借助pandas内置支持,您可以使用chunksize参数通过生成器读取数据。

使用生成器保存数据

要使用生成器实现数据的惰性保存,你需要完成以下三件事:

  • make_prediction函数定义更新为使用yield而非return

  • 创建一个名为ChunkWiseCSVDataset自定义数据集

  • 更新 catalog.yml 以使用新创建的 ChunkWiseCSVDataset

将以下代码复制到nodes.py中。主要变化是使用新模型DecisionTreeClassifier,在make_predictions中按块进行预测。

Click to open
import logging
from typing import Any, Dict, Tuple, Iterator, Generator
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
import numpy as np
import pandas as pd


def split_data(
    data: pd.DataFrame, parameters: Dict[str, Any]
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    """Splits data into features and target training and test sets.

    Args:
        data: Data containing features and target.
        parameters: Parameters defined in parameters.yml.
    Returns:
        Split data.
    """

    data_train = data.sample(
        frac=parameters["train_fraction"], random_state=parameters["random_state"]
    )
    data_test = data.drop(data_train.index)

    X_train = data_train.drop(columns=parameters["target_column"])
    X_test = data_test.drop(columns=parameters["target_column"])
    y_train = data_train[parameters["target_column"]]
    y_test = data_test[parameters["target_column"]]

    label_encoder = LabelEncoder()
    label_encoder.fit(pd.concat([y_train, y_test]))
    y_train = label_encoder.transform(y_train)

    return X_train, X_test, y_train, y_test


def make_predictions(
    X_train: pd.DataFrame, X_test: pd.DataFrame, y_train: pd.Series
) -> Generator[pd.Series, None, None]:
    """Use a DecisionTreeClassifier model to make prediction."""
    model = DecisionTreeClassifier()
    model.fit(X_train, y_train)

    for chunk in X_test:
        y_pred = model.predict(chunk)
        y_pred = pd.DataFrame(y_pred)
        yield y_pred


def report_accuracy(y_pred: pd.Series, y_test: pd.Series):
    """Calculates and logs the accuracy.

    Args:
        y_pred: Predicted target.
        y_test: True target.
    """
    accuracy = accuracy_score(y_test, y_pred)
    logger = logging.getLogger(__name__)
    logger.info("Model has accuracy of %.3f on test data.", accuracy)

ChunkWiseCSVDatasetpandas.CSVDataset的一个变体,主要改动在于_save方法改为追加数据而非覆盖。您需要创建文件src//chunkwise.py并将这个类放入其中。以下是ChunkWiseCSVDataset的实现示例:

import pandas as pd

from kedro.io.core import (
    get_filepath_str,
)
from kedro_datasets.pandas import CSVDataset


class ChunkWiseCSVDataset(CSVDataset):
    """``ChunkWiseCSVDataset`` loads/saves data from/to a CSV file using an underlying
    filesystem. It uses pandas to handle the CSV file.
    """

    _overwrite = True

    def _save(self, data: pd.DataFrame) -> None:
        save_path = get_filepath_str(self._get_save_path(), self._protocol)
        # Save the header for the first batch
        if self._overwrite:
            data.to_csv(save_path, index=False, mode="w")
            self._overwrite = False
        else:
            data.to_csv(save_path, index=False, header=False, mode="a")

之后,您需要更新catalog.yml以使用这个新的数据集。

+ y_pred:
+  type: <package_name>.chunkwise.ChunkWiseCSVDataset
+  filepath: data/07_model_output/y_pred.csv

经过这些修改后,当你在终端运行kedro run命令时,应该能在日志中看到y_pred被多次保存,这是因为生成器会惰性地处理数据并以较小分块的形式保存数据。

...
                    INFO     Loading data from 'y_train' (MemoryDataset)...                                                                                         data_catalog.py:475
                    INFO     Running node: make_predictions: make_predictions([X_train,X_test,y_train]) -> [y_pred]                                                         node.py:331
                    INFO     Saving data to 'y_pred' (ChunkWiseCSVDataset)...                                                                                       data_catalog.py:514
                    INFO     Saving data to 'y_pred' (ChunkWiseCSVDataset)...                                                                                       data_catalog.py:514
                    INFO     Saving data to 'y_pred' (ChunkWiseCSVDataset)...                                                                                       data_catalog.py:514
                    INFO     Completed 2 out of 3 tasks                                                                                                         sequential_runner.py:85
                    INFO     Loading data from 'y_pred' (ChunkWiseCSVDataset)...                                                                                    data_catalog.py:475
...                                                                              runner.py:105