Shortcuts

torch.nn.functional.gumbel_softmax

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[源代码]

从Gumbel-Softmax分布中采样(链接1 链接2)并可选择离散化。

Parameters
  • logits (Tensor) – […, num_features] 未归一化的对数概率

  • tau (float) – 非负标量温度

  • hard (bool) – 如果True,返回的样本将被离散化为独热向量,但在自动求导中将被视为软样本

  • dim (int) – 计算softmax的维度。默认值:-1。

Returns

从Gumbel-Softmax分布中采样的张量,形状与logits相同。 如果hard=True,返回的样本将是一热编码,否则它们将是概率分布,在dim上求和为1。

Return type

张量

注意

此函数由于历史原因存在,可能会在未来从nn.Functional中移除。

注意

对于hard的主要技巧是执行y_hard - y_soft.detach() + y_soft

它实现了两件事: - 使输出值完全为one-hot (因为我们先加上然后减去y_soft值) - 使梯度等于y_soft的梯度 (因为我们去除了所有其他梯度)

Examples::
>>> logits = torch.randn(20, 32)
>>> # 使用重参数化技巧采样软分类:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # 使用“Straight-through”技巧采样硬分类:
>>> F.gumbel_softmax(logits, tau=1, hard=True)
优云智算