dgl.sampling.pack_traces
- dgl.sampling.pack_traces(traces, types)[source]
将
random_walk()
返回的填充轨迹打包成一个连接数组。 填充值(-1)被移除,并且每个轨迹的长度和偏移量 与连接的节点ID和节点类型数组一起返回。- Parameters:
traces (Tensor) – 一个二维的节点ID张量。必须在CPU上,并且是
int32
或int64
类型。types (Tensor) – 一个一维的节点类型ID张量。必须在CPU上,并且是
int32
或int64
类型。
- Returns:
concat_vids (Tensor) – 所有节点ID连接并去除填充值的数组。
concat_types (Tensor) – 对应于
concat_vids
中每个节点的节点类型数组。 与concat_vids
长度相同。lengths (Tensor) – 原始跟踪张量中每个跟踪的长度。
offsets (Tensor) – 原始跟踪张量中每个跟踪在新连接张量中的偏移量。
注释
返回的张量位于CPU上。
示例
>>> g2 = dgl.heterograph({ ... ('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]), ... ('user', 'view', 'item'): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]), ... ('item', 'viewed-by', 'user'): ([0, 1, 1, 2, 2, 1], [0, 0, 1, 2, 3, 3]) >>> traces, types = dgl.sampling.random_walk( ... g2, [0, 0], metapath=['follow', 'view', 'viewed-by'] * 2, ... restart_prob=torch.FloatTensor([0, 0.5, 0, 0, 0.5, 0])) >>> traces, types (tensor([[ 0, 1, -1, -1, -1, -1, -1], [ 0, 1, 1, 3, 0, 0, 0]]), tensor([0, 0, 1, 0, 0, 1, 0])) >>> concat_vids, concat_types, lengths, offsets = dgl.sampling.pack_traces(traces, types) >>> concat_vids tensor([0, 1, 0, 1, 1, 3, 0, 0, 0]) >>> concat_types tensor([0, 0, 0, 0, 1, 0, 0, 1, 0]) >>> lengths tensor([2, 7]) >>> offsets tensor([0, 2]))
第一个张量
concat_vids
是所有路径的连接,即traces
的扁平化数组,不包括所有填充值(-1)。第二个张量
concat_types
代表第一个张量中所有对应节点的节点类型ID。第三和第四张量表示每条路径的长度和偏移量。通过这些张量,可以轻松获取第i条随机游走路径:
>>> vids = concat_vids.split(lengths.tolist()) >>> vtypes = concat_vtypes.split(lengths.tolist()) >>> vids[1], vtypes[1] (tensor([0, 1, 1, 3, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 1, 0]))