PyTorch 数据集/加载器

%load_ext autoreload
%autoreload 2

时间序列的 Torch 数据集

from fastcore.test import test_eq
from nbdev.showdoc import show_doc
from neuralforecast.utils import generate_series
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import List, Optional, Sequence, Union

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import utilsforecast.processing as ufp
from torch.utils.data import Dataset, DataLoader
from utilsforecast.compat import DataFrame, pl_Series
class TimeSeriesLoader(DataLoader):
    """TimeSeriesLoader DataLoader.
    [Source code](https://github.com/Nixtla/neuralforecast1/blob/main/neuralforecast/tsdataset.py).

    Small change to PyTorch's Data loader. 
    Combines a dataset and a sampler, and provides an iterable over the given dataset.

    The class `~torch.utils.data.DataLoader` supports both map-style and
    iterable-style datasets with single- or multi-process loading, customizing
    loading order and optional automatic batching (collation) and memory pinning.    
    
    **Parameters:**<br>
    `batch_size`: (int, optional): how many samples per batch to load (default: 1).<br>
    `shuffle`: (bool, optional): set to `True` to have the data reshuffled at every epoch (default: `False`).<br>
    `sampler`: (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset.<br>
                Can be any `Iterable` with `__len__` implemented. If specified, `shuffle` must not be specified.<br>
    """
    def __init__(self, dataset, **kwargs):
        if 'collate_fn' in kwargs:
            kwargs.pop('collate_fn')
        kwargs_ = {**kwargs, **dict(collate_fn=self._collate_fn)}
        DataLoader.__init__(self, dataset=dataset, **kwargs_)
    
    def _collate_fn(self, batch):
        elem = batch[0]
        elem_type = type(elem)

        if isinstance(elem, torch.Tensor):
            out = None
            if torch.utils.data.get_worker_info() is not None:
                # If we're in a background process, concatenate directly into a
                # 共享内存张量以避免额外的复制操作
                numel = sum(x.numel() for x in batch)
                storage = elem.storage()._new_shared(numel, device=elem.device)
                out = elem.new(storage).resize_(len(batch), *list(elem.size()))
            return torch.stack(batch, 0, out=out)

        elif isinstance(elem, Mapping):
            if elem['static'] is None:
                return dict(temporal=self.collate_fn([d['temporal'] for d in batch]),
                            temporal_cols = elem['temporal_cols'],
                            y_idx=elem['y_idx'])
            
            return dict(static=self.collate_fn([d['static'] for d in batch]),
                        static_cols = elem['static_cols'],
                        temporal=self.collate_fn([d['temporal'] for d in batch]),
                        temporal_cols = elem['temporal_cols'],
                        y_idx=elem['y_idx'])

        raise TypeError(f'Unknown {elem_type}')
show_doc(TimeSeriesLoader)
class BaseTimeSeriesDataset(Dataset):

    def __init__(self,
                 temporal_cols,
                 max_size: int,
                 min_size: int,
                 y_idx: int,
                 static=None,
                 static_cols=None,
                 sorted=False,
                ):
        super().__init__()
        self.temporal_cols = pd.Index(list(temporal_cols))

        if static is not None:
            self.static = self._as_torch_copy(static)
            self.static_cols = static_cols
        else:
            self.static = static
            self.static_cols = static_cols

        self.max_size = max_size
        self.min_size = min_size
        self.y_idx = y_idx

        # 更新标志。为确保一致性,数据集只能更新一次
        self.updated = False
        self.sorted = sorted

    def __len__(self):
        return self.n_groups

    def _as_torch_copy(
        self,
        x: Union[np.ndarray, torch.Tensor],
        dtype: torch.dtype = torch.float32,
    ) -> torch.Tensor:
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x)
        return x.to(dtype, copy=False).clone()
    
    @staticmethod
    def _ensure_available_mask(data: np.ndarray, temporal_cols):
        if 'available_mask' not in temporal_cols:
            available_mask = np.ones((len(data),1), dtype=np.float32)
            temporal_cols = temporal_cols.append(pd.Index(['available_mask']))
            data = np.append(data, available_mask, axis=1)
        return data, temporal_cols
    
    @staticmethod
    def _extract_static_features(static_df, sort_df, id_col):
        if static_df is not None:
            if isinstance(static_df, pd.DataFrame) and static_df.index.name == id_col:
                warnings.warn(
                    "Passing the id as index is deprecated, please provide it as a column instead.",
                    FutureWarning,
                )
            if sort_df:
                static_df = ufp.sort(static_df, by=id_col)

            static_cols = [col for col in static_df.columns if col != id_col]
            static = ufp.to_numpy(static_df[static_cols])
            static_cols = pd.Index(static_cols)
        else:
            static = None
            static_cols = None
        return static, static_cols
class TimeSeriesDataset(BaseTimeSeriesDataset):

    def __init__(self,
                 temporal,
                 temporal_cols,
                 indptr,
                 max_size: int,
                 min_size: int,
                 y_idx: int,
                 static=None,
                 static_cols=None,
                 sorted=False,
                ):
        super().__init__(
                temporal_cols=temporal_cols,
                max_size=max_size,
                min_size=min_size,
                y_idx=y_idx,
                static=static,
                static_cols=static_cols,
                sorted=sorted
            )
        self.temporal = self._as_torch_copy(temporal)
        self.indptr = indptr
        self.n_groups = self.indptr.size - 1

    def __getitem__(self, idx):
        if isinstance(idx, int):
            # 解析时间数据并填充其左侧
            temporal = torch.zeros(size=(len(self.temporal_cols), self.max_size),
                                   dtype=torch.float32)
            ts = self.temporal[self.indptr[idx] : self.indptr[idx + 1], :]
            temporal[:len(self.temporal_cols), -len(ts):] = ts.permute(1, 0)

            # 如有静态数据,请添加。
            static = None if self.static is None else self.static[idx,:]

            item = dict(temporal=temporal, temporal_cols=self.temporal_cols,
                        static=static, static_cols=self.static_cols,
                        y_idx=self.y_idx)

            return item
        raise ValueError(f'idx must be int, got {type(idx)}')

    def __repr__(self):
        return f'TimeSeriesDataset(n_data={self.temporal.shape[0]:,}, n_groups={self.n_groups:,})'

    def __eq__(self, other):
        if not hasattr(other, 'data') or not hasattr(other, 'indptr'):
            return False
        return np.allclose(self.data, other.data) and np.array_equal(self.indptr, other.indptr)

    def align(self, df: DataFrame, id_col: str, time_col: str, target_col: str) -> 'TimeSeriesDataset':
        # 保护一致性
        df = ufp.copy_if_pandas(df, deep=False)

        # 在缺失的列中添加Nones(不使用available_mask)
        temporal_cols = self.temporal_cols.copy()
        for col in temporal_cols:
            if col not in df.columns:
                df = ufp.assign_columns(df, col, np.nan)
            if col == 'available_mask':
                df = ufp.assign_columns(df, col, 1.0)
        
        # 对列进行排序以匹配 self.temporal_cols(不包括 available_mask)
        df = df[ [id_col, time_col] + temporal_cols.tolist() ]

        # 处理 future_df
        dataset, *_ = TimeSeriesDataset.from_df(
            df=df,
            sort_df=self.sorted,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
        )
        return dataset

    def append(self, futr_dataset: 'TimeSeriesDataset') -> 'TimeSeriesDataset':
        """将未来的观测值添加到数据集中。返回一个副本"""
        if self.indptr.size != futr_dataset.indptr.size:
            raise ValueError('Cannot append `futr_dataset` with different number of groups.')
        # 定义并填充包含更新信息的新时间表
        len_temporal, col_temporal = self.temporal.shape
        len_futr = futr_dataset.temporal.shape[0]
        new_temporal = torch.empty(size=(len_temporal + len_futr, col_temporal))
        new_indptr = self.indptr + futr_dataset.indptr
        new_sizes = np.diff(new_indptr)
        new_min_size = np.min(new_sizes)
        new_max_size = np.max(new_sizes)

        for i in range(self.n_groups):
            curr_slice = slice(self.indptr[i], self.indptr[i + 1])
            curr_size = curr_slice.stop - curr_slice.start
            futr_slice = slice(futr_dataset.indptr[i], futr_dataset.indptr[i + 1])
            new_temporal[new_indptr[i] : new_indptr[i] + curr_size] = self.temporal[curr_slice]
            new_temporal[new_indptr[i] + curr_size : new_indptr[i + 1]] = futr_dataset.temporal[futr_slice]
        
        # 定义新数据集
        return TimeSeriesDataset(
            temporal=new_temporal,
            temporal_cols=self.temporal_cols.copy(),
            indptr=new_indptr,
            max_size=new_max_size,
            min_size=new_min_size,
            static=self.static,
            y_idx=self.y_idx,
            static_cols=self.static_cols,
            sorted=self.sorted
        )

    @staticmethod
    def update_dataset(dataset, futr_df, id_col='unique_id', time_col='ds', target_col='y'):
        futr_dataset = dataset.align(
            futr_df, id_col=id_col, time_col=time_col, target_col=target_col
        )
        return dataset.append(futr_dataset)
    
    @staticmethod
    def trim_dataset(dataset, left_trim: int = 0, right_trim: int = 0):
        """
        从数据集中去除时间信息。
        返回所有序列的时间索引 [t+left:t-right]。
        """
        if dataset.min_size <= left_trim + right_trim:
            raise Exception(f'left_trim + right_trim ({left_trim} + {right_trim}) \
                                must be lower than the shorter time series ({dataset.min_size})')

        # 定义并填充新的时间信息,去除冗余内容        
        len_temporal, col_temporal = dataset.temporal.shape
        total_trim = (left_trim + right_trim) * dataset.n_groups
        new_temporal = torch.zeros(size=(len_temporal-total_trim, col_temporal))
        new_indptr = [0]

        acum = 0
        for i in range(dataset.n_groups):
            series_length = dataset.indptr[i + 1] - dataset.indptr[i]
            new_length = series_length - left_trim - right_trim
            new_temporal[acum:(acum+new_length), :] = dataset.temporal[dataset.indptr[i]+left_trim : \
                                                                       dataset.indptr[i + 1]-right_trim, :]
            acum += new_length
            new_indptr.append(acum)

        new_max_size = dataset.max_size-left_trim-right_trim
        new_min_size = dataset.min_size-left_trim-right_trim
        
        # 定义新数据集
        updated_dataset = TimeSeriesDataset(temporal=new_temporal,
                                            temporal_cols= dataset.temporal_cols.copy(),
                                            indptr=np.array(new_indptr, dtype=np.int32),
                                            max_size=new_max_size,
                                            min_size=new_min_size,
                                            y_idx=dataset.y_idx,
                                            static=dataset.static,
                                            static_cols=dataset.static_cols,
                                            sorted=dataset.sorted)

        return updated_dataset

    @staticmethod
    def from_df(df, static_df=None, sort_df=False, id_col='unique_id', time_col='ds', target_col='y'):
        # 待办事项:保护静态数据框(static_df)与动态数据框(df)索引的相等性
        if isinstance(df, pd.DataFrame) and df.index.name == id_col:
            warnings.warn(
                "Passing the id as index is deprecated, please provide it as a column instead.",
                FutureWarning,
            )
            df = df.reset_index(id_col)

        # 如果没有定义索引,则先定义索引,然后提取静态特征。
        static, static_cols = TimeSeriesDataset._extract_static_features(static_df, sort_df, id_col)
        
        ids, times, data, indptr, sort_idxs = ufp.process_df(df, id_col, time_col, target_col)
        # 处理器将y设为第一列
        temporal_cols = pd.Index(
            [target_col] + [c for c in df.columns if c not in (id_col, time_col, target_col)]
        )
        temporal = data.astype(np.float32, copy=False)
        indices = ids
        if isinstance(df, pd.DataFrame):
            dates = pd.Index(times, name=time_col)
        else:
            dates = pl_Series(time_col, times)
        sizes = np.diff(indptr)
        max_size = max(sizes)
        min_size = min(sizes)

        # 高效添加可用掩码(不向数据框添加列)
        temporal, temporal_cols = TimeSeriesDataset._ensure_available_mask(data, temporal_cols)

        dataset = TimeSeriesDataset(
            temporal=temporal,
            temporal_cols=temporal_cols,
            static=static,
            static_cols=static_cols,
            indptr=indptr,
            max_size=max_size,
            min_size=min_size,
            sorted=sort_df,
            y_idx=0,
        )
        ds = df[time_col].to_numpy()
        if sort_idxs is not None:
            ds = ds[sort_idxs]
        return dataset, indices, dates, ds
class _FilesDataset:
    def __init__(
        self,
        files: Sequence[str],
        temporal_cols,
        id_col: str,
        time_col: str,
        target_col: str,
        min_size: int,
        static_cols: Optional[List[str]] = None,
    ):
        self.files = files
        self.temporal_cols = pd.Index(temporal_cols)
        self.static_cols = pd.Index(static_cols) if static_cols is not None else None
        self.id_col = id_col
        self.time_col = time_col
        self.target_col = target_col
        self.min_size = min_size
class LocalFilesTimeSeriesDataset(BaseTimeSeriesDataset):

    def __init__(self,
                 files_ds: List[str],
                 temporal_cols,
                 id_col: str,
                 time_col: str,
                 target_col: str,
                 last_times,
                 indices,
                 max_size: int, 
                 min_size: int, 
                 y_idx: int,
                 static=None,
                 static_cols=None,
                 sorted=False,
                ):
        super().__init__(
                temporal_cols=temporal_cols,
                max_size=max_size,
                min_size=min_size,
                y_idx=y_idx,
                static=static,
                static_cols=static_cols,
                sorted=sorted
            )
        self.files_ds = files_ds
        self.id_col = id_col
        self.time_col = time_col
        self.target_col = target_col
        #包含每个时间序列的最后时间的数组
        self.last_times = last_times
        self.indices = indices
        self.n_groups = len(files_ds)

    def __getitem__(self, idx):
        if not isinstance(idx, int):
            raise ValueError(f'idx must be int, got {type(idx)}')
        
        temporal_cols = self.temporal_cols.copy()
        data = pd.read_parquet(self.files_ds[idx], columns=temporal_cols.tolist()).to_numpy()
        data, temporal_cols = TimeSeriesDataset._ensure_available_mask(data, temporal_cols)
        data = self._as_torch_copy(data)

        # 将时间数据向左填充
        temporal = torch.zeros(size=(len(temporal_cols), self.max_size),
                                dtype=torch.float32)
        temporal[:len(temporal_cols), -len(data):] = data.permute(1,0)

        # 如有静态数据,请添加。
        static = None if self.static is None else self.static[idx,:]

        item = dict(temporal=temporal, temporal_cols=temporal_cols,
                    static=static, static_cols=self.static_cols,
                    y_idx=self.y_idx)

        return item

    @staticmethod
    def from_data_directories(directories, static_df=None, sort_df=False, exogs=[], id_col='unique_id', time_col='ds', target_col='y'):
        """We expect directories to be a list of directories of the form [unique_id=id_0, unique_id=id_1, ...]. Each directory should contain the timeseries corresponding to that unqiue_id,
        represented as a pandas or polars DataFrame. The timeseries can be entirely contained in one parquet file or split between multiple, but within each parquet files the timeseries should be sorted by time.
        Static df should also be a pandas or polars DataFrame"""
        import pyarrow as pa
        
        # 如果未定义索引,则先定义索引,然后提取静态特征。
        static, static_cols = TimeSeriesDataset._extract_static_features(static_df, sort_df, id_col)
        
        max_size = 0
        min_size = float('inf')
        last_times = []
        ids = []
        expected_temporal = {target_col, *exogs}
        available_mask_seen = True

        for dir in directories:
            dir_path = Path(dir)
            if not dir_path.is_dir():
                raise ValueError(f'paths must be directories, {dir} is not.')
            uid = dir_path.name.split('=')[-1]
            total_rows = 0
            last_time = None
            for file in dir_path.glob('*.parquet'):
                meta = pa.parquet.read_metadata(file)
                rg = meta.row_group(0)
                col2pos = {rg.column(i).path_in_schema: i for i in range(rg.num_columns)}
                
                last_time_file = meta.row_group(meta.num_row_groups -1).column(col2pos[time_col]).statistics.max
                last_time = max(last_time, last_time_file) if last_time is not None else last_time_file
                total_rows += sum(meta.row_group(i).num_rows for i in range(meta.num_row_groups))

                # 检查所有时间列是否存在
                missing_cols = expected_temporal - col2pos.keys()
                if missing_cols:
                    raise ValueError(f"Temporal columns: {missing_cols} not found in the file: {file}.")
                
                if 'available_mask' not in col2pos.keys():
                    available_mask_seen = False
                elif not available_mask_seen:
                    # 如果触发了此条件,available_mask 列将出现在此文件中,但之前文件中并未包含该列。
                    raise ValueError("The available_mask column is present in some files but is missing in others.")
                else:
                    expected_temporal.add("available_mask")

            max_size = max(total_rows, max_size)
            min_size = min(total_rows, min_size)
            ids.append(uid)
            last_times.append(last_time)

        last_times = pd.Index(last_times, name=time_col)
        ids = pd.Series(ids, name=id_col)

        if "available_mask" in expected_temporal:
            exogs = ["available_mask", *exogs]
        temporal_cols = pd.Index([target_col, *exogs])

        dataset = LocalFilesTimeSeriesDataset(
            files_ds=directories,
            temporal_cols=temporal_cols,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            last_times=last_times,
            indices=ids,
            min_size=min_size,
            max_size=max_size,
            y_idx=0,
            static=static,
            static_cols=static_cols,
            sorted=sort_df
        )
        return dataset
show_doc(TimeSeriesDataset)

# 测试sort_df=True功能
temporal_df = generate_series(n_series=1000, 
                         n_temporal_features=0, equal_ends=False)
sorted_temporal_df = temporal_df.sort_values(['unique_id', 'ds'])
unsorted_temporal_df = sorted_temporal_df.sample(frac=1.0)
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=unsorted_temporal_df,
                                                        sort_df=True)

np.testing.assert_allclose(dataset.temporal[:,:-1], 
                           sorted_temporal_df.drop(columns=['unique_id', 'ds']).values)
test_eq(indices, pd.Series(sorted_temporal_df['unique_id'].unique()))
test_eq(dates, temporal_df.groupby('unique_id')['ds'].max().values)
class TimeSeriesDataModule(pl.LightningDataModule):
    
    def __init__(
            self, 
            dataset: BaseTimeSeriesDataset,
            batch_size=32, 
            valid_batch_size=1024,
            num_workers=0,
            drop_last=False,
            shuffle_train=True,
        ):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.valid_batch_size = valid_batch_size
        self.num_workers = num_workers
        self.drop_last = drop_last
        self.shuffle_train = shuffle_train
    
    def train_dataloader(self):
        loader = TimeSeriesLoader(
            self.dataset,
            batch_size=self.batch_size, 
            num_workers=self.num_workers,
            shuffle=self.shuffle_train,
            drop_last=self.drop_last
        )
        return loader
    
    def val_dataloader(self):
        loader = TimeSeriesLoader(
            self.dataset, 
            batch_size=self.valid_batch_size, 
            num_workers=self.num_workers,
            shuffle=False,
            drop_last=self.drop_last
        )
        return loader
    
    def predict_dataloader(self):
        loader = TimeSeriesLoader(
            self.dataset,
            batch_size=self.valid_batch_size, 
            num_workers=self.num_workers,
            shuffle=False
        )
        return loader
show_doc(TimeSeriesDataModule)

batch_size = 128
data = TimeSeriesDataModule(dataset=dataset, 
                            batch_size=batch_size, drop_last=True)
for batch in data.train_dataloader():
    test_eq(batch['temporal'].shape, (batch_size, 2, 500))
    test_eq(batch['temporal_cols'], ['y', 'available_mask'])

batch_size = 128
n_static_features = 2
n_temporal_features = 4
temporal_df, static_df = generate_series(n_series=1000,
                                         n_static_features=n_static_features,
                                         n_temporal_features=n_temporal_features, 
                                         equal_ends=False)

dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,
                                                        static_df=static_df,
                                                        sort_df=True)
data = TimeSeriesDataModule(dataset=dataset,
                            batch_size=batch_size, drop_last=True)

for batch in data.train_dataloader():
    test_eq(batch['temporal'].shape, (batch_size, n_temporal_features + 2, 500))
    test_eq(batch['temporal_cols'],
            ['y'] + [f'temporal_{i}' for i in range(n_temporal_features)] + ['available_mask'])
    
    test_eq(batch['static'].shape, (batch_size, n_static_features))
    test_eq(batch['static_cols'], [f'static_{i}' for i in range(n_static_features)])
# 隐藏

# 测试sort_df=True功能
temporal_df = generate_series(n_series=2,
                              n_temporal_features=2, equal_ends=True)
temporal_df = temporal_df.groupby('unique_id').tail(10)
temporal_df = temporal_df.reset_index()
temporal_full_df = temporal_df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
temporal_full_df.loc[temporal_full_df.ds > '2001-05-11', ['y', 'temporal_0']] = None

split1_df = temporal_full_df.loc[temporal_full_df.ds <= '2001-05-11']
split2_df = temporal_full_df.loc[temporal_full_df.ds > '2001-05-11']

# 测试可用面罩
temporal_df_w_mask = temporal_df.copy()
temporal_df_w_mask['available_mask'] = 1

# Mask with all 1's
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,
                                                        sort_df=True)
mask_average = dataset.temporal[:, -1].mean()
np.testing.assert_almost_equal(mask_average, 1.0000)

# Add 0's to available mask
temporal_df_w_mask.loc[temporal_df_w_mask.ds > '2001-05-11', 'available_mask'] = 0
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,
                                                        sort_df=True)
mask_average = dataset.temporal[:, -1].mean()
np.testing.assert_almost_equal(mask_average, 0.7000)

# 可用面具不在最后一列
temporal_df_w_mask = temporal_df_w_mask[['unique_id','ds','y','available_mask', 'temporal_0','temporal_1']]
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,
                                                        sort_df=True)
mask_average = dataset.temporal[:, 1].mean()
np.testing.assert_almost_equal(mask_average, 0.7000)
# 为了测试`update_df`方法对`future_df`的正确处理
# 我们正在检查是否能够恢复AirPassengers数据集。
# 使用数据框或将其拆分为多个部分并进行初始化。

# 完整数据集
dataset_full, indices_full, dates_full, ds_full = TimeSeriesDataset.from_df(df=temporal_full_df,
                                                                            sort_df=False)

# SPLIT_1 数据集
dataset_1, indices_1, dates_1, ds_1 = TimeSeriesDataset.from_df(df=split1_df,
                                                                sort_df=False)
dataset_1 = dataset_1.update_dataset(dataset_1, split2_df)

np.testing.assert_almost_equal(dataset_full.temporal.numpy(), dataset_1.temporal.numpy())
test_eq(dataset_full.max_size, dataset_1.max_size)
test_eq(dataset_full.indptr, dataset_1.indptr)

# 测试trim_dataset功能
n_static_features = 0
n_temporal_features = 2
temporal_df = generate_series(n_series=100,
                              min_length=50,
                              max_length=100,
                              n_static_features=n_static_features,
                              n_temporal_features=n_temporal_features, 
                              equal_ends=False)
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,
                                                        static_df=static_df,
                                                        sort_df=True)
left_trim = 10
right_trim = 20
dataset_trimmed = dataset.trim_dataset(dataset, left_trim=left_trim, right_trim=right_trim)

np.testing.assert_almost_equal(dataset.temporal[dataset.indptr[50]+left_trim:dataset.indptr[51]-right_trim].numpy(),
                               dataset_trimmed.temporal[dataset_trimmed.indptr[50]:dataset_trimmed.indptr[51]].numpy())
#| 极地
import polars
#| 极地
temporal_df2 = temporal_df.copy()
for col in ('unique_id', 'temporal_0', 'temporal_1'):
    temporal_df2[col] = temporal_df2[col].cat.codes
temporal_pl = polars.from_pandas(temporal_df2).sample(fraction=1.0)
static_pl = polars.from_pandas(static_df.assign(unique_id=lambda df: df['unique_id'].astype('int64')))
dataset_pl, indices_pl, dates_pl, ds_pl = TimeSeriesDataset.from_df(df=temporal_pl, static_df=static_df, sort_df=True)
for attr in ('static_cols', 'temporal_cols', 'min_size', 'max_size', 'n_groups'):
    test_eq(getattr(dataset, attr), getattr(dataset_pl, attr))
torch.testing.assert_allclose(dataset.temporal, dataset_pl.temporal)
torch.testing.assert_allclose(dataset.static, dataset_pl.static)
pd.testing.assert_series_equal(indices.astype('int64'), indices_pl.to_pandas().astype('int64'))
pd.testing.assert_index_equal(dates, pd.Index(dates_pl, name='ds'))
np.testing.assert_array_equal(ds, ds_pl)
np.testing.assert_array_equal(dataset.indptr, dataset_pl.indptr)
class _DistributedTimeSeriesDataModule(TimeSeriesDataModule):
    def __init__(
        self,
        dataset: _FilesDataset,
        batch_size=32,
        valid_batch_size=1024,
        num_workers=0,
        drop_last=False,
        shuffle_train=True,
    ):
        super(TimeSeriesDataModule, self).__init__()
        self.files_ds = dataset
        self.batch_size = batch_size
        self.valid_batch_size = valid_batch_size
        self.num_workers = num_workers
        self.drop_last = drop_last
        self.shuffle_train = shuffle_train

    def setup(self, stage):
        import torch.distributed as dist

        df = pd.read_parquet(self.files_ds.files[dist.get_rank()])
        if self.files_ds.static_cols is not None:
            static_df = (
                df[[self.files_ds.id_col] + self.files_ds.static_cols.tolist()]
                .groupby(self.files_ds.id_col, observed=True)
                .head(1)
            )
            df = df.drop(columns=self.files_ds.static_cols)
        else:
            static_df = None
        self.dataset, *_ = TimeSeriesDataset.from_df(
            df=df,
            static_df=static_df,
            sort_df=True,
            id_col=self.files_ds.id_col,
            time_col=self.files_ds.time_col,
            target_col=self.files_ds.target_col,
        )

Give us a ⭐ on Github