进阶教程:创建自定义数据集¶
Kedro supports many datasets 开箱即用地支持多种数据集,但您可能会发现需要创建自定义数据集。例如,您可能需要在流水线中处理专有数据格式或文件系统,或者您可能发现了Kedro尚未支持的特定数据集用例。本教程将解释如何创建自定义数据集来读取和保存图像数据。
AbstractDataset¶
如果您是贡献者并希望提交新的数据集,必须扩展AbstractDataset接口;若计划支持版本控制,则需扩展AbstractVersionedDataset接口。该接口要求子类实现load和save方法,同时提供封装器来增强对应方法的统一错误处理能力。此外还要求子类重写_describe方法,该方法用于记录自定义AbstractDataset实现实例的内部信息。
场景¶
在本示例中,我们使用Kaggle上的宝可梦图像与属性数据集来训练一个模型,使其能够根据外观对给定宝可梦的属性进行分类(例如水系、火系、虫系等)。为训练模型,我们先将PNG格式的宝可梦图像读取为numpy数组,以便在Kedro流水线中进行后续处理。为开箱即用地处理PNG图像,本示例创建了ImageDataset来读取和保存图像数据。
项目设置¶
我们假设您已经安装了Kedro。现在创建一个项目(您可以随意命名项目,但这里我们假设项目仓库名称为kedro-pokemon)。
登录您的Kaggle账户以下载宝可梦数据集,并将其解压到data/01_raw目录下,放在名为pokemon-images-and-types的子文件夹中。该数据集包含一个pokemon.csv文件和一个存放图片的子文件夹。
该数据集将使用Pillow进行通用图像处理功能,以确保它能处理多种不同的图像格式,而不仅仅是PNG。
安装Pillow:
pip install Pillow
如果在安装过程中遇到问题,请查阅Pillow文档。
数据集的结构剖析¶
至少,一个有效的Kedro数据集需要继承基础类AbstractDataset,并为以下抽象方法提供实现:
loadsave_describe
AbstractDataset 泛型类定义了保存数据的输入数据类型和加载数据的输出数据类型。不过这种类型标注是可选的,默认会使用 Any 类型。
AbstractDataset中的_EPHEMERAL布尔属性表示数据集是否为持久化的。例如,对于非持久化的MemoryDataset,该属性被设置为True。默认情况下,_EPHEMERAL被设置为False。
注意
为了遵循Kedro规范,在自定义数据集类的构造函数中,用于指定数据文件/文件夹位置的参数必须命名为filename、filepath或path。
以下是ImageDataset的示例框架:
Click to expand
from typing import Any, Dict
import numpy as np
from kedro.io import AbstractDataset
class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
"""``ImageDataset`` loads / save image data from a given filepath as `numpy` array using Pillow.
Example:
::
>>> ImageDataset(filepath='/img/file/path.png')
"""
def __init__(self, filepath: str):
"""Creates a new instance of ImageDataset to load / save image data at the given filepath.
Args:
filepath: The location of the image file to load / save data.
"""
self._filepath = filepath
def load(self) -> np.ndarray:
"""Loads data from the image file.
Returns:
Data from the image file as a numpy array.
"""
...
def save(self, data: np.ndarray) -> None:
"""Saves image data to the specified filepath"""
...
def _describe(self) -> Dict[str, Any]:
"""Returns a dict that describes the attributes of the dataset"""
...
在src/kedro_pokemon/目录下创建一个名为datasets的子文件夹,用于存储数据集定义文件image_dataset.py,同时添加__init__.py文件使Python将该目录视为可导入的包:
src/kedro_pokemon/datasets
├── __init__.py
└── image_dataset.py
使用fsspec实现load方法¶
许多内置的Kedro数据集依赖fsspec作为统一接口来访问不同数据源,正如前文在数据目录章节所述。在本示例中,结合使用fsspec和Pillow读取图像数据特别方便,因为这使得数据集能够灵活处理不同位置和格式的图像。
以下是使用fsspec和Pillow将单张图像数据读取到numpy数组中的load方法实现:
Click to expand
from pathlib import PurePosixPath
from typing import Any, Dict
import fsspec
import numpy as np
from PIL import Image
from kedro.io import AbstractDataset
from kedro.io.core import get_filepath_str, get_protocol_and_path
class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
def __init__(self, filepath: str):
"""Creates a new instance of ImageDataset to load / save image data for given filepath.
Args:
filepath: The location of the image file to load / save data.
"""
# parse the path and protocol (e.g. file, http, s3, etc.)
protocol, path = get_protocol_and_path(filepath)
self._protocol = protocol
self._filepath = PurePosixPath(path)
self._fs = fsspec.filesystem(self._protocol)
def load(self) -> np.ndarray:
"""Loads data from the image file.
Returns:
Data from the image file as a numpy array
"""
# using get_filepath_str ensures that the protocol and path are appended correctly for different filesystems
load_path = get_filepath_str(self._filepath, self._protocol)
with self._fs.open(load_path) as f:
image = Image.open(f).convert("RGBA")
return np.asarray(image)
...
为了测试这一点,让我们在数据目录中添加一个数据集来加载皮卡丘的图像。
# in conf/base/catalog.yml
pikachu:
type: kedro_pokemon.datasets.image_dataset.ImageDataset
filepath: data/01_raw/pokemon-images-and-types/images/images/pikachu.png
# Note: the duplicated `images` path is part of the original Kaggle dataset
然后通过kedro ipython启动一个IPython会话来预览数据:
# read data image into a numpy array
In [1]: image = context.catalog.load('pikachu')
# then re-show the image using Pillow's Image API.
In [2]: from PIL import Image
In [3]: Image.fromarray(image).show()
使用fsspec实现save方法¶
同样地,我们可以按如下方式实现_save方法:
class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
def save(self, data: np.ndarray) -> None:
"""Saves image data to the specified filepath."""
# using get_filepath_str ensures that the protocol and path are appended correctly for different filesystems
save_path = get_filepath_str(self._filepath, self._protocol)
with self._fs.open(save_path, "wb") as f:
image = Image.fromarray(data)
image.save(f)
让我们在IPython中试试看:
In [1]: image = context.catalog.load('pikachu')
In [2]: context.catalog.save('pikachu', data=image)
您可以打开文件以验证数据是否正确回写。
实现 _describe 方法¶
_describe 方法用于打印输出。在Kedro中的惯例是让该方法返回一个描述数据集属性的字典。
class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
def _describe(self) -> Dict[str, Any]:
"""Returns a dict that describes the attributes of the dataset."""
return dict(filepath=self._filepath, protocol=self._protocol)
完整示例¶
这是我们基础ImageDataset的完整实现:
Click to expand
from pathlib import PurePosixPath
from typing import Any, Dict
import fsspec
import numpy as np
from PIL import Image
from kedro.io import AbstractDataset
from kedro.io.core import get_filepath_str, get_protocol_and_path
class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
"""``ImageDataset`` loads / save image data from a given filepath as `numpy` array using Pillow.
Example:
::
>>> ImageDataset(filepath='/img/file/path.png')
"""
def __init__(self, filepath: str):
"""Creates a new instance of ImageDataset to load / save image data for given filepath.
Args:
filepath: The location of the image file to load / save data.
"""
protocol, path = get_protocol_and_path(filepath)
self._protocol = protocol
self._filepath = PurePosixPath(path)
self._fs = fsspec.filesystem(self._protocol)
def load(self) -> np.ndarray:
"""Loads data from the image file.
Returns:
Data from the image file as a numpy array
"""
load_path = get_filepath_str(self._filepath, self._protocol)
with self._fs.open(load_path, mode="r") as f:
image = Image.open(f).convert("RGBA")
return np.asarray(image)
def save(self, data: np.ndarray) -> None:
"""Saves image data to the specified filepath."""
save_path = get_filepath_str(self._filepath, self._protocol)
with self._fs.open(save_path, mode="wb") as f:
image = Image.fromarray(data)
image.save(f)
def _describe(self) -> Dict[str, Any]:
"""Returns a dict that describes the attributes of the dataset."""
return dict(filepath=self._filepath, protocol=self._protocol)
与 PartitionedDataset¶ 的集成
目前,ImageDataset仅支持处理单张图像,但本示例需要从原始数据目录加载所有宝可梦图像以进行后续处理。
Kedro的PartitionedDataset提供了一种便捷的方式,可以将同一基础数据集类型的多个独立数据文件加载到目录中。
要使用PartitionedDataset与ImageDataset加载所有宝可梦PNG图片,请将此添加到数据目录YAML中,以便PartitionedDataset使用ImageDataset从数据目录加载所有PNG文件:
# in conf/base/catalog.yml
pokemon:
type: partitions.PartitionedDataset
dataset: kedro_pokemon.datasets.image_dataset.ImageDataset
path: data/01_raw/pokemon-images-and-types/images/images
filename_suffix: ".png"
让我们在IPython控制台中尝试一下:
In [1]: images = context.catalog.load('pokemon')
In [2]: len(images)
Out[2]: 721
验证数据目录中.png文件的数量(应为721):
$ ls -la data/01_raw/pokemon-images-and-types/images/images/*.png | wc -l
721
版本控制¶
如何在你的数据集中实现版本控制¶
注意
版本控制不适用于PartitionedDataset。您不能同时使用两者。
要为新数据集添加版本控制支持,我们需要扩展AbstractVersionedDataset来实现:
在构造函数中接受一个
version关键字参数调整
load和save方法,以使用分别从_get_load_path和_get_save_path获取的版本化数据路径
以下是对我们基础ImageDataset完整实现的修改。现在它会在版本化的子文件夹中加载和保存数据(默认情况下,version是一个日期时间格式的字符串YYYY-MM-DDThh.mm.ss.sssZ,路径格式为data/01_raw/pokemon-images-and-types/images/images/pikachu.png/):
Click to expand
from pathlib import PurePosixPath
from typing import Any, Dict
import fsspec
import numpy as np
from PIL import Image
from kedro.io import AbstractVersionedDataset
from kedro.io.core import get_filepath_str, get_protocol_and_path, Version
class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]):
"""``ImageDataset`` loads / save image data from a given filepath as `numpy` array using Pillow.
Example:
::
>>> ImageDataset(filepath='/img/file/path.png')
"""
def __init__(self, filepath: str, version: Version = None):
"""Creates a new instance of ImageDataset to load / save image data for given filepath.
Args:
filepath: The location of the image file to load / save data.
version: The version of the dataset being saved and loaded.
"""
protocol, path = get_protocol_and_path(filepath)
self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol)
super().__init__(
filepath=PurePosixPath(path),
version=version,
exists_function=self._fs.exists,
glob_function=self._fs.glob,
)
def load(self) -> np.ndarray:
"""Loads data from the image file.
Returns:
Data from the image file as a numpy array
"""
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, mode="r") as f:
image = Image.open(f).convert("RGBA")
return np.asarray(image)
def save(self, data: np.ndarray) -> None:
"""Saves image data to the specified filepath."""
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, mode="wb") as f:
image = Image.fromarray(data)
image.save(f)
def _describe(self) -> Dict[str, Any]:
"""Returns a dict that describes the attributes of the dataset."""
return dict(
filepath=self._filepath, version=self._version, protocol=self._protocol
)
原始ImageDataset与版本化ImageDataset的区别如下:
Click to expand
from pathlib import PurePosixPath
from typing import Any, Dict
import fsspec
import numpy as np
from PIL import Image
-from kedro.io import AbstractDataset
-from kedro.io.core import get_filepath_str, get_protocol_and_path
+from kedro.io import AbstractVersionedDataset
+from kedro.io.core import get_filepath_str, get_protocol_and_path, Version
-class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]):
+class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]):
"""``ImageDataset`` loads / save image data from a given filepath as `numpy` array using Pillow.
Example:
::
>>> ImageDataset(filepath='/img/file/path.png')
"""
- def __init__(self, filepath: str):
+ def __init__(self, filepath: str, version: Version = None):
"""Creates a new instance of ImageDataset to load / save image data for given filepath.
Args:
filepath: The location of the image file to load / save data.
+ version: The version of the dataset being saved and loaded.
"""
protocol, path = get_protocol_and_path(filepath)
self._protocol = protocol
- self._filepath = PurePosixPath(path)
self._fs = fsspec.filesystem(self._protocol)
+ super().__init__(
+ filepath=PurePosixPath(path),
+ version=version,
+ exists_function=self._fs.exists,
+ glob_function=self._fs.glob,
+ )
+
def load(self) -> np.ndarray:
"""Loads data from the image file.
Returns:
Data from the image file as a numpy array
"""
- load_path = get_filepath_str(self._filepath, self._protocol)
+ load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, mode="r") as f:
image = Image.open(f).convert("RGBA")
return np.asarray(image)
def save(self, data: np.ndarray) -> None:
"""Saves image data to the specified filepath."""
- save_path = get_filepath_str(self._filepath, self._protocol)
+ save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, mode="wb") as f:
image = Image.fromarray(data)
image.save(f)
def _describe(self) -> Dict[str, Any]:
"""Returns a dict that describes the attributes of the dataset."""
- return dict(filepath=self._filepath, protocol=self._protocol)
+ return dict(
+ filepath=self._filepath, version=self._version, protocol=self._protocol
+ )
要测试代码,您需要在数据目录中启用版本控制支持:
# in conf/base/catalog.yml
pikachu:
type: kedro_pokemon.datasets.image_dataset.ImageDataset
filepath: data/01_raw/pokemon-images-and-types/images/images/pikachu.png
versioned: true
注意
不支持使用基于HTTP(S)的filepath并设置versioned: true。
通过创建示例的第一个版本(例如2020-02-22T00.00.00.000Z)来生成数据的初始版本:
$ mv data/01_raw/pokemon-images-and-types/images/images/pikachu.png data/01_raw/pokemon-images-and-types/images/images/pikachu.png.backup
$ mkdir -p data/01_raw/pokemon-images-and-types/images/images/pikachu.png/2020-02-22T00.00.00.000Z/
$ mv data/01_raw/pokemon-images-and-types/images/images/pikachu.png.backup data/01_raw/pokemon-images-and-types/images/images/pikachu.png/2020-02-22T00.00.00.000Z/pikachu.png
目录结构应如下所示:
data/01_raw/pokemon-images-and-types/images/images/pikachu.png
└── 2020-02-22T00.00.00.000Z/
└── pikachu.png
启动一个IPython shell来测试版本化数据的加载/保存:
# loading works as Kedro automatically find the latest available version inside `pikachu.png` directory
In [1]: img = context.catalog.load('pikachu')
# then saving it should work as well
In [2]: context.catalog.save('pikachu', data=img)
检查数据目录的内容,找到由save写入的新版本数据。
线程安全性¶
Kedro数据集应能与SequentialRunner和ParallelRunner协同工作,因此它们必须能被Python multiprocessing包完全序列化。这意味着您的数据集不应使用lambda函数、嵌套函数、闭包等。如果使用自定义装饰器,需要确保它们使用了functools.wraps()。
有一个数据集是例外:SparkDataset。这个例外的原因是Apache Spark使用自己的并行机制,因此无法与Kedro的ParallelRunner配合使用。在使用Spark的Kedro项目中实现并行,请改用ThreadRunner。
要验证你的数据集是否可以被multiprocessing序列化,可以使用控制台或IPython会话尝试使用multiprocessing.reduction.ForkingPickler进行转储:
dataset = context.catalog._datasets["pokemon"]
from multiprocessing.reduction import ForkingPickler
# the following call shouldn't throw any errors
ForkingPickler.dumps(dataset)
如何处理凭证和不同的文件系统¶
如果您的使用场景需要,Kedro允许您向数据集传递credentials和文件系统特定的fs_args参数。例如,如果宝可梦数据存放在S3存储桶中,我们可以按如下方式将credentials和fs_args添加到数据目录中:
# in conf/base/catalog.yml
pikachu:
type: kedro_pokemon.datasets.image_dataset.ImageDataset
filepath: s3://data/01_raw/pokemon-images-and-types/images/images/pikachu.png
credentials: <your_credentials>
fs_args:
arg_1: <value>
这些参数随后会传递给数据集构造函数,因此您可以将它们与fsspec一起使用:
import fsspec
class ImageDataset(AbstractVersionedDataset):
def __init__(
self,
filepath: str,
version: Version = None,
credentials: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
):
"""Creates a new instance of ImageDataset to load / save image data for given filepath.
Args:
filepath: The location of the image file to load / save data.
version: The version of the dataset being saved and loaded.
credentials: Credentials required to get access to the underlying filesystem.
E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
fs_args: Extra arguments to pass into underlying filesystem class.
E.g. for ``GCSFileSystem`` class: `{"project": "my-project", ...}`.
"""
protocol, path = get_protocol_and_path(filepath)
self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **credentials, **fs_args)
...
我们提供了如何通过数据目录的YAML API使用参数的更多示例。如需了解如何在数据集构造函数中使用这些参数,请参阅SparkDataset的实现。
如何贡献自定义数据集实现¶
为Kedro做贡献最简单的方式之一就是分享自定义数据集。Kedro在kedro-plugins代码库中提供了kedro-datasets包,您可以在其中添加新的自定义数据集实现来与他人分享。更多详情请参阅GitHub上的Kedro贡献指南。
贡献您的自定义数据集:
将您的数据集包添加到
kedro-plugins/kedro-datasets/kedro_datasets/。
例如,在我们的ImageDataset示例中,目录结构应为:
kedro-plugins/kedro-datasets/kedro_datasets/image
├── __init__.py
└── image_dataset.py
如果数据集较为复杂,创建一个
README.md文件来解释其工作原理并记录其API。数据集应附带完整的测试覆盖,位于
kedro-plugins/kedro-datasets/tests/目录下。向Kedro的插件仓库的
main分支提交一个拉取请求。
注意
在贡献数据集时有两个特殊注意事项:
将该数据集添加到
kedro_datasets.rst中,使其显示在API文档中。将数据集添加到
kedro-plugins/kedro-datasets/static/jsonschema/kedro-catalog-X.json以实现IDE验证。