LSTM#
- class pytorch_forecasting.models.nn.rnn.LSTM(input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, proj_size: int = 0, device=None, dtype=None)[来源]#
- class pytorch_forecasting.models.nn.rnn.LSTM(*args, **kwargs)
基础:
RNN
,LSTM
可以处理零长度序列的LSTM
方法
handle_no_encoding
(hidden_state, ...)在没有编码的地方屏蔽隐藏状态。
初始化一个 hidden_state。
repeat_interleave
(hidden_state, n_samples)将hidden_state复制n_samples次。
- handle_no_encoding(hidden_state: Tuple[Tensor, Tensor] | Tensor, no_encoding: BoolTensor, initial_hidden_state: Tuple[Tensor, Tensor] | Tensor) Tuple[Tensor, Tensor] | Tensor [来源]#
在没有编码的地方屏蔽隐藏状态。
- Parameters:
hidden_state (HiddenState) – 需要替换某些条目的隐藏状态
no_encoding (torch.BoolTensor) – 需要替换的位置
initial_hidden_state (HiddenState) – 用于替换的隐状态
- Returns:
适当情况下传播初始隐藏状态的隐藏状态
- Return type:
隐藏状态
初始化一个 hidden_state。
- Parameters:
x (torch.Tensor) – 网络输入
- Returns:
默认(零类)隐藏状态
- Return type:
隐藏状态