torch_frame.data.MultiNestedTensor

class MultiNestedTensor(num_rows: int, num_cols: int, values: Tensor, offset: Tensor)[source]

基础类:_MultiTensor

一个只读的基于PyTorch张量的数据结构,存储 [num_rows, num_cols, *],其中最后一维的大小可以 因不同的行/列而异。在内部,我们以高效的扁平化格式存储对象: (values, offset),其中位于 (i, j) 的PyTorch张量通过 values[offset[i*num_cols+j]:offset[i*num_cols+j+1]] 访问。 它支持各种高级索引,包括沿行和列的切片和列表索引。

Parameters:

示例

>>> import torch
>>> tensor_mat = [
...    [torch.tensor([1, 2]), torch.tensor([3])],
...    [torch.tensor([4]), torch.tensor([5, 6, 7])],
...    [torch.tensor([8, 9]), torch.tensor([10])],
... ]
>>> mnt = MultiNestedTensor.from_tensor_mat(tensor_mat)
>>> mnt
MultiNestedTensor(num_rows=3, num_cols=2, device='cpu')
>>> mnt.values
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
>>> mnt.offset
tensor([ 0,  2,  3,  4,  7,  9, 10])
>>> mnt[0, 0]
torch.tensor([1, 2])
>>> mnt[1, 1]
tensor([5, 6, 7])
>>> mnt[0] # Row integer indexing
MultiNestedTensor(num_rows=1, num_cols=2, device='cpu')
>>> mnt[:, 0] # Column integer indexing
MultiNestedTensor(num_rows=3, num_cols=1, device='cpu')
>>> mnt[:2] # Row integer slicing
MultiNestedTensor(num_rows=2, num_cols=2, device='cpu')
>>> mnt[[2, 1, 2, 0]] # Row list indexing
MultiNestedTensor(num_rows=4, num_cols=2, device='cpu')
>>> mnt.to_dense(fill_value = -1) # Map to a dense matrix via padding
tensor([[[ 1,  2, -1],
        [ 3, -1, -1]],
        [[ 4, -1, -1],
        [ 5,  6,  7]],
        [[ 8,  9, -1],
        [10, -1, -1]]])
classmethod from_tensor_mat(tensor_mat: list[list[torch.Tensor]]) MultiNestedTensor[source]

tensor_mat构建MultiNestedTensor对象。

Parameters:

tensor_mat (List[List[Tensor]]) – 一个由 torch.Tensor 对象组成的矩阵。tensor_mat[i][j] 包含第 i 行和第 j 列的一维 torch.Tensor,大小可变。

Returns:

一个 MultiNestedTensor 实例。

Return type:

MultiNestedTensor

fillna_col(col_index: int, fill_value: int | float | Tensor) None[source]

fill_value原地填充MultiTensor中的第index列。

Parameters:
  • col_index (int) – 要选择的张量的列索引。

  • fill_value (Union[int, float, Tensor]) – 用于替换NaN的标量值。

to_dense(fill_value: int | float) Tensor[source]

将MultiNestedTensor映射为带有填充的密集Tensor表示。

Parameters:

fill_value (Union[int, float]) – 填充值。

Returns:

带有形状的填充PyTorch张量对象

(num_rows, num_cols, max_length)

Return type:

张量

static cat(xs: Sequence[MultiNestedTensor], dim: int = 0) MultiNestedTensor[来源]

沿着指定的维度连接一系列MultiNestedTensor

Parameters:
Returns:

连接的多层嵌套张量。

Return type:

MultiNestedTensor