Shortcuts

torchcodec._frame 的源代码

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
from dataclasses import dataclass
from typing import Iterable, Iterator, Union

from torch import Tensor


def _frame_repr(self):
    # Utility to replace Frame and FrameBatch __repr__ method. This prints the
    # shape of the .data tensor rather than printing the (potentially very long)
    # data tensor itself.
    s = self.__class__.__name__ + ":\n"
    spaces = "  "
    for field in dataclasses.fields(self):
        field_name = field.name
        field_val = getattr(self, field_name)
        if field_name == "data":
            field_name = "data (shape)"
            field_val = field_val.shape
        s += f"{spaces}{field_name}: {field_val}\n"
    return s


[docs]@dataclass class Frame(Iterable): """A single video frame with associated metadata.""" data: Tensor """The frame data as (3-D ``torch.Tensor``).""" pts_seconds: float """The :term:`pts` of the frame, in seconds (float).""" duration_seconds: float """The duration of the frame, in seconds (float).""" def __post_init__(self): # This is called after __init__() when a Frame is created. We can run # input validation checks here. if not self.data.ndim == 3: raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") self.pts_seconds = float(self.pts_seconds) self.duration_seconds = float(self.duration_seconds) def __iter__(self) -> Iterator[Union[Tensor, float]]: for field in dataclasses.fields(self): yield getattr(self, field.name) def __repr__(self): return _frame_repr(self)
[docs]@dataclass class FrameBatch(Iterable): """Multiple video frames with associated metadata. The ``data`` tensor is typically 4D for sequences of frames (NHWC or NCHW), or 5D for sequences of clips, as returned by the :ref:`samplers <sphx_glr_generated_examples_sampling.py>`. When ``data`` is 4D (resp. 5D) the ``pts_seconds`` and ``duration_seconds`` tensors are 1D (resp. 2D). .. note:: The ``pts_seconds`` and ``duration_seconds`` Tensors are always returned on CPU, even if ``data`` is on GPU. """ data: Tensor """The frames data (``torch.Tensor`` of uint8).""" pts_seconds: Tensor """The :term:`pts` of the frame, in seconds (``torch.Tensor`` of floats).""" duration_seconds: Tensor """The duration of the frame, in seconds (``torch.Tensor`` of floats).""" def __post_init__(self): # This is called after __init__() when a FrameBatch is created. We can # run input validation checks here. if self.data.ndim < 3: raise ValueError( f"data must be at least 3-dimensional, got {self.data.shape = }" ) leading_dims = self.data.shape[:-3] if not (leading_dims == self.pts_seconds.shape == self.duration_seconds.shape): raise ValueError( "Tried to create a FrameBatch but the leading dimensions of the inputs do not match. " f"Got {self.data.shape = } so we expected the shape of pts_seconds and " f"duration_seconds to be {leading_dims = }, but got " f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }." ) def __iter__(self) -> Iterator["FrameBatch"]: for data, pts_seconds, duration_seconds in zip( self.data, self.pts_seconds, self.duration_seconds ): yield FrameBatch( data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds, ) def __getitem__(self, key) -> "FrameBatch": return FrameBatch( data=self.data[key], pts_seconds=self.pts_seconds[key], duration_seconds=self.duration_seconds[key], ) def __len__(self): return len(self.data) def __repr__(self): return _frame_repr(self)