torch_frame.data.MultiEmbeddingTensor

class MultiEmbeddingTensor(num_rows: int, num_cols: int, values: Tensor, offset: Tensor)[来源]

基础类:_MultiTensor

一个只读的基于PyTorch张量的数据结构,存储 [num_rows, num_cols, *],其中最后一维的大小可以 因列而异。请注意,最后一维在每列中跨行是相同的, 而在MultiNestedTensor中,最后一维可以在行和列之间不同。 它支持各种高级索引,包括沿行和列的切片和列表索引。

Parameters:

示例

>>> tensor_list = [
...    torch.tensor([[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]),  # emb col 0
...    torch.tensor([[0.6, 0.7], [0.8, 0.9]]),            # emb col 1
...    torch.tensor([[1.], [1.1]]),                       # emb col 2
... ]
>>> met = MultiEmbeddingTensor.from_tensor_list(tensor_list)
>>> met
MultiEmbeddingTensor(num_rows=2, num_cols=3, device='cpu')
>>> met.values
tensor([[0.0000, 0.1000, 0.2000, 0.6000, 0.7000, 1.0000],
        [0.3000, 0.4000, 0.5000, 0.8000, 0.9000, 1.1000]])
>>> met.offset
tensor([0, 3, 5, 6])
>>> met[0, 0]
tensor([0.0000, 0.1000, 0.2000])
>>> met[1, 1]
tensor([0.8000, 0.9000])
>>> met[0] # Row integer indexing
MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu')
>>> met[:, 0] # Column integer indexing
MultiEmbeddingTensor(num_rows=2, num_cols=1, device='cpu')
>>> met[:, 0].values # Embedding of column 0
tensor([[0.0000, 0.1000, 0.2000],
        [0.3000, 0.4000, 0.5000]])
>>> met[:1] # Row slicing
MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu')
>>> met[[0, 1, 0, 0]] # Row list indexing
MultiEmbeddingTensor(num_rows=4, num_cols=3, device='cpu')
classmethod from_tensor_list(tensor_list: list[torch.Tensor]) MultiEmbeddingTensor[来源]

从一系列torch.Tensor创建一个MultiEmbeddingTensor

Parameters:

tensor_list (List[Tensor]) – 一个张量列表,其中每个张量具有相同的行数,但可以有不同的列数。

Returns:

一个 MultiEmbeddingTensor 实例。

Return type:

MultiEmbeddingTensor

fillna_col(col_index: int, fill_value: int | float | Tensor) None[来源]

fill_value原地填充MultiTensor中的第index列。

Parameters:
  • col_index (int) – 要选择的张量的列索引。

  • fill_value (Union[int, float, Tensor]) – 用于替换NaN的标量值。

static cat(xs: Sequence[MultiEmbeddingTensor], dim: int = 0) MultiEmbeddingTensor[来源]

沿着指定的维度连接一系列MultiEmbeddingTensor

Parameters:
Returns:

连接的多嵌入张量。

Return type:

MultiEmbeddingTensor