Dask¶
本页介绍如何使用Dask(一个灵活、开源的Python并行计算库)来分布式执行构成Kedro管道的节点。
Dask 提供了默认的单机调度器和更高级的分布式调度器。较新的 dask.distributed 调度器通常是更好的选择,即使在单工作站上也是如此,这也是我们部署指南的重点。有关在不同硬件上设置 Dask 的各种方法,请参阅 官方 Dask 操作指南。
为何要使用Dask?¶
Dask.distributed 是一个用于Python分布式计算的轻量级库。它补充了现有的PyData分析栈,后者构成了许多Kedro管道的基础。它也是纯Python编写的,这使得安装更加简便并简化了调试过程。关于为什么人们选择采用Dask,特别是dask.distributed的更多动机,请分别参阅Why Dask?和the dask.distributed documentation。
先决条件¶
除了Kedro流水线已有的要求外,唯一需要额外满足的条件是安装dask.distributed。要查看完整的安装说明(包括如何设置Python虚拟环境),请参阅我们的入门指南。
如何使用Dask分发你的Kedro管道¶
创建自定义运行器¶
在您的src文件夹中创建一个新的Python包runner,即kedro_tutorial/src/kedro_tutorial/runner/。确保该位置存在__init__.py文件,并添加另一个名为dask_runner.py的文件,该文件将包含您的自定义运行器DaskRunner的实现。DaskRunner将异步提交和监控任务,并显示执行过程中发生的任何错误。
确保runner文件夹中的__init__.py文件包含以下导入和声明:
from .dask_runner import DaskRunner
__all__ = ["DaskRunner"]
将以下脚本内容复制到 dask_runner.py 文件中:
"""``DaskRunner`` is an ``AbstractRunner`` implementation. It can be
used to distribute execution of ``Node``s in the ``Pipeline`` across
a Dask cluster, taking into account the inter-``Node`` dependencies.
"""
from collections import Counter
from itertools import chain
from typing import Any
from distributed import Client, as_completed, worker_client
from kedro.framework.hooks.manager import (
_create_hook_manager,
_register_hooks,
_register_hooks_entry_points,
)
from kedro.framework.project import settings
from kedro.io import AbstractDataset, CatalogProtocol
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node
from kedro.runner import AbstractRunner, run_node
from pluggy import PluginManager
class _DaskDataset(AbstractDataset):
"""``_DaskDataset`` publishes/gets named datasets to/from the Dask
scheduler."""
def __init__(self, name: str):
self._name = name
def _load(self) -> Any:
try:
with worker_client() as client:
return client.get_dataset(self._name)
except ValueError:
# Upon successfully executing the pipeline, the runner loads
# free outputs on the scheduler (as opposed to on a worker).
Client.current().get_dataset(self._name)
def _save(self, data: Any) -> None:
with worker_client() as client:
client.publish_dataset(data, name=self._name, override=True)
def _exists(self) -> bool:
return self._name in Client.current().list_datasets()
def _release(self) -> None:
Client.current().unpublish_dataset(self._name)
def _describe(self) -> dict[str, Any]:
return dict(name=self._name)
class DaskRunner(AbstractRunner):
"""``DaskRunner`` is an ``AbstractRunner`` implementation. It can be
used to distribute execution of ``Node``s in the ``Pipeline`` across
a Dask cluster, taking into account the inter-``Node`` dependencies.
"""
def __init__(self, client_args: dict[str, Any] = {}, is_async: bool = False):
"""Instantiates the runner by creating a ``distributed.Client``.
Args:
client_args: Arguments to pass to the ``distributed.Client``
constructor.
is_async: If True, the node inputs and outputs are loaded and saved
asynchronously with threads. Defaults to False.
"""
super().__init__(is_async=is_async)
Client(**client_args)
def __del__(self):
Client.current().close()
def create_default_dataset(self, ds_name: str) -> _DaskDataset:
"""Factory method for creating the default dataset for the runner.
Args:
ds_name: Name of the missing dataset.
Returns:
An instance of ``_DaskDataset`` to be used for all
unregistered datasets.
"""
return _DaskDataset(ds_name)
@staticmethod
def _run_node(
node: Node,
catalog: CatalogProtocol,
is_async: bool = False,
session_id: str | None = None,
*dependencies: Node,
) -> Node:
"""Run a single `Node` with inputs from and outputs to the `catalog`.
Wraps ``run_node`` to accept the set of ``Node``s that this node
depends on. When ``dependencies`` are futures, Dask ensures that
the upstream node futures are completed before running ``node``.
A ``PluginManager`` instance is created on each worker because the
``PluginManager`` can't be serialised.
Args:
node: The ``Node`` to run.
catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data.
is_async: If True, the node inputs and outputs are loaded and saved
asynchronously with threads. Defaults to False.
session_id: The session id of the pipeline run.
dependencies: The upstream ``Node``s to allow Dask to handle
dependency tracking. Their values are not actually used.
Returns:
The node argument.
"""
hook_manager = _create_hook_manager()
_register_hooks(hook_manager, settings.HOOKS)
_register_hooks_entry_points(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS)
return run_node(node, catalog, hook_manager, is_async, session_id)
def _run(
self,
pipeline: Pipeline,
catalog: CatalogProtocol,
hook_manager: PluginManager | None = None,
session_id: str | None = None,
) -> None:
"""Implementation of the abstract interface for running the pipelines.
Args:
pipeline: The ``Pipeline`` to run.
catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data.
hook_manager: The ``PluginManager`` to activate hooks.
session_id: The id of the session.
"""
nodes = pipeline.nodes
load_counts = Counter(chain.from_iterable(n.inputs for n in nodes))
node_dependencies = pipeline.node_dependencies
node_futures = {}
client = Client.current()
for node in nodes:
dependencies = (
node_futures[dependency] for dependency in node_dependencies[node]
)
node_futures[node] = client.submit(
DaskRunner._run_node,
node,
catalog,
self._is_async,
session_id,
*dependencies,
)
for i, (_, node) in enumerate(
as_completed(node_futures.values(), with_results=True)
):
self._logger.info("Completed node: %s", node.name)
self._logger.info("Completed %d out of %d tasks", i + 1, len(nodes))
# Decrement load counts, and release any datasets we
# have finished with. This is particularly important
# for the shared, default datasets we created above.
for dataset in node.inputs:
load_counts[dataset] -= 1
if load_counts[dataset] < 1 and dataset not in pipeline.inputs():
catalog.release(dataset)
for dataset in node.outputs:
if load_counts[dataset] < 1 and dataset not in pipeline.outputs():
catalog.release(dataset)
def run_only_missing(
self, pipeline: Pipeline, catalog: CatalogProtocol, hook_manager: PluginManager
) -> dict[str, Any]:
"""Run only the missing outputs from the ``Pipeline`` using the
datasets provided by ``catalog``, and save results back to the
same objects.
Args:
pipeline: The ``Pipeline`` to run.
catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data.
hook_manager: The ``PluginManager`` to activate hooks.
Raises:
ValueError: Raised when ``Pipeline`` inputs cannot be
satisfied.
Returns:
Any node outputs that cannot be processed by the
catalog. These are returned in a dictionary, where
the keys are defined by the node outputs.
"""
free_outputs = pipeline.outputs() - set(catalog.list())
missing = {ds for ds in catalog.list() if not catalog.exists(ds)}
to_build = free_outputs | missing
to_rerun = pipeline.only_nodes_with_outputs(*to_build) + pipeline.from_inputs(
*to_build
)
# We also need any missing datasets that are required to run the
# `to_rerun` pipeline, including any chains of missing datasets.
unregistered_ds = pipeline.datasets() - set(catalog.list())
# Some of the unregistered datasets could have been published to
# the scheduler in a previous run, so we need not recreate them.
missing_unregistered_ds = {
ds_name
for ds_name in unregistered_ds
if not self.create_default_dataset(ds_name).exists()
}
output_to_unregistered = pipeline.only_nodes_with_outputs(
*missing_unregistered_ds
)
input_from_unregistered = to_rerun.inputs() & missing_unregistered_ds
to_rerun += output_to_unregistered.to_outputs(*input_from_unregistered)
# We need to add any previously-published, unregistered datasets
# to the catalog passed to the `run` method, so that it does not
# think that the `to_rerun` pipeline's inputs are not satisfied.
catalog = catalog.shallow_copy()
for ds_name in unregistered_ds - missing_unregistered_ds:
catalog.add(ds_name, self.create_default_dataset(ds_name))
return self.run(to_rerun, catalog)
def _get_executor(self, max_workers):
# Run sequentially
return None
更新CLI实现¶
快完成了!在使用新运行器之前,您需要在与settings.py同级目录下添加一个cli.py文件,使用我们提供的模板。请更新新建的cli.py文件中的run()函数,确保运行器类被正确实例化:
def run(tag, env, ...):
"""Run the pipeline."""
runner = runner or "SequentialRunner"
tags = tuple(tags)
node_names = tuple(node_names)
with KedroSession.create(env=env, extra_params=params) as session:
context = session.load_context()
runner_instance = _instantiate_runner(runner, is_async, context)
session.run(
tags=tags,
runner=runner_instance,
node_names=node_names,
from_nodes=from_nodes,
to_nodes=to_nodes,
from_inputs=from_inputs,
to_outputs=to_outputs,
load_versions=load_versions,
pipeline_name=pipeline,
)
其中辅助函数 _instantiate_runner() 的实现如下:
def _instantiate_runner(runner, is_async, project_context):
runner_class = load_obj(runner, "kedro.runner")
runner_kwargs = dict(is_async=is_async)
if runner.endswith("DaskRunner"):
client_args = project_context.params.get("dask_client") or {}
runner_kwargs.update(client_args=client_args)
return runner_class(**runner_kwargs)
部署¶
你现在可以触发运行了。无需任何额外配置,底层的Dask Client会在后台创建一个LocalCluster并与之连接:
kedro run --runner=kedro_tutorial.runner.DaskRunner
