创建您自己的数据集

本教程假设您已经了解训练GNN进行节点分类的基础知识以及如何创建、加载和存储DGL图

在本教程结束时,您将能够

  • 创建您自己的图数据集,用于节点分类、链接预测或图分类。

(预计时间:15分钟)

DGLDataset 对象概述

您的自定义图数据集应继承 dgl.data.DGLDataset 类并实现以下方法:

  • __getitem__(self, i): 检索数据集的第i个示例。一个示例通常包含一个DGL图,偶尔也包含其标签。

  • __len__(self): 数据集中示例的数量。

  • process(self): 从磁盘加载并处理原始数据。

从CSV创建用于图分类的数据集

创建图分类数据集涉及实现 __getitem__ 以返回图及其图级标签。

本教程演示了如何使用以下合成的CSV数据创建图分类数据集:

  • graph_edges.csv: 包含三列:

    • graph_id: 图的ID。

    • src: 给定图中边的源节点。

    • dst: 给定图中边的目标节点。

  • graph_properties.csv: 包含三列:

    • graph_id: 图的ID。

    • label: 图的标签。

    • num_nodes: 图中的节点数量。

urllib.request.urlretrieve(
    "https://data.dgl.ai/tutorial/dataset/graph_edges.csv", "./graph_edges.csv"
)
urllib.request.urlretrieve(
    "https://data.dgl.ai/tutorial/dataset/graph_properties.csv",
    "./graph_properties.csv",
)
edges = pd.read_csv("./graph_edges.csv")
properties = pd.read_csv("./graph_properties.csv")

edges.head()

properties.head()


class SyntheticDataset(DGLDataset):
    def __init__(self):
        super().__init__(name="synthetic")

    def process(self):
        edges = pd.read_csv("./graph_edges.csv")
        properties = pd.read_csv("./graph_properties.csv")
        self.graphs = []
        self.labels = []

        # Create a graph for each graph ID from the edges table.
        # First process the properties table into two dictionaries with graph IDs as keys.
        # The label and number of nodes are values.
        label_dict = {}
        num_nodes_dict = {}
        for _, row in properties.iterrows():
            label_dict[row["graph_id"]] = row["label"]
            num_nodes_dict[row["graph_id"]] = row["num_nodes"]

        # For the edges, first group the table by graph IDs.
        edges_group = edges.groupby("graph_id")

        # For each graph ID...
        for graph_id in edges_group.groups:
            # Find the edges as well as the number of nodes and its label.
            edges_of_id = edges_group.get_group(graph_id)
            src = edges_of_id["src"].to_numpy()
            dst = edges_of_id["dst"].to_numpy()
            num_nodes = num_nodes_dict[graph_id]
            label = label_dict[graph_id]

            # Create a graph and add it to the list of graphs and labels.
            g = dgl.graph((src, dst), num_nodes=num_nodes)
            self.graphs.append(g)
            self.labels.append(label)

        # Convert the label list to tensor for saving.
        self.labels = torch.LongTensor(self.labels)

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)


dataset = SyntheticDataset()
graph, label = dataset[0]
print(graph, label)
Graph(num_nodes=15, num_edges=45,
      ndata_schemes={}
      edata_schemes={}) tensor(0)

通过CSVDataset从CSV创建数据集

前面的例子逐步描述了如何从CSV文件创建数据集。DGL还提供了一个实用类 CSVDataset 用于从CSV文件读取和解析数据。更多详情请参见 4.6 从CSV文件加载数据

# Thumbnail credits: (Un)common Use Cases for Graph Databases, Michal Bachman
# sphinx_gallery_thumbnail_path = '_static/blitz_6_load_data.png'

脚本的总运行时间: (0 分钟 0.322 秒)

Gallery generated by Sphinx-Gallery