tensor.gather(dim, index)和torch.gather(input, dim, index)两者没有本质差别。
这里挑tensor.gather(dim, index)来讲。
官网解释
中文解释:
输入dim和index,index和tensor的维度数目一样,比如都是3个维度的数组(A,B,C)这种。
dim指明你的索引是在第几维,index要求必须是和输入tensor相同维度的张量,返回的是这些索引对应的值,返回张量的size与index相同。
对于你的index中的元素值,它有自己的索引,此时要指定是某一个维度dim,将这个元素自己的索引中对应dim的维度改变为该元素值,其他维度上的值不变,然后根据这个新的索引在tensor中取索引值。
官网的3D情况说明(看完下面的举例你会觉得很清晰):
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
a = torch.randn((3,4))
print('a\n',a)
i1 = torch.tensor([[1,2],[1,1],[0,2],[1,1]])
i2 = torch.tensor([[1,2],[1,1],[0,2]])
print('ans1\n',a.gather(0,i1))
print('ans2\n',a.gather(1,i2))
output:
atensor([[-0.2655, 0.3619, -0.7515, -0.8025],[ 0.5486, 0.0390, -0.3317, -0.1171],[ 2.1218, -1.6343, -1.0830, 1.3824]])
ans1tensor([[ 0.5486, -1.6343],[ 0.5486, 0.0390],[-0.2655, -1.6343],[ 0.5486, 0.0390]])
ans2tensor([[ 0.3619, -0.7515],[ 0.0390, 0.0390],[ 2.1218, -1.0830]])
以索引i1为例,他的首行元素是[1,2],
1在这个数组中的索引是[0][0],
2在这个数组中的索引是[0][1],
选择的dim=0,所以在元素索引的0维上改变为元素值:
1的索引[0][0]->[1][0],然后在变量a中索引到了值a[1][0]=0.5486
2的索引[0][1]->[2][1],然后在变量a中索引到了值a[2][1]=-1.6343
所以这个ans1我们可以写为:
tensor([[a[1][0], a[2][1]],[a[1][0], a[1][1]],[a[0][0], a[2][1]],[a[1][0], a[1][1]]])可以对比一下刚刚的数据:
i1 = torch.tensor([[1,2], [1,1],[0,2],[1,1]])a = tensor([[-0.2655, 0.3619, -0.7515, -0.8025],[ 0.5486, 0.0390, -0.3317, -0.1171],[ 2.1218, -1.6343, -1.0830, 1.3824]])ans1 = tensor( [[ 0.5486, -1.6343],[ 0.5486, 0.0390],[-0.2655, -1.6343],[ 0.5486, 0.0390]])
若索引i2为例,他的首行元素是[1,2],
1在这个数组中的索引是[0][0],
2在这个数组中的索引是[0][1],
选择的dim=1,所以在元素索引的1维上改变为元素值:
1的索引[0][0]->[0][1],然后在变量a中索引到了值a[0][1]=0.3619
2的索引[0][1]->[0][2],然后在变量a中索引到了值a[0][2]=-0.7515
所以这个ans2我们可以写为:
tensor([[a[0][1], a[0][2]],[a[1][1], a[1][1]],[a[2][0], a[2][2]]])可以对比一下刚刚的数据:
a = tensor([[-0.2655, 0.3619, -0.7515, -0.8025],[ 0.5486, 0.0390, -0.3317, -0.1171],[ 2.1218, -1.6343, -1.0830, 1.3824]])i2 = torch.tensor([[1,2],[1,1],[0,2]])ans2 = tensor([[ 0.3619, -0.7515],[ 0.0390, 0.0390],[ 2.1218, -1.0830]])
a = torch.randn((3,5,3))
print('a\n',a)
i1 = torch.tensor([[[1,2],[1,1],[0,2],[1,1]]])
print('ans\n',a.gather(2,i1))
output:
a
tensor([[[-0.0114, -1.0284, -0.5340],[ 0.5844, 1.4223, 0.4038],[ 0.0575, 1.0408, 0.4988],[ 0.3994, -0.0080, 0.5033],[-1.3644, 0.4155, -0.6559]],[[ 1.7330, 0.2755, -0.9000],[-0.2527, 0.5685, 1.6011],[ 2.0909, -0.4134, -1.2176],[ 0.8040, 1.1630, 0.3964],[-0.6463, 0.2030, -0.8429]],[[ 1.0368, -0.7876, 1.3825],[ 1.5968, -1.1934, 0.9004],[-0.6002, -0.8837, -2.1700],[-0.9114, -0.1575, 1.3854],[-0.0854, 0.5144, 0.0932]]])
ans
tensor([[[-1.0284, -0.5340],[ 1.4223, 1.4223],[ 0.0575, 0.4988],[-0.0080, -0.0080]]])
相同的道理,他的第三行元素是[0,2],
0在这个数组中的索引是[0][2][0],
2在这个数组中的索引是[0][2][1],
选择的dim=2,所以在元素索引的2维上改变为元素值:
0的索引[0][2][0]->[0][2][0],然后在变量a中索引到了值a[0][2][0]=0.0575
2的索引[0][2][1]->[0][2][2],然后在变量a中索引到了值a[0][2][2]=0.4988
其他的话就一样了,这里不再赘述。
这时再看官网的说明,很清晰了吧(水字数)!
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
假设我现在有一个神经网络的输出是点云数据。
例如这个输出的维度是output.shape=[128,1000,3] ,
可以理解为batch_size为128,然后每个batch有1000个点,每个点的坐标是 xyz 3个特征。
然后我现在有一个索引 index.shape=[128,3000](这3000个索引肯定是有重复的,我要做的就是根据每个batch中这3000个索引把output中的值放入到这个结果result中。
遍历所以batch,将output每个batch中对应index的值赋值给result的每个batch。
result = torch.empty((128,3000,3))
for i in range(128):result[i]=output[i][index[i]]
索引取得是xyz,可以先扩展一维,然后复制3份,取dim为1,表示现在result[i][j][k] = output[i][index[i][j][k]][k]
即对于每个batch,都把每行换成了此行所索引的output中的对应行值。
index=index.unsqueeze(-1).expand(-1,-1,3) #(128,3000,3)
result = output.gather(dim=1,index=index)