LSTM#

class pytorch_forecasting.models.nn.rnn.LSTM(input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, proj_size: int = 0, device=None, dtype=None)[来源]#
class pytorch_forecasting.models.nn.rnn.LSTM(*args, **kwargs)

基础: RNN, LSTM

可以处理零长度序列的LSTM

方法

handle_no_encoding(hidden_state, ...)

在没有编码的地方屏蔽隐藏状态。

init_hidden_state(x)

初始化一个 hidden_state。

repeat_interleave(hidden_state, n_samples)

将hidden_state复制n_samples次。

handle_no_encoding(hidden_state: Tuple[Tensor, Tensor] | Tensor, no_encoding: BoolTensor, initial_hidden_state: Tuple[Tensor, Tensor] | Tensor) Tuple[Tensor, Tensor] | Tensor[来源]#

在没有编码的地方屏蔽隐藏状态。

Parameters:
  • hidden_state (HiddenState) – 需要替换某些条目的隐藏状态

  • no_encoding (torch.BoolTensor) – 需要替换的位置

  • initial_hidden_state (HiddenState) – 用于替换的隐状态

Returns:

适当情况下传播初始隐藏状态的隐藏状态

Return type:

隐藏状态

init_hidden_state(x: Tensor) Tuple[Tensor, Tensor] | Tensor[来源]#

初始化一个 hidden_state。

Parameters:

x (torch.Tensor) – 网络输入

Returns:

默认(零类)隐藏状态

Return type:

隐藏状态

repeat_interleave(hidden_state: Tuple[Tensor, Tensor] | Tensor, n_samples: int) Tuple[Tensor, Tensor] | Tensor[来源]#

将hidden_state复制n_samples次。

Parameters:
  • hidden_state (HiddenState) – 要重复的隐状态

  • n_samples (int) – 重复次数

Returns:

重复的隐藏状态

Return type:

隐藏状态