torcheval.metrics.Cat¶
- class torcheval.metrics.Cat(*, dim: int = 0, device: device | None = None)¶
沿维度 dim 连接所有输入张量。其功能版本是
torch.cat(input)。所有输入到
Cat.update()的张量必须具有相同的形状(除了在连接维度上)或者为空。零维张量不是
Cat.update()的有效输入。torch.flatten()可以在传入Cat.update()之前将零维张量展平为一维张量。示例:
>>> import torch >>> from torcheval.metrics import Cat >>> metric = Cat(dim=1) >>> metric.update(torch.tensor([[1, 2], [3, 4]])) >>> metric.compute() tensor([[1, 2], [3, 4]])) >>> metric.update(torch.tensor([[5, 6], [7, 8]]))).compute() tensor([[1, 2, 5, 6], [3, 4, 7, 8]])) >>> metric.reset() >>> metric.update(torch.tensor([0])).compute() tensor([0])
- __init__(*, dim: int = 0, device: device | None = None) None¶
初始化一个Cat指标对象。
- Parameters:
dim – 沿哪个维度进行连接,如
torch.cat()所示。
方法
__init__(*[, dim, device])初始化一个Cat指标对象。
compute()返回连接的输入。
load_state_dict(state_dict[, strict])从state_dict加载度量状态变量。
merge_state(metrics)实现此方法以将当前度量的状态变量更新为当前度量和输入度量的合并状态。
reset()将度量状态变量重置为其默认值。
state_dict()将度量状态变量保存在state_dict中。
to(device, *args, **kwargs)将度量状态变量中的张量移动到设备。
update(input)实现此方法以更新您的指标类的状态变量。
属性
deviceMetric.to()的最后一个输入设备。