根据属性屏蔽节点

class dgl.data.utils.mask_nodes_by_property(property_values, part_ratios, random_seed=None)[source]

基础类:

根据给定的节点属性,为具有分布偏移的节点分割提供分割掩码,如评估图模型在结构分布偏移下的鲁棒性和不确定性中所提出的。

它考虑了节点的分布内(ID)和分布外(OOD)子集。 ID子集包括训练、验证和测试部分,而OOD子集 包括验证和测试部分。它按属性值的升序对节点进行排序, 将它们分成5个不相交的部分,并创建5个 相关的节点掩码数组:

  • 3 for the ID nodes: 'in_train_mask', 'in_valid_mask', 'in_test_mask',

  • and 2 for the OOD nodes: 'out_valid_mask', 'out_test_mask'.

Parameters:
  • property_values (numpy ndarray) – 节点属性(浮点数)值,数据集将根据这些值进行分割。 数组的长度必须等于图中节点的数量。

  • part_ratios (list) – 一个包含5个比率的列表,用于训练、ID验证、ID测试、OOD验证、OOD测试部分。列表中的值必须总和为一。

  • random_seed (int, optional) – Random seed to fix for the initial permutation of nodes. It is used to create a random order for the nodes that have the same property values or belong to the ID subset. (default: None)

Returns:

split_masks – 一个python字典,存储掩码名称作为键,对应的节点掩码数组作为值。

Return type:

dict

示例

>>> num_nodes = 1000
>>> property_values = np.random.uniform(size=num_nodes)
>>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
>>> split_masks = dgl.data.utils.mask_nodes_by_property(property_values, part_ratios)
>>> print('in_valid_mask' in split_masks)
True