%load_ext autoreload
%autoreload 2PyTorch 数据集/加载器
时间序列的 Torch 数据集
from fastcore.test import test_eq
from nbdev.showdoc import show_doc
from neuralforecast.utils import generate_seriesimport warnings
from collections.abc import Mapping
from pathlib import Path
from typing import List, Optional, Sequence, Union
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import utilsforecast.processing as ufp
from torch.utils.data import Dataset, DataLoader
from utilsforecast.compat import DataFrame, pl_Seriesclass TimeSeriesLoader(DataLoader):
"""TimeSeriesLoader DataLoader.
[Source code](https://github.com/Nixtla/neuralforecast1/blob/main/neuralforecast/tsdataset.py).
Small change to PyTorch's Data loader.
Combines a dataset and a sampler, and provides an iterable over the given dataset.
The class `~torch.utils.data.DataLoader` supports both map-style and
iterable-style datasets with single- or multi-process loading, customizing
loading order and optional automatic batching (collation) and memory pinning.
**Parameters:**<br>
`batch_size`: (int, optional): how many samples per batch to load (default: 1).<br>
`shuffle`: (bool, optional): set to `True` to have the data reshuffled at every epoch (default: `False`).<br>
`sampler`: (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset.<br>
Can be any `Iterable` with `__len__` implemented. If specified, `shuffle` must not be specified.<br>
"""
def __init__(self, dataset, **kwargs):
if 'collate_fn' in kwargs:
kwargs.pop('collate_fn')
kwargs_ = {**kwargs, **dict(collate_fn=self._collate_fn)}
DataLoader.__init__(self, dataset=dataset, **kwargs_)
def _collate_fn(self, batch):
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# 共享内存张量以避免额外的复制操作
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
elif isinstance(elem, Mapping):
if elem['static'] is None:
return dict(temporal=self.collate_fn([d['temporal'] for d in batch]),
temporal_cols = elem['temporal_cols'],
y_idx=elem['y_idx'])
return dict(static=self.collate_fn([d['static'] for d in batch]),
static_cols = elem['static_cols'],
temporal=self.collate_fn([d['temporal'] for d in batch]),
temporal_cols = elem['temporal_cols'],
y_idx=elem['y_idx'])
raise TypeError(f'Unknown {elem_type}')show_doc(TimeSeriesLoader)class BaseTimeSeriesDataset(Dataset):
def __init__(self,
temporal_cols,
max_size: int,
min_size: int,
y_idx: int,
static=None,
static_cols=None,
sorted=False,
):
super().__init__()
self.temporal_cols = pd.Index(list(temporal_cols))
if static is not None:
self.static = self._as_torch_copy(static)
self.static_cols = static_cols
else:
self.static = static
self.static_cols = static_cols
self.max_size = max_size
self.min_size = min_size
self.y_idx = y_idx
# 更新标志。为确保一致性,数据集只能更新一次
self.updated = False
self.sorted = sorted
def __len__(self):
return self.n_groups
def _as_torch_copy(
self,
x: Union[np.ndarray, torch.Tensor],
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
return x.to(dtype, copy=False).clone()
@staticmethod
def _ensure_available_mask(data: np.ndarray, temporal_cols):
if 'available_mask' not in temporal_cols:
available_mask = np.ones((len(data),1), dtype=np.float32)
temporal_cols = temporal_cols.append(pd.Index(['available_mask']))
data = np.append(data, available_mask, axis=1)
return data, temporal_cols
@staticmethod
def _extract_static_features(static_df, sort_df, id_col):
if static_df is not None:
if isinstance(static_df, pd.DataFrame) and static_df.index.name == id_col:
warnings.warn(
"Passing the id as index is deprecated, please provide it as a column instead.",
FutureWarning,
)
if sort_df:
static_df = ufp.sort(static_df, by=id_col)
static_cols = [col for col in static_df.columns if col != id_col]
static = ufp.to_numpy(static_df[static_cols])
static_cols = pd.Index(static_cols)
else:
static = None
static_cols = None
return static, static_colsclass TimeSeriesDataset(BaseTimeSeriesDataset):
def __init__(self,
temporal,
temporal_cols,
indptr,
max_size: int,
min_size: int,
y_idx: int,
static=None,
static_cols=None,
sorted=False,
):
super().__init__(
temporal_cols=temporal_cols,
max_size=max_size,
min_size=min_size,
y_idx=y_idx,
static=static,
static_cols=static_cols,
sorted=sorted
)
self.temporal = self._as_torch_copy(temporal)
self.indptr = indptr
self.n_groups = self.indptr.size - 1
def __getitem__(self, idx):
if isinstance(idx, int):
# 解析时间数据并填充其左侧
temporal = torch.zeros(size=(len(self.temporal_cols), self.max_size),
dtype=torch.float32)
ts = self.temporal[self.indptr[idx] : self.indptr[idx + 1], :]
temporal[:len(self.temporal_cols), -len(ts):] = ts.permute(1, 0)
# 如有静态数据,请添加。
static = None if self.static is None else self.static[idx,:]
item = dict(temporal=temporal, temporal_cols=self.temporal_cols,
static=static, static_cols=self.static_cols,
y_idx=self.y_idx)
return item
raise ValueError(f'idx must be int, got {type(idx)}')
def __repr__(self):
return f'TimeSeriesDataset(n_data={self.temporal.shape[0]:,}, n_groups={self.n_groups:,})'
def __eq__(self, other):
if not hasattr(other, 'data') or not hasattr(other, 'indptr'):
return False
return np.allclose(self.data, other.data) and np.array_equal(self.indptr, other.indptr)
def align(self, df: DataFrame, id_col: str, time_col: str, target_col: str) -> 'TimeSeriesDataset':
# 保护一致性
df = ufp.copy_if_pandas(df, deep=False)
# 在缺失的列中添加Nones(不使用available_mask)
temporal_cols = self.temporal_cols.copy()
for col in temporal_cols:
if col not in df.columns:
df = ufp.assign_columns(df, col, np.nan)
if col == 'available_mask':
df = ufp.assign_columns(df, col, 1.0)
# 对列进行排序以匹配 self.temporal_cols(不包括 available_mask)
df = df[ [id_col, time_col] + temporal_cols.tolist() ]
# 处理 future_df
dataset, *_ = TimeSeriesDataset.from_df(
df=df,
sort_df=self.sorted,
id_col=id_col,
time_col=time_col,
target_col=target_col,
)
return dataset
def append(self, futr_dataset: 'TimeSeriesDataset') -> 'TimeSeriesDataset':
"""将未来的观测值添加到数据集中。返回一个副本"""
if self.indptr.size != futr_dataset.indptr.size:
raise ValueError('Cannot append `futr_dataset` with different number of groups.')
# 定义并填充包含更新信息的新时间表
len_temporal, col_temporal = self.temporal.shape
len_futr = futr_dataset.temporal.shape[0]
new_temporal = torch.empty(size=(len_temporal + len_futr, col_temporal))
new_indptr = self.indptr + futr_dataset.indptr
new_sizes = np.diff(new_indptr)
new_min_size = np.min(new_sizes)
new_max_size = np.max(new_sizes)
for i in range(self.n_groups):
curr_slice = slice(self.indptr[i], self.indptr[i + 1])
curr_size = curr_slice.stop - curr_slice.start
futr_slice = slice(futr_dataset.indptr[i], futr_dataset.indptr[i + 1])
new_temporal[new_indptr[i] : new_indptr[i] + curr_size] = self.temporal[curr_slice]
new_temporal[new_indptr[i] + curr_size : new_indptr[i + 1]] = futr_dataset.temporal[futr_slice]
# 定义新数据集
return TimeSeriesDataset(
temporal=new_temporal,
temporal_cols=self.temporal_cols.copy(),
indptr=new_indptr,
max_size=new_max_size,
min_size=new_min_size,
static=self.static,
y_idx=self.y_idx,
static_cols=self.static_cols,
sorted=self.sorted
)
@staticmethod
def update_dataset(dataset, futr_df, id_col='unique_id', time_col='ds', target_col='y'):
futr_dataset = dataset.align(
futr_df, id_col=id_col, time_col=time_col, target_col=target_col
)
return dataset.append(futr_dataset)
@staticmethod
def trim_dataset(dataset, left_trim: int = 0, right_trim: int = 0):
"""
从数据集中去除时间信息。
返回所有序列的时间索引 [t+left:t-right]。
"""
if dataset.min_size <= left_trim + right_trim:
raise Exception(f'left_trim + right_trim ({left_trim} + {right_trim}) \
must be lower than the shorter time series ({dataset.min_size})')
# 定义并填充新的时间信息,去除冗余内容
len_temporal, col_temporal = dataset.temporal.shape
total_trim = (left_trim + right_trim) * dataset.n_groups
new_temporal = torch.zeros(size=(len_temporal-total_trim, col_temporal))
new_indptr = [0]
acum = 0
for i in range(dataset.n_groups):
series_length = dataset.indptr[i + 1] - dataset.indptr[i]
new_length = series_length - left_trim - right_trim
new_temporal[acum:(acum+new_length), :] = dataset.temporal[dataset.indptr[i]+left_trim : \
dataset.indptr[i + 1]-right_trim, :]
acum += new_length
new_indptr.append(acum)
new_max_size = dataset.max_size-left_trim-right_trim
new_min_size = dataset.min_size-left_trim-right_trim
# 定义新数据集
updated_dataset = TimeSeriesDataset(temporal=new_temporal,
temporal_cols= dataset.temporal_cols.copy(),
indptr=np.array(new_indptr, dtype=np.int32),
max_size=new_max_size,
min_size=new_min_size,
y_idx=dataset.y_idx,
static=dataset.static,
static_cols=dataset.static_cols,
sorted=dataset.sorted)
return updated_dataset
@staticmethod
def from_df(df, static_df=None, sort_df=False, id_col='unique_id', time_col='ds', target_col='y'):
# 待办事项:保护静态数据框(static_df)与动态数据框(df)索引的相等性
if isinstance(df, pd.DataFrame) and df.index.name == id_col:
warnings.warn(
"Passing the id as index is deprecated, please provide it as a column instead.",
FutureWarning,
)
df = df.reset_index(id_col)
# 如果没有定义索引,则先定义索引,然后提取静态特征。
static, static_cols = TimeSeriesDataset._extract_static_features(static_df, sort_df, id_col)
ids, times, data, indptr, sort_idxs = ufp.process_df(df, id_col, time_col, target_col)
# 处理器将y设为第一列
temporal_cols = pd.Index(
[target_col] + [c for c in df.columns if c not in (id_col, time_col, target_col)]
)
temporal = data.astype(np.float32, copy=False)
indices = ids
if isinstance(df, pd.DataFrame):
dates = pd.Index(times, name=time_col)
else:
dates = pl_Series(time_col, times)
sizes = np.diff(indptr)
max_size = max(sizes)
min_size = min(sizes)
# 高效添加可用掩码(不向数据框添加列)
temporal, temporal_cols = TimeSeriesDataset._ensure_available_mask(data, temporal_cols)
dataset = TimeSeriesDataset(
temporal=temporal,
temporal_cols=temporal_cols,
static=static,
static_cols=static_cols,
indptr=indptr,
max_size=max_size,
min_size=min_size,
sorted=sort_df,
y_idx=0,
)
ds = df[time_col].to_numpy()
if sort_idxs is not None:
ds = ds[sort_idxs]
return dataset, indices, dates, dsclass _FilesDataset:
def __init__(
self,
files: Sequence[str],
temporal_cols,
id_col: str,
time_col: str,
target_col: str,
min_size: int,
static_cols: Optional[List[str]] = None,
):
self.files = files
self.temporal_cols = pd.Index(temporal_cols)
self.static_cols = pd.Index(static_cols) if static_cols is not None else None
self.id_col = id_col
self.time_col = time_col
self.target_col = target_col
self.min_size = min_sizeclass LocalFilesTimeSeriesDataset(BaseTimeSeriesDataset):
def __init__(self,
files_ds: List[str],
temporal_cols,
id_col: str,
time_col: str,
target_col: str,
last_times,
indices,
max_size: int,
min_size: int,
y_idx: int,
static=None,
static_cols=None,
sorted=False,
):
super().__init__(
temporal_cols=temporal_cols,
max_size=max_size,
min_size=min_size,
y_idx=y_idx,
static=static,
static_cols=static_cols,
sorted=sorted
)
self.files_ds = files_ds
self.id_col = id_col
self.time_col = time_col
self.target_col = target_col
#包含每个时间序列的最后时间的数组
self.last_times = last_times
self.indices = indices
self.n_groups = len(files_ds)
def __getitem__(self, idx):
if not isinstance(idx, int):
raise ValueError(f'idx must be int, got {type(idx)}')
temporal_cols = self.temporal_cols.copy()
data = pd.read_parquet(self.files_ds[idx], columns=temporal_cols.tolist()).to_numpy()
data, temporal_cols = TimeSeriesDataset._ensure_available_mask(data, temporal_cols)
data = self._as_torch_copy(data)
# 将时间数据向左填充
temporal = torch.zeros(size=(len(temporal_cols), self.max_size),
dtype=torch.float32)
temporal[:len(temporal_cols), -len(data):] = data.permute(1,0)
# 如有静态数据,请添加。
static = None if self.static is None else self.static[idx,:]
item = dict(temporal=temporal, temporal_cols=temporal_cols,
static=static, static_cols=self.static_cols,
y_idx=self.y_idx)
return item
@staticmethod
def from_data_directories(directories, static_df=None, sort_df=False, exogs=[], id_col='unique_id', time_col='ds', target_col='y'):
"""We expect directories to be a list of directories of the form [unique_id=id_0, unique_id=id_1, ...]. Each directory should contain the timeseries corresponding to that unqiue_id,
represented as a pandas or polars DataFrame. The timeseries can be entirely contained in one parquet file or split between multiple, but within each parquet files the timeseries should be sorted by time.
Static df should also be a pandas or polars DataFrame"""
import pyarrow as pa
# 如果未定义索引,则先定义索引,然后提取静态特征。
static, static_cols = TimeSeriesDataset._extract_static_features(static_df, sort_df, id_col)
max_size = 0
min_size = float('inf')
last_times = []
ids = []
expected_temporal = {target_col, *exogs}
available_mask_seen = True
for dir in directories:
dir_path = Path(dir)
if not dir_path.is_dir():
raise ValueError(f'paths must be directories, {dir} is not.')
uid = dir_path.name.split('=')[-1]
total_rows = 0
last_time = None
for file in dir_path.glob('*.parquet'):
meta = pa.parquet.read_metadata(file)
rg = meta.row_group(0)
col2pos = {rg.column(i).path_in_schema: i for i in range(rg.num_columns)}
last_time_file = meta.row_group(meta.num_row_groups -1).column(col2pos[time_col]).statistics.max
last_time = max(last_time, last_time_file) if last_time is not None else last_time_file
total_rows += sum(meta.row_group(i).num_rows for i in range(meta.num_row_groups))
# 检查所有时间列是否存在
missing_cols = expected_temporal - col2pos.keys()
if missing_cols:
raise ValueError(f"Temporal columns: {missing_cols} not found in the file: {file}.")
if 'available_mask' not in col2pos.keys():
available_mask_seen = False
elif not available_mask_seen:
# 如果触发了此条件,available_mask 列将出现在此文件中,但之前文件中并未包含该列。
raise ValueError("The available_mask column is present in some files but is missing in others.")
else:
expected_temporal.add("available_mask")
max_size = max(total_rows, max_size)
min_size = min(total_rows, min_size)
ids.append(uid)
last_times.append(last_time)
last_times = pd.Index(last_times, name=time_col)
ids = pd.Series(ids, name=id_col)
if "available_mask" in expected_temporal:
exogs = ["available_mask", *exogs]
temporal_cols = pd.Index([target_col, *exogs])
dataset = LocalFilesTimeSeriesDataset(
files_ds=directories,
temporal_cols=temporal_cols,
id_col=id_col,
time_col=time_col,
target_col=target_col,
last_times=last_times,
indices=ids,
min_size=min_size,
max_size=max_size,
y_idx=0,
static=static,
static_cols=static_cols,
sorted=sort_df
)
return datasetshow_doc(TimeSeriesDataset)
# 测试sort_df=True功能
temporal_df = generate_series(n_series=1000,
n_temporal_features=0, equal_ends=False)
sorted_temporal_df = temporal_df.sort_values(['unique_id', 'ds'])
unsorted_temporal_df = sorted_temporal_df.sample(frac=1.0)
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=unsorted_temporal_df,
sort_df=True)
np.testing.assert_allclose(dataset.temporal[:,:-1],
sorted_temporal_df.drop(columns=['unique_id', 'ds']).values)
test_eq(indices, pd.Series(sorted_temporal_df['unique_id'].unique()))
test_eq(dates, temporal_df.groupby('unique_id')['ds'].max().values)class TimeSeriesDataModule(pl.LightningDataModule):
def __init__(
self,
dataset: BaseTimeSeriesDataset,
batch_size=32,
valid_batch_size=1024,
num_workers=0,
drop_last=False,
shuffle_train=True,
):
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.valid_batch_size = valid_batch_size
self.num_workers = num_workers
self.drop_last = drop_last
self.shuffle_train = shuffle_train
def train_dataloader(self):
loader = TimeSeriesLoader(
self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=self.shuffle_train,
drop_last=self.drop_last
)
return loader
def val_dataloader(self):
loader = TimeSeriesLoader(
self.dataset,
batch_size=self.valid_batch_size,
num_workers=self.num_workers,
shuffle=False,
drop_last=self.drop_last
)
return loader
def predict_dataloader(self):
loader = TimeSeriesLoader(
self.dataset,
batch_size=self.valid_batch_size,
num_workers=self.num_workers,
shuffle=False
)
return loadershow_doc(TimeSeriesDataModule)
batch_size = 128
data = TimeSeriesDataModule(dataset=dataset,
batch_size=batch_size, drop_last=True)
for batch in data.train_dataloader():
test_eq(batch['temporal'].shape, (batch_size, 2, 500))
test_eq(batch['temporal_cols'], ['y', 'available_mask'])
batch_size = 128
n_static_features = 2
n_temporal_features = 4
temporal_df, static_df = generate_series(n_series=1000,
n_static_features=n_static_features,
n_temporal_features=n_temporal_features,
equal_ends=False)
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,
static_df=static_df,
sort_df=True)
data = TimeSeriesDataModule(dataset=dataset,
batch_size=batch_size, drop_last=True)
for batch in data.train_dataloader():
test_eq(batch['temporal'].shape, (batch_size, n_temporal_features + 2, 500))
test_eq(batch['temporal_cols'],
['y'] + [f'temporal_{i}' for i in range(n_temporal_features)] + ['available_mask'])
test_eq(batch['static'].shape, (batch_size, n_static_features))
test_eq(batch['static_cols'], [f'static_{i}' for i in range(n_static_features)])# 隐藏
# 测试sort_df=True功能
temporal_df = generate_series(n_series=2,
n_temporal_features=2, equal_ends=True)
temporal_df = temporal_df.groupby('unique_id').tail(10)
temporal_df = temporal_df.reset_index()
temporal_full_df = temporal_df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
temporal_full_df.loc[temporal_full_df.ds > '2001-05-11', ['y', 'temporal_0']] = None
split1_df = temporal_full_df.loc[temporal_full_df.ds <= '2001-05-11']
split2_df = temporal_full_df.loc[temporal_full_df.ds > '2001-05-11']
# 测试可用面罩
temporal_df_w_mask = temporal_df.copy()
temporal_df_w_mask['available_mask'] = 1
# Mask with all 1's
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,
sort_df=True)
mask_average = dataset.temporal[:, -1].mean()
np.testing.assert_almost_equal(mask_average, 1.0000)
# Add 0's to available mask
temporal_df_w_mask.loc[temporal_df_w_mask.ds > '2001-05-11', 'available_mask'] = 0
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,
sort_df=True)
mask_average = dataset.temporal[:, -1].mean()
np.testing.assert_almost_equal(mask_average, 0.7000)
# 可用面具不在最后一列
temporal_df_w_mask = temporal_df_w_mask[['unique_id','ds','y','available_mask', 'temporal_0','temporal_1']]
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df_w_mask,
sort_df=True)
mask_average = dataset.temporal[:, 1].mean()
np.testing.assert_almost_equal(mask_average, 0.7000)# 为了测试`update_df`方法对`future_df`的正确处理
# 我们正在检查是否能够恢复AirPassengers数据集。
# 使用数据框或将其拆分为多个部分并进行初始化。
# 完整数据集
dataset_full, indices_full, dates_full, ds_full = TimeSeriesDataset.from_df(df=temporal_full_df,
sort_df=False)
# SPLIT_1 数据集
dataset_1, indices_1, dates_1, ds_1 = TimeSeriesDataset.from_df(df=split1_df,
sort_df=False)
dataset_1 = dataset_1.update_dataset(dataset_1, split2_df)
np.testing.assert_almost_equal(dataset_full.temporal.numpy(), dataset_1.temporal.numpy())
test_eq(dataset_full.max_size, dataset_1.max_size)
test_eq(dataset_full.indptr, dataset_1.indptr)
# 测试trim_dataset功能
n_static_features = 0
n_temporal_features = 2
temporal_df = generate_series(n_series=100,
min_length=50,
max_length=100,
n_static_features=n_static_features,
n_temporal_features=n_temporal_features,
equal_ends=False)
dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=temporal_df,
static_df=static_df,
sort_df=True)left_trim = 10
right_trim = 20
dataset_trimmed = dataset.trim_dataset(dataset, left_trim=left_trim, right_trim=right_trim)
np.testing.assert_almost_equal(dataset.temporal[dataset.indptr[50]+left_trim:dataset.indptr[51]-right_trim].numpy(),
dataset_trimmed.temporal[dataset_trimmed.indptr[50]:dataset_trimmed.indptr[51]].numpy())#| 极地
import polars#| 极地
temporal_df2 = temporal_df.copy()
for col in ('unique_id', 'temporal_0', 'temporal_1'):
temporal_df2[col] = temporal_df2[col].cat.codes
temporal_pl = polars.from_pandas(temporal_df2).sample(fraction=1.0)
static_pl = polars.from_pandas(static_df.assign(unique_id=lambda df: df['unique_id'].astype('int64')))
dataset_pl, indices_pl, dates_pl, ds_pl = TimeSeriesDataset.from_df(df=temporal_pl, static_df=static_df, sort_df=True)
for attr in ('static_cols', 'temporal_cols', 'min_size', 'max_size', 'n_groups'):
test_eq(getattr(dataset, attr), getattr(dataset_pl, attr))
torch.testing.assert_allclose(dataset.temporal, dataset_pl.temporal)
torch.testing.assert_allclose(dataset.static, dataset_pl.static)
pd.testing.assert_series_equal(indices.astype('int64'), indices_pl.to_pandas().astype('int64'))
pd.testing.assert_index_equal(dates, pd.Index(dates_pl, name='ds'))
np.testing.assert_array_equal(ds, ds_pl)
np.testing.assert_array_equal(dataset.indptr, dataset_pl.indptr)class _DistributedTimeSeriesDataModule(TimeSeriesDataModule):
def __init__(
self,
dataset: _FilesDataset,
batch_size=32,
valid_batch_size=1024,
num_workers=0,
drop_last=False,
shuffle_train=True,
):
super(TimeSeriesDataModule, self).__init__()
self.files_ds = dataset
self.batch_size = batch_size
self.valid_batch_size = valid_batch_size
self.num_workers = num_workers
self.drop_last = drop_last
self.shuffle_train = shuffle_train
def setup(self, stage):
import torch.distributed as dist
df = pd.read_parquet(self.files_ds.files[dist.get_rank()])
if self.files_ds.static_cols is not None:
static_df = (
df[[self.files_ds.id_col] + self.files_ds.static_cols.tolist()]
.groupby(self.files_ds.id_col, observed=True)
.head(1)
)
df = df.drop(columns=self.files_ds.static_cols)
else:
static_df = None
self.dataset, *_ = TimeSeriesDataset.from_df(
df=df,
static_df=static_df,
sort_df=True,
id_col=self.files_ds.id_col,
time_col=self.files_ds.time_col,
target_col=self.files_ds.target_col,
)Give us a ⭐ on Github