torch_frame.utils.split 的源代码

import numpy as np

# Mapping split name to integer.
SPLIT_TO_NUM = {'train': 0, 'val': 1, 'test': 2}


[docs]def generate_random_split(length: int, seed: int, train_ratio: float = 0.8, val_ratio: float = 0.1, include_test: bool = True) -> np.ndarray: r"""Generate a list of random split assignments of the specified length. The elements are either :obj:`0`, :obj:`1`, or :obj:`2`, representing train, val, test, respectively. Note that this function relies on the fact that numpy's shuffle is consistent across versions, which has been historically the case. """ assert train_ratio > 0 assert val_ratio > 0 if include_test: assert train_ratio + val_ratio < 1 train_num = int(length * train_ratio) val_num = int(length * val_ratio) test_num = length - train_num - val_num arr = np.concatenate([ np.full(train_num, SPLIT_TO_NUM['train']), np.full(val_num, SPLIT_TO_NUM['val']), np.full(test_num, SPLIT_TO_NUM['test']) ]) else: assert train_ratio + val_ratio == 1 train_num = int(length * train_ratio) val_num = length - train_num arr = np.concatenate([ np.full(train_num, SPLIT_TO_NUM['train']), np.full(val_num, SPLIT_TO_NUM['val']), ]) np.random.seed(seed) np.random.shuffle(arr) return arr