torch_frame.data.MultiNestedTensor
- class MultiNestedTensor(num_rows: int, num_cols: int, values: Tensor, offset: Tensor)[source]
基础类:
_MultiTensor
一个只读的基于PyTorch张量的数据结构,存储
[num_rows, num_cols, *]
,其中最后一维的大小可以 因不同的行/列而异。在内部,我们以高效的扁平化格式存储对象:(values, offset)
,其中位于(i, j)
的PyTorch张量通过values[offset[i*num_cols+j]:offset[i*num_cols+j+1]]
访问。 它支持各种高级索引,包括沿行和列的切片和列表索引。- Parameters:
num_rows (int) – 行数。
num_cols (int) – 列数。
值 (torch.Tensor) – 大小为
[numel,]
的torch.Tensor
的值。offset (torch.Tensor) – 偏移量
torch.Tensor
的大小为[num_rows*num_cols+1,]
。
示例
>>> import torch >>> tensor_mat = [ ... [torch.tensor([1, 2]), torch.tensor([3])], ... [torch.tensor([4]), torch.tensor([5, 6, 7])], ... [torch.tensor([8, 9]), torch.tensor([10])], ... ] >>> mnt = MultiNestedTensor.from_tensor_mat(tensor_mat) >>> mnt MultiNestedTensor(num_rows=3, num_cols=2, device='cpu') >>> mnt.values tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) >>> mnt.offset tensor([ 0, 2, 3, 4, 7, 9, 10]) >>> mnt[0, 0] torch.tensor([1, 2]) >>> mnt[1, 1] tensor([5, 6, 7]) >>> mnt[0] # Row integer indexing MultiNestedTensor(num_rows=1, num_cols=2, device='cpu') >>> mnt[:, 0] # Column integer indexing MultiNestedTensor(num_rows=3, num_cols=1, device='cpu') >>> mnt[:2] # Row integer slicing MultiNestedTensor(num_rows=2, num_cols=2, device='cpu') >>> mnt[[2, 1, 2, 0]] # Row list indexing MultiNestedTensor(num_rows=4, num_cols=2, device='cpu') >>> mnt.to_dense(fill_value = -1) # Map to a dense matrix via padding tensor([[[ 1, 2, -1], [ 3, -1, -1]], [[ 4, -1, -1], [ 5, 6, 7]], [[ 8, 9, -1], [10, -1, -1]]])
- classmethod from_tensor_mat(tensor_mat: list[list[torch.Tensor]]) MultiNestedTensor [source]
从
tensor_mat
构建MultiNestedTensor
对象。- Parameters:
tensor_mat (List[List[Tensor]]) – 一个由
torch.Tensor
对象组成的矩阵。tensor_mat[i][j]
包含第i
行和第j
列的一维torch.Tensor
,大小可变。- Returns:
一个
MultiNestedTensor
实例。- Return type:
- fillna_col(col_index: int, fill_value: int | float | Tensor) None [source]
用
fill_value
原地填充MultiTensor
中的第index
列。
- static cat(xs: Sequence[MultiNestedTensor], dim: int = 0) MultiNestedTensor [来源]
沿着指定的维度连接一系列
MultiNestedTensor
。- Parameters:
xs (Sequence[MultiNestedTensor]) – 一个由
MultiNestedTensor
组成的序列,用于连接。dim (int) – 要沿其连接的维度。
- Returns:
连接的多层嵌套张量。
- Return type: