torch_geometric.datasets.InfectionDataset

class InfectionDataset(graph_generator: Union[GraphGenerator, str], num_infected_nodes: Union[int, List[int]], max_path_length: Union[int, List[int]], num_graphs: Optional[int] = None, graph_generator_kwargs: Optional[Dict[str, Any]] = None, transform: Optional[Callable] = None)[source]

Bases: InMemoryDataset

生成一个用于评估可解释性算法的合成感染数据集,如“图卷积网络的可解释性技术”论文中所述。 InfectionDataset创建了来自 GraphGenerator的合成图, 其中包含随机分配的num_infected感染节点。 该数据集描述了一个节点分类任务,即预测到感染节点的最短路径长度,并带有相应的真实边缘级掩码。

例如,要生成一个具有500个节点和0.004边概率的随机Erdos-Renyi (ER)感染图,请编写:

from torch_geometric.datasets import InfectionDataset
from torch_geometric.datasets.graph_generator import ERGraph

dataset = InfectionDataset(
    graph_generator=ERGraph(num_nodes=500, edge_prob=0.004),
    num_infected_nodes=50,
    max_path_length=3,
)
Parameters:
  • graph_generator (GraphGenerator or str) – The graph generator to be used, e.g., torch.geometric.datasets.graph_generator.BAGraph (or any string that automatically resolves to it).

  • num_infected_nodes (intList[int]) – 图中随机选择的感染节点数量。 如果以列表形式给出,将为不同的图选择不同数量的感染节点。

  • max_path_length (int, List[int]) – 用于确定节点是否会被感染的最大最短路径长度。 如果以列表形式给出,将为不同的图应用不同的最短路径长度。(默认值:5

  • num_graphs (int, optional) – 要生成的图的数量。 图的数量将自动由 len(num_infected_nodes)len(max_path_length) 确定, 如果其中任何一个作为列表给出,并且只有在 num_infected_nodesmax_path_length 作为整数给出时, 才应设置此参数以创建多个图。(默认值:None

  • graph_generator_kwargs (Dict[str, Any], optional) – Arguments passed to the respective graph generator module in case it gets automatically resolved. (default: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)