基础

本节介绍了TorchOpt中有用的概念。

TorchOpt 类型

torchopt.base.GradientTransformation(init, ...)

一对纯函数,用于实现梯度变换。

torchopt.base.TransformInitFn(*args, **kwargs)

GradientTransformationinit() 步骤的可调用类型。

torchopt.base.TransformUpdateFn(*args, **kwargs)

一个可调用类型,用于GradientTransformationupdate()步骤。

PyTrees

PyTrees 是 TorchOpt 中的一个重要概念。 它们可以被视为向量的泛化。 它们是一种使用元组和字典来结构化参数或权重的方式。 TorchOpt 中的许多求解器都对 pytrees 有原生支持。

浮点数精度

TorchOpt 默认使用单精度(32位)浮点数(torch.float32)。 然而,对于某些算法,这可能不够。 双精度(64位)浮点数(torch.float64)可以通过在文件开头添加以下行来启用:

import torch

torch.set_default_dtype(torch.float64)