PythonSetFilterer

class PythonSetFilterer(mapped_triples: Tensor)[来源]

基础类:Filterer

使用Python集合进行过滤的过滤器。

由于需要将torch长张量转换为Python元组,预计此过滤器会相当慢。它仍然可以作为性能比较的基准。

初始化过滤器。

Parameters:

mapped_triples (Tensor) – 基于ID的三元组。

方法总结

contains(batch)

检查是否包含三元组。

方法文档

contains(batch: Tensor) Tensor[source]

检查是否包含三元组。

支持批处理。

Parameters:

batch (Tensor) – 形状 (batch_size, 3) 三元组的批次。

Returns:

形状: (batch_size,) 三元组是否包含在训练三元组中。

Return type:

Tensor