torch_frame.data.multi_nested_tensor 的源代码

from __future__ import annotations

from collections.abc import Sequence
from typing import cast

import torch
from torch import Tensor

from torch_frame.data.multi_tensor import _batched_arange, _MultiTensor


[docs]class MultiNestedTensor(_MultiTensor): r"""A read-only PyTorch tensor-based data structure that stores :obj:`[num_rows, num_cols, *]`, where the size of last dimension can be different for different row/column. Internally, we store the object in an efficient flattened format: :obj:`(values, offset)`, where the PyTorch Tensor at :obj:`(i, j)` is accessed by :obj:`values[offset[i*num_cols+j]:offset[i*num_cols+j+1]]`. It supports various advanced indexing, including slicing and list indexing along both row and column. Args: num_rows (int): Number of rows. num_cols (int): Number of columns. values (torch.Tensor): The values :class:`torch.Tensor` of size :obj:`[numel,]`. offset (torch.Tensor): The offset :class:`torch.Tensor` of size :obj:`[num_rows*num_cols+1,]`. Example: >>> import torch >>> tensor_mat = [ ... [torch.tensor([1, 2]), torch.tensor([3])], ... [torch.tensor([4]), torch.tensor([5, 6, 7])], ... [torch.tensor([8, 9]), torch.tensor([10])], ... ] >>> mnt = MultiNestedTensor.from_tensor_mat(tensor_mat) >>> mnt MultiNestedTensor(num_rows=3, num_cols=2, device='cpu') >>> mnt.values tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) >>> mnt.offset tensor([ 0, 2, 3, 4, 7, 9, 10]) >>> mnt[0, 0] torch.tensor([1, 2]) >>> mnt[1, 1] tensor([5, 6, 7]) >>> mnt[0] # Row integer indexing MultiNestedTensor(num_rows=1, num_cols=2, device='cpu') >>> mnt[:, 0] # Column integer indexing MultiNestedTensor(num_rows=3, num_cols=1, device='cpu') >>> mnt[:2] # Row integer slicing MultiNestedTensor(num_rows=2, num_cols=2, device='cpu') >>> mnt[[2, 1, 2, 0]] # Row list indexing MultiNestedTensor(num_rows=4, num_cols=2, device='cpu') >>> mnt.to_dense(fill_value = -1) # Map to a dense matrix via padding tensor([[[ 1, 2, -1], [ 3, -1, -1]], [[ 4, -1, -1], [ 5, 6, 7]], [[ 8, 9, -1], [10, -1, -1]]]) """ def validate(self): assert self.offset[0] == 0 assert self.offset[-1] == len(self.values) assert len(self.offset) == self.num_rows * self.num_cols + 1
[docs] @classmethod def from_tensor_mat( cls, tensor_mat: list[list[Tensor]], ) -> MultiNestedTensor: r"""Construct :class:`MultiNestedTensor` object from :obj:`tensor_mat`. Args: tensor_mat (List[List[Tensor]]): A matrix of :class:`torch.Tensor` objects. :obj:`tensor_mat[i][j]` contains 1-dim :class:`torch.Tensor` of :obj:`i`-th row and :obj:`j`-th column, varying in size. Returns: MultiNestedTensor: A :class:`MultiNestedTensor` instance. """ num_rows = len(tensor_mat) num_cols = len(tensor_mat[0]) offset_list = [] accum_idx = 0 offset_list.append(accum_idx) values_list = [] for i in range(num_rows): if len(tensor_mat[i]) != num_cols: raise RuntimeError( f"The length of each row must be the same." f" tensor_mat[0] has length {num_cols}, but" f" tensor_mat[{i}] has length {len(tensor_mat[i])}") for j in range(num_cols): tensor = tensor_mat[i][j] if not isinstance(tensor, Tensor): raise RuntimeError( "The element of tensor_mat must be PyTorch Tensor") if tensor.ndim != 1: raise RuntimeError( "tensor in tensor_mat needs to be 1-dimensional.") values_list.append(tensor) accum_idx += len(tensor) offset_list.append(accum_idx) values = torch.cat(values_list) offset = torch.tensor(offset_list, device=values.device) return cls(num_rows, num_cols, values, offset)
def _get_value(self, i: int, j: int) -> Tensor: r"""Get :obj:`(i, j)`-th :class:`Tensor` object. Args: i (int): The row integer index. j (int): The column integer index. """ i = self._normalize_index(i, dim=0) j = self._normalize_index(j, dim=1) idx = i * self.num_cols + j start_idx = self.offset[idx] end_idx = self.offset[idx + 1] out = self.values[start_idx:end_idx] return out def _row_narrow(self, start: int, length: int) -> MultiNestedTensor: r"""Helper function called by :meth:`MultiNestedTensor.narrow`.""" assert start >= 0 assert length > 0 end = start + length assert not (start == 0 and end >= self.num_rows) offset = self.offset[start * self.num_cols:end * self.num_cols + 1] values = self.values[offset[0]:offset[-1]] offset = offset - offset[0] return MultiNestedTensor( num_rows=end - start, num_cols=self.num_cols, values=values, offset=offset, ) def _col_narrow(self, start: int, length: int) -> MultiNestedTensor: r"""Helper function called by :meth:`MultiNestedTensor.narrow`.""" assert start >= 0 assert length > 0 end = start + length if start == 0: assert end < self.num_cols offset_mat = self.offset[:-1].reshape(self.num_rows, self.num_cols) offset_mat = offset_mat[:, start:end + 1] else: offset_mat = self.offset[1:].reshape(self.num_rows, self.num_cols) offset_mat = offset_mat[:, start - 1:end] offset_start = offset_mat[:, 0] count = offset_mat[:, -1] - offset_start batch, arange = _batched_arange(count) values = self.values[offset_start[batch] + arange] offset_mat_zero_start = offset_mat - offset_start.view(-1, 1) accum = torch.cumsum(offset_mat_zero_start[:, -1], dim=0) offset_mat_zero_start[1:] += accum[:-1].view(-1, 1) num_cols = end - start offset = torch.full((self.num_rows * num_cols + 1, ), cast(int, accum[-1])) offset[:-1] = offset_mat_zero_start[:, :-1].flatten() return MultiNestedTensor( num_rows=self.num_rows, num_cols=num_cols, values=values, offset=offset, ) def _row_index_select(self, index: Tensor) -> MultiNestedTensor: r"""Helper function called by :obj:`index_select`.""" # Calculate values if index.numel() == 0: return self._empty(dim=0) index_right = (index + 1) * self.num_cols index_left = index * self.num_cols diff = self.offset[index_right] - self.offset[index_left] batch, arange = _batched_arange(diff) idx = self.offset[index_left][batch] + arange values = self.values[idx] # Calculate offset count = torch.full( size=(len(index), ), fill_value=self.num_cols, dtype=torch.long, device=self.device, ) count[-1] += 1 batch, arange = _batched_arange(count) idx = index_left[batch] + arange offset = self.offset[idx] - self.offset[index_left][batch] diff_cumsum = torch.cumsum(diff, dim=0) diff_cumsum = torch.roll(diff_cumsum, 1) diff_cumsum[0] = 0 offset = offset + diff_cumsum[batch] return MultiNestedTensor( num_rows=len(index), num_cols=self.num_cols, values=values, offset=offset, ) def _col_index_select(self, index: Tensor) -> MultiNestedTensor: r"""Helper function called by :obj:`index_select`.""" if index.numel() == 0: return self._empty(dim=1) start_idx = (index + torch.arange( 0, self.num_rows * self.num_cols, self.num_cols, device=self.device, ).view(-1, 1)).flatten() offset_start = self.offset[start_idx] count = self.offset[start_idx + 1] - self.offset[start_idx] offset = count.new_zeros(count.numel() + 1) torch.cumsum(count, dim=0, out=offset[1:]) batch, arange = _batched_arange(count) values = self.values[offset_start[batch] + arange] return MultiNestedTensor( num_rows=self.num_rows, num_cols=len(index), values=values, offset=offset, ) def _single_index_select(self, index: int, dim: int) -> MultiNestedTensor: r"""Get :obj:`index`-th row (:obj:`dim=0`) or column (:obj:`dim=1`).""" dim = MultiNestedTensor._normalize_dim(dim) index = self._normalize_index(index, dim=dim) start_idx: int | Tensor if dim == 0: start_idx = index * self.num_cols end_idx = (index + 1) * self.num_cols + 1 offset = self.offset[start_idx:end_idx] values = self.values[offset[0]:offset[-1]] offset = offset - offset[0] return MultiNestedTensor(num_rows=1, num_cols=self.num_cols, values=values, offset=offset) elif dim == 1: start_idx = torch.arange( 0, self.num_rows * self.num_cols, self.num_cols, device=self.device, ) + index diff = self.offset[start_idx + 1] - self.offset[start_idx] batch, arange = _batched_arange(diff) # Compute values values = self.values[self.offset[start_idx][batch] + arange] # Compute offset offset = diff.new_zeros(diff.numel() + 1) torch.cumsum(diff, dim=0, out=offset[1:]) return MultiNestedTensor(num_rows=self.num_rows, num_cols=1, values=values, offset=offset) else: raise RuntimeError(f"Unsupported dim={dim} for index_select.")
[docs] def fillna_col( self, col_index: int, fill_value: int | float | Tensor, ) -> None: start_idx = torch.arange( col_index, self.num_rows * self.num_cols, self.num_cols, device=self.device, ) diff = self.offset[start_idx + 1] - self.offset[start_idx] batch, arange = _batched_arange(diff) # Compute values values_index = self.offset[start_idx][batch] + arange values_col = self.values[values_index] if self.values.is_floating_point(): values_col[torch.isnan(values_col)] = fill_value else: values_col[values_col == -1] = fill_value self.values[values_index] = values_col
[docs] def to_dense(self, fill_value: int | float) -> Tensor: """Map MultiNestedTensor into dense Tensor representation with padding. Args: fill_value (Union[int, float]): Fill values. Returns: Tensor: Padded PyTorch Tensor object with shape :obj:`(num_rows, num_cols, max_length)` """ count = self.offset[1:] - self.offset[:-1] max_length = cast(int, count.max()) batch, arange = _batched_arange(count) dense = self.values.new_full( (self.num_rows, self.num_cols, max_length), fill_value=fill_value, ) row = batch // self.num_cols col = batch % self.num_cols dense[row, col, arange] = self.values return dense
def _empty(self, dim: int) -> MultiNestedTensor: r"""Creates an empty :class:`MultiNestedTensor`. Args: dim (int): The dimension to empty. Returns: MultiNestedTensor: An empty :class:`MultiNestedTensor`. """ values = torch.tensor([], device=self.device, dtype=self.dtype) offset = torch.zeros(1, device=self.device, dtype=torch.long) return MultiNestedTensor( num_rows=0 if dim == 0 else self.num_rows, num_cols=0 if dim == 1 else self.num_cols, values=values, offset=offset, ) # Static methods ##########################################################
[docs] @staticmethod def cat( xs: Sequence[MultiNestedTensor], dim: int = 0, ) -> MultiNestedTensor: """Concatenates a sequence of :class:`MultiNestedTensor` along the specified dimension. Args: xs (Sequence[MultiNestedTensor]): A sequence of :class:`MultiNestedTensor` to be concatenated. dim (int): The dimension to concatenate along. Returns: MultiNestedTensor: Concatenated multi nested tensor. """ if len(xs) == 0: raise RuntimeError("Cannot concatenate a sequence of length 0.") assert isinstance(xs[0], MultiNestedTensor) dim = MultiNestedTensor._normalize_dim(dim) device = xs[0].device if dim == 0: num_rows = sum(x.num_rows for x in xs) num_cols = xs[0].num_cols for x in xs[1:]: if x.num_cols != num_cols: raise RuntimeError( "num_cols must be the same across a list of input " "multi nested tensors.") values = torch.cat([x.values for x in xs], dim=0) offset = torch.empty(num_rows * num_cols + 1, dtype=torch.long, device=device) accum = 0 idx = 0 for x in xs[:-1]: offset[idx:idx + len(x.offset[:-1])] = x.offset[:-1] offset[idx:idx + len(x.offset[:-1])].add_(accum) accum += cast(int, x.offset[-1]) idx += len(x.offset[:-1]) offset[idx:] = xs[-1].offset offset[idx:].add_(accum) return MultiNestedTensor( num_rows=num_rows, num_cols=num_cols, values=values, offset=offset, ) else: num_rows = xs[0].num_rows num_cols = sum(x.num_cols for x in xs) for x in xs[1:]: if x.num_rows != num_rows: raise RuntimeError( "num_rows must be the same across a list of input " "multi nested tensors.") # (i,j)-th element stores the length of its stored Tensor elem_length_mat = torch.empty(num_rows, num_cols, dtype=torch.long, device=device) col_start_idx = 0 for x in xs: elem_count = x.offset[1:] - x.offset[:-1] elem_length_mat[:, col_start_idx:col_start_idx + x.num_cols] = elem_count.reshape( x.num_rows, x.num_cols) col_start_idx += x.num_cols # Compute offset offset = torch.zeros(num_rows * num_cols + 1, dtype=torch.long, device=device) torch.cumsum(elem_length_mat.flatten(), dim=0, out=offset[1:]) # Compute values values = torch.empty( sum([x.values.numel() for x in xs]), dtype=xs[0].values.dtype, device=device, ) col_start_idx = 0 for x in xs: offset_start_idx = col_start_idx + torch.arange( 0, num_rows * num_cols, num_cols, device=device) offset_start = offset[offset_start_idx] offset_end_idx = offset_start_idx + x.num_cols offset_end = offset[offset_end_idx] count = offset_end - offset_start batch, arange = _batched_arange(count) values[offset_start[batch] + arange] = x.values col_start_idx += x.num_cols return MultiNestedTensor( num_rows=num_rows, num_cols=num_cols, values=values, offset=offset, )