torch_geometric.transforms.RandomNodeSplit

class RandomNodeSplit(split: str = 'train_rest', num_splits: int = 1, num_train_per_class: int = 20, num_val: Union[int, float] = 500, num_test: Union[int, float] = 1000, key: Optional[str] = 'y')[source]

Bases: BaseTransform

通过向DataHeteroData对象添加train_maskval_masktest_mask属性来执行节点级别的随机分割(功能名称:random_node_split)。

Parameters:
  • split (str, optional) – 数据集分割的类型 ("train_rest", "test_rest", "random"). 如果设置为 "train_rest", 除了验证集和测试集中的节点外,所有节点将用于训练(如 “FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling” 论文中所述)。 如果设置为 "test_rest", 除了训练集和验证集中的节点外,所有节点将用于测试(如 “Pitfalls of Graph Neural Network Evaluation” 论文中所述)。 如果设置为 "random", 训练集、验证集和测试集将根据 num_train_per_class, num_valnum_test 随机生成(如 “Semi-supervised Classification with Graph Convolutional Networks” 论文中所述)。 (默认: "train_rest")

  • num_splits (int, optional) – 要添加的分割数。如果大于 1,掩码的形状将为 [num_nodes, num_splits],否则为 [num_nodes]。(默认值:1

  • num_train_per_class (int, optional) – 在"test_rest""random"分割情况下,每个类别的训练样本数量。 (默认: 20)

  • num_val (intfloat, 可选) – 验证样本的数量。 如果是浮点数,它表示包含在验证集中的样本比例。(默认值:500

  • num_test (intfloat, 可选) – 在 "train_rest""random" 分割情况下,测试样本的数量。如果是浮点数,它表示要包含在测试集中的样本比例。 (默认: 1000)

  • key (str, optional) – 保存真实标签的属性的名称。默认情况下,仅会为包含 key 的节点级存储添加节点级分割。(默认值: "y").