Shortcuts

torch.multinomial

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor

返回一个张量,其中每一行包含从位于张量input相应行的多项分布(更严格的定义是多元的,更多细节请参考torch.distributions.multinomial.Multinomial)中采样的num_samples个索引。

注意

input 的行不需要总和为一(在这种情况下,我们使用这些值作为权重),但必须是非负的、有限的并且具有非零总和。

索引按照每个样本的采样顺序从左到右排列(首先采样的样本放在第一列)。

如果 input 是一个向量,out 是一个大小为 num_samples 的向量。

如果 input 是一个有 m 行的矩阵,out 是一个形状为 (m×num_samples)(m \times \text{num\_samples}) 的矩阵。

如果 replacement 是 True,则样本是带放回地抽取。

如果没有,则在不放回的情况下进行抽取,这意味着当为某一行抽取一个样本索引后,该索引不能再次为该行抽取。

注意

当不放回地抽取时,num_samples 必须小于 input 中非零元素的数量(或者如果 input 是一个矩阵,则为每行中非零元素的最小数量)。

Parameters
  • 输入 (张量) – 包含概率的输入张量

  • num_samples (int) – 要抽取的样本数量

  • replacement (bool, 可选) – 是否进行有放回抽样

Keyword Arguments
  • 生成器 (torch.Generator, 可选) – 用于采样的伪随机数生成器

  • 输出 (张量, 可选) – 输出张量。

示例:

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # 创建一个权重张量
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 4) # 错误!
RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,
not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])
优云智算