torch_geometric.transforms.RandomNodeSplit
- class RandomNodeSplit(split: str = 'train_rest', num_splits: int = 1, num_train_per_class: int = 20, num_val: Union[int, float] = 500, num_test: Union[int, float] = 1000, key: Optional[str] = 'y')[source]
Bases:
BaseTransform通过向
Data或HeteroData对象添加train_mask、val_mask和test_mask属性来执行节点级别的随机分割(功能名称:random_node_split)。- Parameters:
split (str, optional) – 数据集分割的类型 (
"train_rest","test_rest","random"). 如果设置为"train_rest", 除了验证集和测试集中的节点外,所有节点将用于训练(如 “FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling” 论文中所述)。 如果设置为"test_rest", 除了训练集和验证集中的节点外,所有节点将用于测试(如 “Pitfalls of Graph Neural Network Evaluation” 论文中所述)。 如果设置为"random", 训练集、验证集和测试集将根据num_train_per_class,num_val和num_test随机生成(如 “Semi-supervised Classification with Graph Convolutional Networks” 论文中所述)。 (默认:"train_rest")num_splits (int, optional) – 要添加的分割数。如果大于
1,掩码的形状将为[num_nodes, num_splits],否则为[num_nodes]。(默认值:1)num_train_per_class (int, optional) – 在
"test_rest"和"random"分割情况下,每个类别的训练样本数量。 (默认:20)num_val (int 或 float, 可选) – 验证样本的数量。 如果是浮点数,它表示包含在验证集中的样本比例。(默认值:
500)num_test (int 或 float, 可选) – 在
"train_rest"和"random"分割情况下,测试样本的数量。如果是浮点数,它表示要包含在测试集中的样本比例。 (默认:1000)key (str, optional) – 保存真实标签的属性的名称。默认情况下,仅会为包含
key的节点级存储添加节点级分割。(默认值:"y").