进阶教程:创建自定义数据集

Kedro supports many datasets 开箱即用地支持多种数据集,但您可能会发现需要创建自定义数据集。例如,您可能需要在流水线中处理专有数据格式或文件系统,或者您可能发现了Kedro尚未支持的特定数据集用例。本教程将解释如何创建自定义数据集来读取和保存图像数据。

AbstractDataset

如果您是贡献者并希望提交新的数据集,必须扩展AbstractDataset接口;若计划支持版本控制,则需扩展AbstractVersionedDataset接口。该接口要求子类实现loadsave方法,同时提供封装器来增强对应方法的统一错误处理能力。此外还要求子类重写_describe方法,该方法用于记录自定义AbstractDataset实现实例的内部信息。

场景

在本示例中,我们使用Kaggle上的宝可梦图像与属性数据集来训练一个模型,使其能够根据外观对给定宝可梦的属性进行分类(例如水系、火系、虫系等)。为训练模型,我们先将PNG格式的宝可梦图像读取为numpy数组,以便在Kedro流水线中进行后续处理。为开箱即用地处理PNG图像,本示例创建了ImageDataset来读取和保存图像数据。

项目设置

我们假设您已经安装了Kedro。现在创建一个项目(您可以随意命名项目,但这里我们假设项目仓库名称为kedro-pokemon)。

登录您的Kaggle账户以下载宝可梦数据集,并将其解压到data/01_raw目录下,放在名为pokemon-images-and-types的子文件夹中。该数据集包含一个pokemon.csv文件和一个存放图片的子文件夹。

该数据集将使用Pillow进行通用图像处理功能,以确保它能处理多种不同的图像格式,而不仅仅是PNG。

安装Pillow:

pip install Pillow

如果在安装过程中遇到问题,请查阅Pillow文档

数据集的结构剖析

至少,一个有效的Kedro数据集需要继承基础类AbstractDataset,并为以下抽象方法提供实现:

  • load

  • save

  • _describe

AbstractDataset 泛型类定义了保存数据的输入数据类型和加载数据的输出数据类型。不过这种类型标注是可选的,默认会使用 Any 类型。

AbstractDataset中的_EPHEMERAL布尔属性表示数据集是否为持久化的。例如,对于非持久化的MemoryDataset,该属性被设置为True。默认情况下,_EPHEMERAL被设置为False。

注意

为了遵循Kedro规范,在自定义数据集类的构造函数中,用于指定数据文件/文件夹位置的参数必须命名为filenamefilepathpath

以下是ImageDataset的示例框架:

Click to expand
from typing import Any, Dict

import numpy as np

from kedro.io import AbstractDataset


class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
    """``ImageDataset`` loads / save image data from a given filepath as `numpy` array using Pillow.

    Example:
    ::

        >>> ImageDataset(filepath='/img/file/path.png')
    """

    def __init__(self, filepath: str):
        """Creates a new instance of ImageDataset to load / save image data at the given filepath.

        Args:
            filepath: The location of the image file to load / save data.
        """
        self._filepath = filepath

    def load(self) -> np.ndarray:
        """Loads data from the image file.

        Returns:
            Data from the image file as a numpy array.
        """
        ...

    def save(self, data: np.ndarray) -> None:
        """Saves image data to the specified filepath"""
        ...

    def _describe(self) -> Dict[str, Any]:
        """Returns a dict that describes the attributes of the dataset"""
        ...

src/kedro_pokemon/目录下创建一个名为datasets的子文件夹,用于存储数据集定义文件image_dataset.py,同时添加__init__.py文件使Python将该目录视为可导入的包:

src/kedro_pokemon/datasets
├── __init__.py
└── image_dataset.py

使用fsspec实现load方法

许多内置的Kedro数据集依赖fsspec作为统一接口来访问不同数据源,正如前文在数据目录章节所述。在本示例中,结合使用fsspecPillow读取图像数据特别方便,因为这使得数据集能够灵活处理不同位置和格式的图像。

以下是使用fsspecPillow将单张图像数据读取到numpy数组中的load方法实现:

Click to expand
from pathlib import PurePosixPath
from typing import Any, Dict

import fsspec
import numpy as np
from PIL import Image

from kedro.io import AbstractDataset
from kedro.io.core import get_filepath_str, get_protocol_and_path


class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
    def __init__(self, filepath: str):
        """Creates a new instance of ImageDataset to load / save image data for given filepath.

        Args:
            filepath: The location of the image file to load / save data.
        """
        # parse the path and protocol (e.g. file, http, s3, etc.)
        protocol, path = get_protocol_and_path(filepath)
        self._protocol = protocol
        self._filepath = PurePosixPath(path)
        self._fs = fsspec.filesystem(self._protocol)

    def load(self) -> np.ndarray:
        """Loads data from the image file.

        Returns:
            Data from the image file as a numpy array
        """
        # using get_filepath_str ensures that the protocol and path are appended correctly for different filesystems
        load_path = get_filepath_str(self._filepath, self._protocol)
        with self._fs.open(load_path) as f:
            image = Image.open(f).convert("RGBA")
            return np.asarray(image)

    ...

为了测试这一点,让我们在数据目录中添加一个数据集来加载皮卡丘的图像。

# in conf/base/catalog.yml

pikachu:
  type: kedro_pokemon.datasets.image_dataset.ImageDataset
  filepath: data/01_raw/pokemon-images-and-types/images/images/pikachu.png
  # Note: the duplicated `images` path is part of the original Kaggle dataset

然后通过kedro ipython启动一个IPython会话来预览数据:

# read data image into a numpy array
In [1]: image = context.catalog.load('pikachu')

# then re-show the image using Pillow's Image API.
In [2]: from PIL import Image
In [3]: Image.fromarray(image).show()

使用fsspec实现save方法

同样地,我们可以按如下方式实现_save方法:

class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
    def save(self, data: np.ndarray) -> None:
        """Saves image data to the specified filepath."""
        # using get_filepath_str ensures that the protocol and path are appended correctly for different filesystems
        save_path = get_filepath_str(self._filepath, self._protocol)
        with self._fs.open(save_path, "wb") as f:
            image = Image.fromarray(data)
            image.save(f)

让我们在IPython中试试看:

In [1]: image = context.catalog.load('pikachu')
In [2]: context.catalog.save('pikachu', data=image)

您可以打开文件以验证数据是否正确回写。

实现 _describe 方法

_describe 方法用于打印输出。在Kedro中的惯例是让该方法返回一个描述数据集属性的字典。

class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
    def _describe(self) -> Dict[str, Any]:
        """Returns a dict that describes the attributes of the dataset."""
        return dict(filepath=self._filepath, protocol=self._protocol)

完整示例

这是我们基础ImageDataset的完整实现:

Click to expand
from pathlib import PurePosixPath
from typing import Any, Dict

import fsspec
import numpy as np
from PIL import Image

from kedro.io import AbstractDataset
from kedro.io.core import get_filepath_str, get_protocol_and_path


class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
    """``ImageDataset`` loads / save image data from a given filepath as `numpy` array using Pillow.

    Example:
    ::

        >>> ImageDataset(filepath='/img/file/path.png')
    """

    def __init__(self, filepath: str):
        """Creates a new instance of ImageDataset to load / save image data for given filepath.

        Args:
            filepath: The location of the image file to load / save data.
        """
        protocol, path = get_protocol_and_path(filepath)
        self._protocol = protocol
        self._filepath = PurePosixPath(path)
        self._fs = fsspec.filesystem(self._protocol)

    def load(self) -> np.ndarray:
        """Loads data from the image file.

        Returns:
            Data from the image file as a numpy array
        """
        load_path = get_filepath_str(self._filepath, self._protocol)
        with self._fs.open(load_path, mode="r") as f:
            image = Image.open(f).convert("RGBA")
            return np.asarray(image)

    def save(self, data: np.ndarray) -> None:
        """Saves image data to the specified filepath."""
        save_path = get_filepath_str(self._filepath, self._protocol)
        with self._fs.open(save_path, mode="wb") as f:
            image = Image.fromarray(data)
            image.save(f)

    def _describe(self) -> Dict[str, Any]:
        """Returns a dict that describes the attributes of the dataset."""
        return dict(filepath=self._filepath, protocol=self._protocol)

PartitionedDataset 的集成

目前,ImageDataset仅支持处理单张图像,但本示例需要从原始数据目录加载所有宝可梦图像以进行后续处理。

Kedro的PartitionedDataset提供了一种便捷的方式,可以将同一基础数据集类型的多个独立数据文件加载到目录中。

要使用PartitionedDatasetImageDataset加载所有宝可梦PNG图片,请将此添加到数据目录YAML中,以便PartitionedDataset使用ImageDataset从数据目录加载所有PNG文件:

# in conf/base/catalog.yml

pokemon:
  type: partitions.PartitionedDataset
  dataset: kedro_pokemon.datasets.image_dataset.ImageDataset
  path: data/01_raw/pokemon-images-and-types/images/images
  filename_suffix: ".png"

让我们在IPython控制台中尝试一下:

In [1]: images = context.catalog.load('pokemon')
In [2]: len(images)
Out[2]: 721

验证数据目录中.png文件的数量(应为721):

$ ls -la data/01_raw/pokemon-images-and-types/images/images/*.png | wc -l
    721

版本控制

如何在你的数据集中实现版本控制

注意

版本控制不适用于PartitionedDataset。您不能同时使用两者。

要为新数据集添加版本控制支持,我们需要扩展AbstractVersionedDataset来实现:

  • 在构造函数中接受一个version关键字参数

  • 调整loadsave方法,以使用分别从_get_load_path_get_save_path获取的版本化数据路径

以下是对我们基础ImageDataset完整实现的修改。现在它会在版本化的子文件夹中加载和保存数据(默认情况下,version是一个日期时间格式的字符串YYYY-MM-DDThh.mm.ss.sssZ,路径格式为data/01_raw/pokemon-images-and-types/images/images/pikachu.png//pikachu.png):

Click to expand
from pathlib import PurePosixPath
from typing import Any, Dict

import fsspec
import numpy as np
from PIL import Image

from kedro.io import AbstractVersionedDataset
from kedro.io.core import get_filepath_str, get_protocol_and_path, Version


class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]):
    """``ImageDataset`` loads / save image data from a given filepath as `numpy` array using Pillow.

    Example:
    ::

        >>> ImageDataset(filepath='/img/file/path.png')
    """

    def __init__(self, filepath: str, version: Version = None):
        """Creates a new instance of ImageDataset to load / save image data for given filepath.

        Args:
            filepath: The location of the image file to load / save data.
            version: The version of the dataset being saved and loaded.
        """
        protocol, path = get_protocol_and_path(filepath)
        self._protocol = protocol
        self._fs = fsspec.filesystem(self._protocol)

        super().__init__(
            filepath=PurePosixPath(path),
            version=version,
            exists_function=self._fs.exists,
            glob_function=self._fs.glob,
        )

    def load(self) -> np.ndarray:
        """Loads data from the image file.

        Returns:
            Data from the image file as a numpy array
        """
        load_path = get_filepath_str(self._get_load_path(), self._protocol)
        with self._fs.open(load_path, mode="r") as f:
            image = Image.open(f).convert("RGBA")
            return np.asarray(image)

    def save(self, data: np.ndarray) -> None:
        """Saves image data to the specified filepath."""
        save_path = get_filepath_str(self._get_save_path(), self._protocol)
        with self._fs.open(save_path, mode="wb") as f:
            image = Image.fromarray(data)
            image.save(f)

    def _describe(self) -> Dict[str, Any]:
        """Returns a dict that describes the attributes of the dataset."""
        return dict(
            filepath=self._filepath, version=self._version, protocol=self._protocol
        )

原始ImageDataset与版本化ImageDataset的区别如下:

Click to expand
 from pathlib import PurePosixPath
 from typing import Any, Dict

 import fsspec
 import numpy as np
 from PIL import Image

-from kedro.io import AbstractDataset
-from kedro.io.core import get_filepath_str, get_protocol_and_path
+from kedro.io import AbstractVersionedDataset
+from kedro.io.core import get_filepath_str, get_protocol_and_path, Version


-class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
+class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]):
     """``ImageDataset`` loads / save image data from a given filepath as `numpy` array using Pillow.

     Example:
     ::

         >>> ImageDataset(filepath='/img/file/path.png')
     """

-    def __init__(self, filepath: str):
+    def __init__(self, filepath: str, version: Version = None):
         """Creates a new instance of ImageDataset to load / save image data for given filepath.

         Args:
             filepath: The location of the image file to load / save data.
+            version: The version of the dataset being saved and loaded.
         """
         protocol, path = get_protocol_and_path(filepath)
         self._protocol = protocol
-        self._filepath = PurePosixPath(path)
         self._fs = fsspec.filesystem(self._protocol)

+        super().__init__(
+            filepath=PurePosixPath(path),
+            version=version,
+            exists_function=self._fs.exists,
+            glob_function=self._fs.glob,
+        )
+
     def load(self) -> np.ndarray:
         """Loads data from the image file.

         Returns:
             Data from the image file as a numpy array
         """
-        load_path = get_filepath_str(self._filepath, self._protocol)
+        load_path = get_filepath_str(self._get_load_path(), self._protocol)
         with self._fs.open(load_path, mode="r") as f:
             image = Image.open(f).convert("RGBA")
             return np.asarray(image)

     def save(self, data: np.ndarray) -> None:
         """Saves image data to the specified filepath."""
-        save_path = get_filepath_str(self._filepath, self._protocol)
+        save_path = get_filepath_str(self._get_save_path(), self._protocol)
         with self._fs.open(save_path, mode="wb") as f:
             image = Image.fromarray(data)
             image.save(f)

     def _describe(self) -> Dict[str, Any]:
         """Returns a dict that describes the attributes of the dataset."""
-        return dict(filepath=self._filepath, protocol=self._protocol)
+        return dict(
+            filepath=self._filepath, version=self._version, protocol=self._protocol
+        )

要测试代码,您需要在数据目录中启用版本控制支持:

# in conf/base/catalog.yml

pikachu:
  type: kedro_pokemon.datasets.image_dataset.ImageDataset
  filepath: data/01_raw/pokemon-images-and-types/images/images/pikachu.png
  versioned: true

注意

不支持使用基于HTTP(S)的filepath并设置versioned: true

通过创建示例的第一个版本(例如2020-02-22T00.00.00.000Z)来生成数据的初始版本:

$ mv data/01_raw/pokemon-images-and-types/images/images/pikachu.png data/01_raw/pokemon-images-and-types/images/images/pikachu.png.backup
$ mkdir -p data/01_raw/pokemon-images-and-types/images/images/pikachu.png/2020-02-22T00.00.00.000Z/
$ mv data/01_raw/pokemon-images-and-types/images/images/pikachu.png.backup data/01_raw/pokemon-images-and-types/images/images/pikachu.png/2020-02-22T00.00.00.000Z/pikachu.png

目录结构应如下所示:

data/01_raw/pokemon-images-and-types/images/images/pikachu.png
└── 2020-02-22T00.00.00.000Z/
    └── pikachu.png

启动一个IPython shell来测试版本化数据的加载/保存:

# loading works as Kedro automatically find the latest available version inside `pikachu.png` directory
In [1]: img = context.catalog.load('pikachu')
# then saving it should work as well
In [2]: context.catalog.save('pikachu', data=img)

检查数据目录的内容,找到由save写入的新版本数据。

线程安全性

Kedro数据集应能与SequentialRunnerParallelRunner协同工作,因此它们必须能被Python multiprocessing包完全序列化。这意味着您的数据集不应使用lambda函数、嵌套函数、闭包等。如果使用自定义装饰器,需要确保它们使用了functools.wraps()

有一个数据集是例外:SparkDataset。这个例外的原因是Apache Spark使用自己的并行机制,因此无法与Kedro的ParallelRunner配合使用。在使用Spark的Kedro项目中实现并行,请改用ThreadRunner

要验证你的数据集是否可以被multiprocessing序列化,可以使用控制台或IPython会话尝试使用multiprocessing.reduction.ForkingPickler进行转储:

dataset = context.catalog._datasets["pokemon"]
from multiprocessing.reduction import ForkingPickler

# the following call shouldn't throw any errors
ForkingPickler.dumps(dataset)

如何处理凭证和不同的文件系统

如果您的使用场景需要,Kedro允许您向数据集传递credentials和文件系统特定的fs_args参数。例如,如果宝可梦数据存放在S3存储桶中,我们可以按如下方式将credentialsfs_args添加到数据目录中:

# in conf/base/catalog.yml

pikachu:
  type: kedro_pokemon.datasets.image_dataset.ImageDataset
  filepath: s3://data/01_raw/pokemon-images-and-types/images/images/pikachu.png
  credentials: <your_credentials>
  fs_args:
    arg_1: <value>

这些参数随后会传递给数据集构造函数,因此您可以将它们与fsspec一起使用:

import fsspec


class ImageDataset(AbstractVersionedDataset):
    def __init__(
        self,
        filepath: str,
        version: Version = None,
        credentials: Dict[str, Any] = None,
        fs_args: Dict[str, Any] = None,
    ):
        """Creates a new instance of ImageDataset to load / save image data for given filepath.

        Args:
            filepath: The location of the image file to load / save data.
            version: The version of the dataset being saved and loaded.
            credentials: Credentials required to get access to the underlying filesystem.
                E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
            fs_args: Extra arguments to pass into underlying filesystem class.
                E.g. for ``GCSFileSystem`` class: `{"project": "my-project", ...}`.
        """
        protocol, path = get_protocol_and_path(filepath)
        self._protocol = protocol
        self._fs = fsspec.filesystem(self._protocol, **credentials, **fs_args)

    ...

我们提供了如何通过数据目录的YAML API使用参数的更多示例。如需了解如何在数据集构造函数中使用这些参数,请参阅SparkDataset的实现。

如何贡献自定义数据集实现

为Kedro做贡献最简单的方式之一就是分享自定义数据集。Kedro在kedro-plugins代码库中提供了kedro-datasets包,您可以在其中添加新的自定义数据集实现来与他人分享。更多详情请参阅GitHub上的Kedro贡献指南

贡献您的自定义数据集:

  1. 将您的数据集包添加到 kedro-plugins/kedro-datasets/kedro_datasets/

例如,在我们的ImageDataset示例中,目录结构应为:

kedro-plugins/kedro-datasets/kedro_datasets/image
├── __init__.py
└── image_dataset.py
  1. 如果数据集较为复杂,创建一个README.md文件来解释其工作原理并记录其API。

  2. 数据集应附带完整的测试覆盖,位于kedro-plugins/kedro-datasets/tests/目录下。

  3. Kedro的插件仓库main分支提交一个拉取请求。

注意

在贡献数据集时有两个特殊注意事项:

  1. 将该数据集添加到kedro_datasets.rst中,使其显示在API文档中。

  2. 将数据集添加到 kedro-plugins/kedro-datasets/static/jsonschema/kedro-catalog-X.json 以实现IDE验证。