第一次指定拼接的维度 dim =2,结果的维度是 [2, 3, 3]。后面指定拼接的维度 dim =0,由于原来的 tensor 已经有了维度 0,因此会把 tensor 往后移动一个维度变为 [1,2,3],再拼接变为 [3,2,3]。
切分
torch.chunk()
torch.chunk(input, chunks, dim=0)
功能:将张量按照维度 dim 进行平均切分。若不能整除,则最后一份张量小于其他张量。
input: 要切分的张量
chunks: 要切分的份数
dim: 要切分的维度
代码示例:
a = torch.ones((2, 7)) # 7 list_of_tensors = torch.chunk(a, dim=1, chunks=3) # 3 for idx, t in enumerate(list_of_tensors): print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
输出为:
第1个张量:tensor([[1., 1., 1.], [1., 1., 1.]]), shape is torch.Size([2, 3]) 第2个张量:tensor([[1., 1., 1.], [1., 1., 1.]]), shape is torch.Size([2, 3]) 第3个张量:tensor([[1.], [1.]]), shape is torch.Size([2, 1])
split_size_or_sections: 为 int 时,表示每一份的长度,如果不能被整除,则最后一份张量小于其他张量;为 list 时,按照 list 元素作为每一个分量的长度切分。如果 list 元素之和不等于切分维度 (dim) 的值,就会报错。
dim: 要切分的维度
代码示例:
t = torch.ones((2, 5)) list_of_tensors = torch.split(t, [2, 1, 2], dim=1) for idx, t in enumerate(list_of_tensors): print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
结果为:
第1个张量:tensor([[1., 1.], [1., 1.]]), shape is torch.Size([2, 2]) 第2个张量:tensor([[1.], [1.]]), shape is torch.Size([2, 1]) 第3个张量:tensor([[1., 1.], [1., 1.]]), shape is torch.Size([2, 2])
t = torch.randint(0, 9, size=(3, 3)) mask = t.le(5) # ge is mean greater than or equal/ gt: greater than le lt # 取出大于 5 的数 t_select = torch.masked_select(t, mask) print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))
#把 c * h * w 变换为 h * w * c t = torch.rand((2, 3, 4)) t_transpose = torch.transpose(t, dim0=1, dim1=2) # c*h*w h*w*c print("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))
结果为:
t shape:torch.Size([2, 3, 4]) t_transpose shape: torch.Size([2, 4, 3])