使用VideoDecoder解码视频¶
在这个例子中,我们将学习如何使用
VideoDecoder
类来解码视频。
首先,一些样板代码:我们将从网上下载一个视频,并定义一个绘图工具。你可以忽略这部分,直接跳到创建解码器。
from typing import Optional
import torch
import requests
# Video source: https://www.pexels.com/video/dog-eating-854132/
# License: CC0. Author: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
raise RuntimeError(f"Failed to download video. {response.status_code = }.")
raw_video_bytes = response.content
def plot(frames: torch.Tensor, title : Optional[str] = None):
try:
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
import matplotlib.pyplot as plt
except ImportError:
print("Cannot plot, please run `pip install torchvision matplotlib`")
return
plt.rcParams["savefig.bbox"] = 'tight'
fig, ax = plt.subplots()
ax.imshow(to_pil_image(make_grid(frames)))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if title is not None:
ax.set_title(title)
plt.tight_layout()
创建解码器¶
我们现在可以从原始(编码的)视频字节创建一个解码器。当然,你可以使用本地视频文件并将路径作为输入,而不是下载视频。
from torchcodec.decoders import VideoDecoder
# You can also pass a path to a local file!
decoder = VideoDecoder(raw_video_bytes)
尚未被解码器解码,但我们已通过metadata
属性访问到一些元数据,该属性是一个VideoStreamMetadata
对象。
print(decoder.metadata)
VideoStreamMetadata:
num_frames: 345
duration_seconds: 13.8
average_fps: 25.0
duration_seconds_from_header: 13.8
bit_rate: 505790.0
num_frames_from_header: 345
num_frames_from_content: 345
begin_stream_seconds: 0.0
end_stream_seconds: 13.8
codec: h264
width: 640
height: 360
average_fps_from_header: 25.0
stream_index: 0
通过索引解码器解码帧¶
first_frame = decoder[0] # using a single int index
every_twenty_frame = decoder[0 : -1 : 20] # using slices
print(f"{first_frame.shape = }")
print(f"{first_frame.dtype = }")
print(f"{every_twenty_frame.shape = }")
print(f"{every_twenty_frame.dtype = }")
first_frame.shape = torch.Size([3, 360, 640])
first_frame.dtype = torch.uint8
every_twenty_frame.shape = torch.Size([18, 3, 360, 640])
every_twenty_frame.dtype = torch.uint8
索引解码器返回的帧为torch.Tensor
对象。
默认情况下,帧的形状为(N, C, H, W)
,其中N是批次大小,C是通道数,H是高度,W是帧的宽度。批次维度N仅在我们解码多个帧时存在。维度顺序可以通过VideoDecoder
的dimension_order
参数更改为N, H, W, C
。帧始终为torch.uint8
数据类型。
plot(first_frame, "First frame")

plot(every_twenty_frame, "Every 20 frame")

遍历帧¶
解码器是一个普通的可迭代对象,可以像这样进行迭代:
for frame in decoder:
assert (
isinstance(frame, torch.Tensor)
and frame.shape == (3, decoder.metadata.height, decoder.metadata.width)
)
检索帧的pts和持续时间¶
索引解码器返回纯torch.Tensor
对象。有时,检索有关帧的附加信息(例如它们的pts(演示时间戳)和它们的持续时间)可能很有用。这可以通过使用get_frame_at()
和get_frames_at()
方法来实现,这些方法将分别返回Frame
和FrameBatch
对象。
last_frame = decoder.get_frame_at(len(decoder) - 1)
print(f"{type(last_frame) = }")
print(last_frame)
type(last_frame) = <class 'torchcodec._frame.Frame'>
Frame:
data (shape): torch.Size([3, 360, 640])
pts_seconds: 13.76
duration_seconds: 0.04
other_frames = decoder.get_frames_at([10, 0, 50])
print(f"{type(other_frames) = }")
print(other_frames)
type(other_frames) = <class 'torchcodec._frame.FrameBatch'>
FrameBatch:
data (shape): torch.Size([3, 3, 360, 640])
pts_seconds: tensor([0.4000, 0.0000, 2.0000], dtype=torch.float64)
duration_seconds: tensor([0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(last_frame.data, "Last frame")
plot(other_frames.data, "Other frames")
无论是Frame
还是
FrameBatch
都有一个data
字段,其中包含
解码后的张量数据。它们还有pts_seconds
和
duration_seconds
字段,这些字段对于
Frame
是单个整数,而对于
FrameBatch
则是1维的torch.Tensor
(批次中每帧一个值)。
使用基于时间的索引¶
到目前为止,我们已经根据索引检索了帧。我们还可以根据它们播放的时间来检索帧,使用get_frame_played_at()
和get_frames_played_at()
,它们分别返回Frame
和FrameBatch
。
frame_at_2_seconds = decoder.get_frame_played_at(seconds=2)
print(f"{type(frame_at_2_seconds) = }")
print(frame_at_2_seconds)
type(frame_at_2_seconds) = <class 'torchcodec._frame.Frame'>
Frame:
data (shape): torch.Size([3, 360, 640])
pts_seconds: 2.0
duration_seconds: 0.04
other_frames = decoder.get_frames_played_at(seconds=[10.1, 0.3, 5])
print(f"{type(other_frames) = }")
print(other_frames)
type(other_frames) = <class 'torchcodec._frame.FrameBatch'>
FrameBatch:
data (shape): torch.Size([3, 3, 360, 640])
pts_seconds: tensor([10.0800, 0.2800, 5.0000], dtype=torch.float64)
duration_seconds: tensor([0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(frame_at_2_seconds.data, "Frame played at 2 seconds")
plot(other_frames.data, "Other frames")
脚本总运行时间: (0 分钟 3.146 秒)