根据属性屏蔽节点
- class dgl.data.utils.mask_nodes_by_property(property_values, part_ratios, random_seed=None)[source]
基础类:
根据给定的节点属性,为具有分布偏移的节点分割提供分割掩码,如评估图模型在结构分布偏移下的鲁棒性和不确定性中所提出的。
它考虑了节点的分布内(ID)和分布外(OOD)子集。 ID子集包括训练、验证和测试部分,而OOD子集 包括验证和测试部分。它按属性值的升序对节点进行排序, 将它们分成5个不相交的部分,并创建5个 相关的节点掩码数组:
3 for the ID nodes:
'in_train_mask'
,'in_valid_mask'
,'in_test_mask'
,and 2 for the OOD nodes:
'out_valid_mask'
,'out_test_mask'
.
- Parameters:
property_values (numpy ndarray) – 节点属性(浮点数)值,数据集将根据这些值进行分割。 数组的长度必须等于图中节点的数量。
part_ratios (list) – 一个包含5个比率的列表,用于训练、ID验证、ID测试、OOD验证、OOD测试部分。列表中的值必须总和为一。
random_seed (int, optional) – Random seed to fix for the initial permutation of nodes. It is used to create a random order for the nodes that have the same property values or belong to the ID subset. (default: None)
- Returns:
split_masks – 一个python字典,存储掩码名称作为键,对应的节点掩码数组作为值。
- Return type:
示例
>>> num_nodes = 1000 >>> property_values = np.random.uniform(size=num_nodes) >>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2] >>> split_masks = dgl.data.utils.mask_nodes_by_property(property_values, part_ratios) >>> print('in_valid_mask' in split_masks) True