torch_geometric.datasets.InfectionDataset
- class InfectionDataset(graph_generator: Union[GraphGenerator, str], num_infected_nodes: Union[int, List[int]], max_path_length: Union[int, List[int]], num_graphs: Optional[int] = None, graph_generator_kwargs: Optional[Dict[str, Any]] = None, transform: Optional[Callable] = None)[source]
Bases:
InMemoryDataset生成一个用于评估可解释性算法的合成感染数据集,如“图卷积网络的可解释性技术”论文中所述。
InfectionDataset创建了来自GraphGenerator的合成图, 其中包含随机分配的num_infected感染节点。 该数据集描述了一个节点分类任务,即预测到感染节点的最短路径长度,并带有相应的真实边缘级掩码。例如,要生成一个具有
500个节点和0.004边概率的随机Erdos-Renyi (ER)感染图,请编写:from torch_geometric.datasets import InfectionDataset from torch_geometric.datasets.graph_generator import ERGraph dataset = InfectionDataset( graph_generator=ERGraph(num_nodes=500, edge_prob=0.004), num_infected_nodes=50, max_path_length=3, )
- Parameters:
graph_generator (GraphGenerator or str) – The graph generator to be used, e.g.,
torch.geometric.datasets.graph_generator.BAGraph(or any string that automatically resolves to it).num_infected_nodes (int 或 List[int]) – 图中随机选择的感染节点数量。 如果以列表形式给出,将为不同的图选择不同数量的感染节点。
max_path_length (int, List[int]) – 用于确定节点是否会被感染的最大最短路径长度。 如果以列表形式给出,将为不同的图应用不同的最短路径长度。(默认值:
5)num_graphs (int, optional) – 要生成的图的数量。 图的数量将自动由
len(num_infected_nodes)或len(max_path_length)确定, 如果其中任何一个作为列表给出,并且只有在num_infected_nodes和max_path_length作为整数给出时, 才应设置此参数以创建多个图。(默认值:None)graph_generator_kwargs (Dict[str, Any], optional) – Arguments passed to the respective graph generator module in case it gets automatically resolved. (default:
None)transform (callable, optional) – A function/transform that takes in an
torch_geometric.data.Dataobject and returns a transformed version. The data object will be transformed before every access. (default:None)