torch_frame.data.MultiNestedTensor
- class MultiNestedTensor(num_rows: int, num_cols: int, values: Tensor, offset: Tensor)[source]
基础类:
_MultiTensor一个只读的基于PyTorch张量的数据结构,存储
[num_rows, num_cols, *],其中最后一维的大小可以 因不同的行/列而异。在内部,我们以高效的扁平化格式存储对象:(values, offset),其中位于(i, j)的PyTorch张量通过values[offset[i*num_cols+j]:offset[i*num_cols+j+1]]访问。 它支持各种高级索引,包括沿行和列的切片和列表索引。- Parameters:
num_rows (int) – 行数。
num_cols (int) – 列数。
值 (torch.Tensor) – 大小为
[numel,]的torch.Tensor的值。offset (torch.Tensor) – 偏移量
torch.Tensor的大小为[num_rows*num_cols+1,]。
示例
>>> import torch >>> tensor_mat = [ ... [torch.tensor([1, 2]), torch.tensor([3])], ... [torch.tensor([4]), torch.tensor([5, 6, 7])], ... [torch.tensor([8, 9]), torch.tensor([10])], ... ] >>> mnt = MultiNestedTensor.from_tensor_mat(tensor_mat) >>> mnt MultiNestedTensor(num_rows=3, num_cols=2, device='cpu') >>> mnt.values tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) >>> mnt.offset tensor([ 0, 2, 3, 4, 7, 9, 10]) >>> mnt[0, 0] torch.tensor([1, 2]) >>> mnt[1, 1] tensor([5, 6, 7]) >>> mnt[0] # Row integer indexing MultiNestedTensor(num_rows=1, num_cols=2, device='cpu') >>> mnt[:, 0] # Column integer indexing MultiNestedTensor(num_rows=3, num_cols=1, device='cpu') >>> mnt[:2] # Row integer slicing MultiNestedTensor(num_rows=2, num_cols=2, device='cpu') >>> mnt[[2, 1, 2, 0]] # Row list indexing MultiNestedTensor(num_rows=4, num_cols=2, device='cpu') >>> mnt.to_dense(fill_value = -1) # Map to a dense matrix via padding tensor([[[ 1, 2, -1], [ 3, -1, -1]], [[ 4, -1, -1], [ 5, 6, 7]], [[ 8, 9, -1], [10, -1, -1]]])
- classmethod from_tensor_mat(tensor_mat: list[list[torch.Tensor]]) MultiNestedTensor[source]
从
tensor_mat构建MultiNestedTensor对象。- Parameters:
tensor_mat (List[List[Tensor]]) – 一个由
torch.Tensor对象组成的矩阵。tensor_mat[i][j]包含第i行和第j列的一维torch.Tensor,大小可变。- Returns:
一个
MultiNestedTensor实例。- Return type:
- fillna_col(col_index: int, fill_value: int | float | Tensor) None[source]
用
fill_value原地填充MultiTensor中的第index列。
- static cat(xs: Sequence[MultiNestedTensor], dim: int = 0) MultiNestedTensor[来源]
沿着指定的维度连接一系列
MultiNestedTensor。- Parameters:
xs (Sequence[MultiNestedTensor]) – 一个由
MultiNestedTensor组成的序列,用于连接。dim (int) – 要沿其连接的维度。
- Returns:
连接的多层嵌套张量。
- Return type: