torch_geometric.transforms.RemoveTrainingClasses

class RemoveTrainingClasses(classes: List[int])[source]

Bases: BaseTransform

从节点级训练集中移除由data.train_mask指定的类,例如,为了获得零样本标签场景(功能名称:remove_training_classes)。

Parameters:

classes (List[int]) – 要从训练集中移除的类别。