作者:话说的爱 | 来源:互联网 | 2024-11-16 15:40
torch.multinomial(input, num_samples, replacement=False, out=None)
是 PyTorch 中的一个重要函数,用于从输入张量中按权重进行采样,并返回采样的索引。
该函数的主要参数如下:
input
: 输入张量,表示每个元素的权重。
num_samples
: 每一行的采样次数。
replacement
: 是否允许有放回的采样,默认为 False(无放回)。
out
: 可选参数,用于指定输出张量。
输入张量可以被视为一个权重矩阵,其中每个元素表示其在该行中的权重。如果某个元素的权重为 0,则在其他非零元素未被完全采样之前,该元素不会被选中。
num_samples 参数指定了每行的采样次数,该值不能超过每行的元素数量,否则会引发错误。
replacement 参数决定了采样方式是有放回还是无放回。如果设置为 True,则表示有放回采样;如果设置为 False,则表示无放回采样。
以下是一些官方示例:
>>> weights = torch.tensor([0, 10, 3, 0]) # 创建一个权重张量
>>> torch.multinomial(weights, 4)
tensor([1, 2, 1, 2])
>>> torch.multinomial(weights, 4, replacement=True)
tensor([1, 2, 1, 2])
在上述例子中,输入张量为 [0, 10, 3, 0],表示第 0 和第 3 个元素的权重为 0。因此,在其他非零元素未被完全采样之前,这些权重为 0 的元素不会被选中。
对于无放回采样(replacement=False
),第一次调用 torch.multinomial(weights, 4)
时,可能的结果只有两种:[1, 2, 0, 0] 和 [2, 1, 0, 0],其中 [1, 2, 0, 0] 更常见,因为第 1 个元素的权重较大,被选中的概率更高。当第 1 和第 2 个元素被采样完后,剩下的两个权重为 0 的元素才会被选中。
对于有放回采样(replacement=True
),第二次调用 torch.multinomial(weights, 4, replacement=True)
时,只会出现 1 和 2 这两个元素,因为有放回采样不会选中权重为 0 的元素。
如果输入的是二维张量,则返回的也是一个二维张量,其行数与输入张量的行数相同,列数为 num_samples
,即每一行都会进行 num_samples
次采样,采样方式与一维张量相同。
参考链接:
https://blog.csdn.net/monchin/article/details/79787621