torch_frame.data.MultiEmbeddingTensor
- class MultiEmbeddingTensor(num_rows: int, num_cols: int, values: Tensor, offset: Tensor)[来源]
基础类:
_MultiTensor一个只读的基于PyTorch张量的数据结构,存储
[num_rows, num_cols, *],其中最后一维的大小可以 因列而异。请注意,最后一维在每列中跨行是相同的, 而在MultiNestedTensor中,最后一维可以在行和列之间不同。 它支持各种高级索引,包括沿行和列的切片和列表索引。- Parameters:
num_rows (int) – 行数。
num_cols (int) – 列数。
值 (torch.Tensor) – 大小为
[num_rows, dim1+dim2+...+dimN]的torch.Tensor的值。offset (torch.Tensor) – 偏移量
torch.Tensor的大小为[num_cols+1,]。
示例
>>> tensor_list = [ ... torch.tensor([[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]), # emb col 0 ... torch.tensor([[0.6, 0.7], [0.8, 0.9]]), # emb col 1 ... torch.tensor([[1.], [1.1]]), # emb col 2 ... ] >>> met = MultiEmbeddingTensor.from_tensor_list(tensor_list) >>> met MultiEmbeddingTensor(num_rows=2, num_cols=3, device='cpu') >>> met.values tensor([[0.0000, 0.1000, 0.2000, 0.6000, 0.7000, 1.0000], [0.3000, 0.4000, 0.5000, 0.8000, 0.9000, 1.1000]]) >>> met.offset tensor([0, 3, 5, 6]) >>> met[0, 0] tensor([0.0000, 0.1000, 0.2000]) >>> met[1, 1] tensor([0.8000, 0.9000]) >>> met[0] # Row integer indexing MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu') >>> met[:, 0] # Column integer indexing MultiEmbeddingTensor(num_rows=2, num_cols=1, device='cpu') >>> met[:, 0].values # Embedding of column 0 tensor([[0.0000, 0.1000, 0.2000], [0.3000, 0.4000, 0.5000]]) >>> met[:1] # Row slicing MultiEmbeddingTensor(num_rows=1, num_cols=3, device='cpu') >>> met[[0, 1, 0, 0]] # Row list indexing MultiEmbeddingTensor(num_rows=4, num_cols=3, device='cpu')
- classmethod from_tensor_list(tensor_list: list[torch.Tensor]) MultiEmbeddingTensor[来源]
从一系列
torch.Tensor创建一个MultiEmbeddingTensor。- Parameters:
tensor_list (List[Tensor]) – 一个张量列表,其中每个张量具有相同的行数,但可以有不同的列数。
- Returns:
一个
MultiEmbeddingTensor实例。- Return type:
- fillna_col(col_index: int, fill_value: int | float | Tensor) None[来源]
用
fill_value原地填充MultiTensor中的第index列。
- static cat(xs: Sequence[MultiEmbeddingTensor], dim: int = 0) MultiEmbeddingTensor[来源]
沿着指定的维度连接一系列
MultiEmbeddingTensor。- Parameters:
xs (Sequence[MultiEmbeddingTensor]) – 一个由
MultiEmbeddingTensor组成的序列,用于连接。dim (int) – 要沿其连接的维度。
- Returns:
连接的多嵌入张量。
- Return type: