注意
Go to the end to download the full example code
创建您自己的数据集
本教程假设您已经了解训练GNN进行节点分类的基础知识以及如何创建、加载和存储DGL图。
在本教程结束时,您将能够
创建您自己的图数据集,用于节点分类、链接预测或图分类。
(预计时间:15分钟)
DGLDataset
对象概述
您的自定义图数据集应继承 dgl.data.DGLDataset
类并实现以下方法:
__getitem__(self, i)
: 检索数据集的第i
个示例。一个示例通常包含一个DGL图,偶尔也包含其标签。__len__(self)
: 数据集中示例的数量。process(self)
: 从磁盘加载并处理原始数据。
从CSV创建用于节点分类或链接预测的数据集
节点分类数据集通常由单个图及其节点和边特征组成。
本教程使用了一个基于Zachary的空手道俱乐部网络的小型数据集。它包含
一个包含所有成员属性及其属性的
members.csv
文件。一个包含两个俱乐部成员之间成对互动的
interactions.csv
文件。
import urllib.request
import pandas as pd
urllib.request.urlretrieve(
"https://data.dgl.ai/tutorial/dataset/members.csv", "./members.csv"
)
urllib.request.urlretrieve(
"https://data.dgl.ai/tutorial/dataset/interactions.csv",
"./interactions.csv",
)
members = pd.read_csv("./members.csv")
members.head()
interactions = pd.read_csv("./interactions.csv")
interactions.head()
本教程将成员视为节点,将交互视为边。它将年龄作为节点的数值特征,所属俱乐部作为节点的标签,边权重作为边的数值特征。
注意
原始的Zachary空手道俱乐部网络没有成员的年龄信息。本教程中的年龄是合成的,用于演示如何将节点特征添加到图中以创建数据集。
注意
在实践中,直接将年龄作为数值特征在机器学习中可能效果不佳;像分箱或特征归一化这样的策略会更好。本教程为了简单起见,直接使用原始值。
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch
from dgl.data import DGLDataset
class KarateClubDataset(DGLDataset):
def __init__(self):
super().__init__(name="karate_club")
def process(self):
nodes_data = pd.read_csv("./members.csv")
edges_data = pd.read_csv("./interactions.csv")
node_features = torch.from_numpy(nodes_data["Age"].to_numpy())
node_labels = torch.from_numpy(
nodes_data["Club"].astype("category").cat.codes.to_numpy()
)
edge_features = torch.from_numpy(edges_data["Weight"].to_numpy())
edges_src = torch.from_numpy(edges_data["Src"].to_numpy())
edges_dst = torch.from_numpy(edges_data["Dst"].to_numpy())
self.graph = dgl.graph(
(edges_src, edges_dst), num_nodes=nodes_data.shape[0]
)
self.graph.ndata["feat"] = node_features
self.graph.ndata["label"] = node_labels
self.graph.edata["weight"] = edge_features
# If your dataset is a node classification dataset, you will need to assign
# masks indicating whether a node belongs to training, validation, and test set.
n_nodes = nodes_data.shape[0]
n_train = int(n_nodes * 0.6)
n_val = int(n_nodes * 0.2)
train_mask = torch.zeros(n_nodes, dtype=torch.bool)
val_mask = torch.zeros(n_nodes, dtype=torch.bool)
test_mask = torch.zeros(n_nodes, dtype=torch.bool)
train_mask[:n_train] = True
val_mask[n_train : n_train + n_val] = True
test_mask[n_train + n_val :] = True
self.graph.ndata["train_mask"] = train_mask
self.graph.ndata["val_mask"] = val_mask
self.graph.ndata["test_mask"] = test_mask
def __getitem__(self, i):
return self.graph
def __len__(self):
return 1
dataset = KarateClubDataset()
graph = dataset[0]
print(graph)
/home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/checkouts/latest/tutorials/blitz/6_load_data.py:106: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
node_labels = torch.from_numpy(
Graph(num_nodes=34, num_edges=156,
ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int8), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
edata_schemes={'weight': Scheme(shape=(), dtype=torch.float64)})
由于链接预测数据集只涉及单个图,准备链接预测数据集的体验将与准备节点分类数据集的体验相同。
从CSV创建用于图分类的数据集
创建图分类数据集涉及实现
__getitem__
以返回图及其图级标签。
本教程演示了如何使用以下合成的CSV数据创建图分类数据集:
graph_edges.csv
: 包含三列:graph_id
: 图的ID。src
: 给定图中边的源节点。dst
: 给定图中边的目标节点。
graph_properties.csv
: 包含三列:graph_id
: 图的ID。label
: 图的标签。num_nodes
: 图中的节点数量。
urllib.request.urlretrieve(
"https://data.dgl.ai/tutorial/dataset/graph_edges.csv", "./graph_edges.csv"
)
urllib.request.urlretrieve(
"https://data.dgl.ai/tutorial/dataset/graph_properties.csv",
"./graph_properties.csv",
)
edges = pd.read_csv("./graph_edges.csv")
properties = pd.read_csv("./graph_properties.csv")
edges.head()
properties.head()
class SyntheticDataset(DGLDataset):
def __init__(self):
super().__init__(name="synthetic")
def process(self):
edges = pd.read_csv("./graph_edges.csv")
properties = pd.read_csv("./graph_properties.csv")
self.graphs = []
self.labels = []
# Create a graph for each graph ID from the edges table.
# First process the properties table into two dictionaries with graph IDs as keys.
# The label and number of nodes are values.
label_dict = {}
num_nodes_dict = {}
for _, row in properties.iterrows():
label_dict[row["graph_id"]] = row["label"]
num_nodes_dict[row["graph_id"]] = row["num_nodes"]
# For the edges, first group the table by graph IDs.
edges_group = edges.groupby("graph_id")
# For each graph ID...
for graph_id in edges_group.groups:
# Find the edges as well as the number of nodes and its label.
edges_of_id = edges_group.get_group(graph_id)
src = edges_of_id["src"].to_numpy()
dst = edges_of_id["dst"].to_numpy()
num_nodes = num_nodes_dict[graph_id]
label = label_dict[graph_id]
# Create a graph and add it to the list of graphs and labels.
g = dgl.graph((src, dst), num_nodes=num_nodes)
self.graphs.append(g)
self.labels.append(label)
# Convert the label list to tensor for saving.
self.labels = torch.LongTensor(self.labels)
def __getitem__(self, i):
return self.graphs[i], self.labels[i]
def __len__(self):
return len(self.graphs)
dataset = SyntheticDataset()
graph, label = dataset[0]
print(graph, label)
Graph(num_nodes=15, num_edges=45,
ndata_schemes={}
edata_schemes={}) tensor(0)
通过CSVDataset
从CSV创建数据集
前面的例子逐步描述了如何从CSV文件创建数据集。DGL还提供了一个实用类 CSVDataset
用于从CSV文件读取和解析数据。更多详情请参见 4.6 从CSV文件加载数据。
# Thumbnail credits: (Un)common Use Cases for Graph Databases, Michal Bachman
# sphinx_gallery_thumbnail_path = '_static/blitz_6_load_data.png'
脚本的总运行时间: (0 分钟 0.322 秒)