torch.Tensor.scatter_(dim, index, src, reduce=None)
理解scatter操作:
tensor_A.scatter_(dim, index, tensor_B)
: tensor_B的每个元素,都按照 index 被scatter(可以理解为填充)到目标tensor_A中。
(1) index和源tensor_B维度一致;
(2) tensor_A一般是全零的张量,其某些特定位置的值由 tensor_B 中的值填充。
(3) 注意如何根据index选取tensor_B中的值:
对于2-D tensor:
if dim=0, tensor_A[index[i][j]][j] = tensor_B[i][j];if dim=1, tensor_A[i][index[i][j]] = tensor_B[i][j];
对于3-D tensor:
if dim = 0,tensor_A[index[i][j][k]][j][k] = tensor_B[i][j][k]
if dim = 1,tensor_A[i][index[i][j][k]][k] = tensor_B[i][j][k]
if dim = 2,tensor_A[i][j][index[i][j][k]] = tensor_B[i][j][k]
举例:
如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!