热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

深入理解PyTorch的multinomial函数

本文详细解析了PyTorch中的torch.multinomial函数,包括其参数、功能及使用示例。该函数用于从输入张量中按权重进行采样,并返回采样的索引。

torch.multinomial(input, num_samples, replacement=False, out=None) 是 PyTorch 中的一个重要函数,用于从输入张量中按权重进行采样,并返回采样的索引。

该函数的主要参数如下:

  • input: 输入张量,表示每个元素的权重。
  • num_samples: 每一行的采样次数。
  • replacement: 是否允许有放回的采样,默认为 False(无放回)。
  • out: 可选参数,用于指定输出张量。

输入张量可以被视为一个权重矩阵,其中每个元素表示其在该行中的权重。如果某个元素的权重为 0,则在其他非零元素未被完全采样之前,该元素不会被选中。

num_samples 参数指定了每行的采样次数,该值不能超过每行的元素数量,否则会引发错误。

replacement 参数决定了采样方式是有放回还是无放回。如果设置为 True,则表示有放回采样;如果设置为 False,则表示无放回采样。

以下是一些官方示例:

>>> weights = torch.tensor([0, 10, 3, 0]) # 创建一个权重张量
>>> torch.multinomial(weights, 4)
tensor([1, 2, 1, 2])
>>> torch.multinomial(weights, 4, replacement=True)
tensor([1, 2, 1, 2])

在上述例子中,输入张量为 [0, 10, 3, 0],表示第 0 和第 3 个元素的权重为 0。因此,在其他非零元素未被完全采样之前,这些权重为 0 的元素不会被选中。

对于无放回采样(replacement=False),第一次调用 torch.multinomial(weights, 4) 时,可能的结果只有两种:[1, 2, 0, 0] 和 [2, 1, 0, 0],其中 [1, 2, 0, 0] 更常见,因为第 1 个元素的权重较大,被选中的概率更高。当第 1 和第 2 个元素被采样完后,剩下的两个权重为 0 的元素才会被选中。

对于有放回采样(replacement=True),第二次调用 torch.multinomial(weights, 4, replacement=True) 时,只会出现 1 和 2 这两个元素,因为有放回采样不会选中权重为 0 的元素。

如果输入的是二维张量,则返回的也是一个二维张量,其行数与输入张量的行数相同,列数为 num_samples,即每一行都会进行 num_samples 次采样,采样方式与一维张量相同。

参考链接:
https://blog.csdn.net/monchin/article/details/79787621


推荐阅读
author-avatar
话说的爱
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有