torch_frame.nn.decoder.decoder 的源代码

from abc import ABC, abstractmethod
from typing import Any

from torch import Tensor
from torch.nn import Module


[docs]class Decoder(Module, ABC): r"""Base class for decoder that transforms the input column-wise PyTorch tensor into output tensor on which prediction head is applied. """
[docs] @abstractmethod def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Any: r"""Decode :obj:`x` of shape :obj:`[batch_size, num_cols, channels]` into an output tensor of shape :obj:`[batch_size, out_channels]`. Args: x (torch.Tensor): Input column-wise tensor of shape :obj:`[batch_size, num_cols, hidden_channels]`. args (Any): Extra arguments. kwargs (Any): Extra keyward arguments. """ raise NotImplementedError
[docs] def reset_parameters(self) -> None: r"""Resets all learnable parameters of the module."""