Shortcuts

如何采样视频片段

在这个例子中,我们将学习如何从视频中采样片段。片段通常表示一系列或一批帧,通常作为视频模型的输入。

首先,一些样板代码:我们将从网上下载一个视频,并定义一个绘图工具。你可以忽略这部分,直接跳到创建解码器

from typing import Optional
import torch
import requests


# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
    raise RuntimeError(f"Failed to download video. {response.status_code = }.")

raw_video_bytes = response.content


def plot(frames: torch.Tensor, title : Optional[str] = None):
    try:
        from torchvision.utils import make_grid
        from torchvision.transforms.v2.functional import to_pil_image
        import matplotlib.pyplot as plt
    except ImportError:
        print("Cannot plot, please run `pip install torchvision matplotlib`")
        return

    plt.rcParams["savefig.bbox"] = 'tight'
    fig, ax = plt.subplots()
    ax.imshow(to_pil_image(make_grid(frames)))
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    if title is not None:
        ax.set_title(title)
    plt.tight_layout()

创建解码器

从视频中采样片段总是从创建一个VideoDecoder对象开始。如果你还不熟悉VideoDecoder,可以快速浏览一下:使用VideoDecoder解码视频

from torchcodec.decoders import VideoDecoder

# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)

采样基础

我们现在可以使用我们的解码器来采样片段。让我们首先看一个简单的例子:所有其他采样器都遵循类似的API和原则。我们将使用clips_at_random_indices()来采样从随机索引开始的片段。

from torchcodec.samplers import clips_at_random_indices

# The samplers RNG is controlled by pytorch's RNG. We set a seed for this
# tutorial to be reproducible across runs, but note that hard-coding a seed for
# a training run is generally not recommended.
torch.manual_seed(0)

clips = clips_at_random_indices(
    decoder,
    num_clips=5,
    num_frames_per_clip=4,
    num_indices_between_frames=3,
)
clips
FrameBatch:
  data (shape): torch.Size([5, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [10.2000, 10.3200, 10.4400, 10.5600],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 9.6000,  9.7200,  9.8400,  9.9600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

采样器的输出是一系列片段,表示为 FrameBatch 对象。在这个对象中,我们有不同的 字段:

  • data: 一个5D uint8张量,表示帧数据。其形状为 (num_clips, num_frames_per_clip, …) 其中 … 是 (C, H, W) 或 (H, W, C),取决于 dimension_order 参数 VideoDecoder。这通常是传递给模型的内容。

  • pts_seconds: 一个形状为 (num_clips, num_frames_per_clip) 的二维浮点张量,给出每个剪辑中每帧的起始时间戳,单位为秒。

  • duration_seconds: 一个形状为 (num_clips, num_frames_per_clip) 的二维浮点数张量,表示每个剪辑中每帧的持续时间,单位为秒。

plot(clips[0].data)
sampling

索引和操作剪辑

Clips 是 FrameBatch 对象,它们支持原生的 pytorch 索引语义(包括花式索引)。这使得根据给定条件轻松过滤 clips 变得容易。例如,从上面的 clips 中,我们可以轻松过滤出那些在特定时间戳之后开始的 clips:

tensor([11.3600, 10.2000,  9.8000,  9.6000,  8.4400], dtype=torch.float64)
clips_starting_after_five_seconds = clips[clip_starts > 5]
clips_starting_after_five_seconds
FrameBatch:
  data (shape): torch.Size([5, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [10.2000, 10.3200, 10.4400, 10.5600],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 9.6000,  9.7200,  9.8400,  9.9600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)
every_other_clip = clips[::2]
every_other_clip
FrameBatch:
  data (shape): torch.Size([3, 4, 3, 360, 640])
  pts_seconds: tensor([[11.3600, 11.4800, 11.6000, 11.7200],
        [ 9.8000,  9.9200, 10.0400, 10.1600],
        [ 8.4400,  8.5600,  8.6800,  8.8000]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

注意

在给定时间戳后获取剪辑的更自然和高效的方法是依赖于采样范围参数,我们将在稍后的高级参数:采样范围中介绍。

基于索引和基于时间的采样器

到目前为止,我们已经使用了 clips_at_random_indices()。Torchcodec 支持额外的采样器,这些采样器主要分为两类:

基于索引的采样器:

基于时间的采样器:

所有这些采样器都遵循类似的API,基于时间的采样器具有与基于索引的采样器相似的参数。两种采样器类型在速度方面通常提供可比的性能。

注意

使用基于时间的采样器还是基于索引的采样器更好?基于索引的采样器的API可能稍微简单一些,并且由于其索引的离散性质,其行为可能更容易理解和控制。对于恒定帧率的视频,基于索引的采样器的行为与基于时间的采样器完全相同。然而,对于可变帧率的视频(这种情况很常见),依赖索引可能会导致视频中某些区域的采样不足或过度,这可能会在训练模型时产生不良的副作用。使用基于时间的采样器可以确保在时间维度上具有均匀的采样特性。

高级参数:采样范围

有时,我们可能不想从整个视频中采样片段。我们可能只对在较小间隔内开始的片段感兴趣。在采样器中,sampling_range_startsampling_range_end 参数控制采样范围:它们定义了允许片段开始的位置。有两个重要的事情需要记住:

  • sampling_range_end 是一个开放的上限:剪辑只能在 [sampling_range_start, sampling_range_end) 范围内开始。

  • 因为这些参数定义了剪辑可以开始的位置,剪辑可能包含sampling_range_end之后的帧!

from torchcodec.samplers import clips_at_regular_timestamps

clips = clips_at_regular_timestamps(
    decoder,
    seconds_between_clip_starts=1,
    num_frames_per_clip=4,
    seconds_between_frames=0.5,
    sampling_range_start=2,
    sampling_range_end=5
)
clips
FrameBatch:
  data (shape): torch.Size([3, 4, 3, 360, 640])
  pts_seconds: tensor([[2.0000, 2.4800, 3.0000, 3.4800],
        [3.0000, 3.4800, 4.0000, 4.4800],
        [4.0000, 4.4800, 5.0000, 5.4800]], dtype=torch.float64)
  duration_seconds: tensor([[0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400],
        [0.0400, 0.0400, 0.0400, 0.0400]], dtype=torch.float64)

高级参数:策略

根据视频的长度或持续时间以及采样参数,采样器可能会尝试采样视频结束之后的帧。policy参数定义了如何用有效帧替换这些无效帧。

from torchcodec.samplers import clips_at_random_timestamps

end_of_video = decoder.metadata.end_stream_seconds
print(f"{end_of_video = }")
end_of_video = 13.8
torch.manual_seed(0)
clips = clips_at_random_timestamps(
    decoder,
    num_clips=1,
    num_frames_per_clip=5,
    seconds_between_frames=0.4,
    sampling_range_start=end_of_video - 1,
    sampling_range_end=end_of_video,
    policy="repeat_last",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.6800, 13.6800, 13.6800]], dtype=torch.float64)

我们在上面看到视频的结束时间是13.8秒。采样器尝试在时间戳[13.28, 13.68, 14.08, …]处采样帧,但14.08是一个无效的时间戳,超出了视频的结束时间。使用默认的“repeat_last”策略,采样器只需重复13.68秒的最后一帧来构建剪辑。

另一种策略是“wrap”:采样器会环绕剪辑并根据需要重复前几个有效帧:

torch.manual_seed(0)
clips = clips_at_random_timestamps(
    decoder,
    num_clips=1,
    num_frames_per_clip=5,
    seconds_between_frames=0.4,
    sampling_range_start=end_of_video - 1,
    sampling_range_end=end_of_video,
    policy="wrap",
)
clips.pts_seconds
tensor([[13.2800, 13.6800, 13.2800, 13.6800, 13.2800]], dtype=torch.float64)

默认情况下,sampling_range_end的值会自动设置,以确保采样器不会尝试在视频结束之后采样帧:默认值确保剪辑在结束之前足够早地开始。这意味着默认情况下,策略参数很少起作用,大多数用户可能不需要过多担心它。

脚本总运行时间: (0 分钟 0.593 秒)

Gallery generated by Sphinx-Gallery