讲清楚tensor.gather(dim,index)和torch.gather(input, dim, index),举例,应用
创始人
2025-05-28 08:56:46

目录

  • 前言
  • 正题
  • 举例,维度为2
  • 举例,维度为3
  • 应用
    • 常规思路
    • 用gather

前言

tensor.gather(dim, index)torch.gather(input, dim, index)两者没有本质差别。
这里挑tensor.gather(dim, index)来讲。

正题

官网解释

中文解释:
输入dim和index,index和tensor的维度数目一样,比如都是3个维度的数组(A,B,C)这种。
dim指明你的索引是在第几维,index要求必须是和输入tensor相同维度的张量,返回的是这些索引对应的值,返回张量的size与index相同。
对于你的index中的元素值,它有自己的索引,此时要指定是某一个维度dim,将这个元素自己的索引中对应dim的维度改变为该元素值,其他维度上的值不变,然后根据这个新的索引在tensor中取索引值。

官网的3D情况说明(看完下面的举例你会觉得很清晰):

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

举例,维度为2

a = torch.randn((3,4))
print('a\n',a)
i1 = torch.tensor([[1,2],[1,1],[0,2],[1,1]])
i2 = torch.tensor([[1,2],[1,1],[0,2]])
print('ans1\n',a.gather(0,i1))
print('ans2\n',a.gather(1,i2)) 
output:
atensor([[-0.2655,  0.3619, -0.7515, -0.8025],[ 0.5486,  0.0390, -0.3317, -0.1171],[ 2.1218, -1.6343, -1.0830,  1.3824]])
ans1tensor([[ 0.5486, -1.6343],[ 0.5486,  0.0390],[-0.2655, -1.6343],[ 0.5486,  0.0390]])
ans2tensor([[ 0.3619, -0.7515],[ 0.0390,  0.0390],[ 2.1218, -1.0830]])

以索引i1为例,他的首行元素是[1,2],
1在这个数组中的索引是[0][0],
2在这个数组中的索引是[0][1],
选择的dim=0,所以在元素索引的0维上改变为元素值:
1的索引[0][0]->[1][0],然后在变量a中索引到了值a[1][0]=0.5486
2的索引[0][1]->[2][1],然后在变量a中索引到了值a[2][1]=-1.6343
所以这个ans1我们可以写为:

tensor([[a[1][0], a[2][1]],[a[1][0], a[1][1]],[a[0][0], a[2][1]],[a[1][0], a[1][1]]])可以对比一下刚刚的数据:
i1 = torch.tensor([[1,2],   [1,1],[0,2],[1,1]])a = tensor([[-0.2655,  0.3619, -0.7515, -0.8025],[ 0.5486,  0.0390, -0.3317, -0.1171],[ 2.1218, -1.6343, -1.0830,  1.3824]])ans1 = tensor( [[ 0.5486, -1.6343],[ 0.5486,  0.0390],[-0.2655, -1.6343],[ 0.5486,  0.0390]])

若索引i2为例,他的首行元素是[1,2],
1在这个数组中的索引是[0][0],
2在这个数组中的索引是[0][1],
选择的dim=1,所以在元素索引的1维上改变为元素值:
1的索引[0][0]->[0][1],然后在变量a中索引到了值a[0][1]=0.3619
2的索引[0][1]->[0][2],然后在变量a中索引到了值a[0][2]=-0.7515
所以这个ans2我们可以写为:

tensor([[a[0][1], a[0][2]],[a[1][1], a[1][1]],[a[2][0], a[2][2]]])可以对比一下刚刚的数据:
a = tensor([[-0.2655,  0.3619, -0.7515, -0.8025],[ 0.5486,  0.0390, -0.3317, -0.1171],[ 2.1218, -1.6343, -1.0830,  1.3824]])i2 = torch.tensor([[1,2],[1,1],[0,2]])ans2 = tensor([[ 0.3619, -0.7515],[ 0.0390,  0.0390],[ 2.1218, -1.0830]])

举例,维度为3

a = torch.randn((3,5,3))
print('a\n',a)
i1 = torch.tensor([[[1,2],[1,1],[0,2],[1,1]]])
print('ans\n',a.gather(2,i1))
output:
a
tensor([[[-0.0114, -1.0284, -0.5340],[ 0.5844,  1.4223,  0.4038],[ 0.0575,  1.0408,  0.4988],[ 0.3994, -0.0080,  0.5033],[-1.3644,  0.4155, -0.6559]],[[ 1.7330,  0.2755, -0.9000],[-0.2527,  0.5685,  1.6011],[ 2.0909, -0.4134, -1.2176],[ 0.8040,  1.1630,  0.3964],[-0.6463,  0.2030, -0.8429]],[[ 1.0368, -0.7876,  1.3825],[ 1.5968, -1.1934,  0.9004],[-0.6002, -0.8837, -2.1700],[-0.9114, -0.1575,  1.3854],[-0.0854,  0.5144,  0.0932]]])
ans
tensor([[[-1.0284, -0.5340],[ 1.4223,  1.4223],[ 0.0575,  0.4988],[-0.0080, -0.0080]]])

相同的道理,他的第三行元素是[0,2],
0在这个数组中的索引是[0][2][0],
2在这个数组中的索引是[0][2][1],
选择的dim=2,所以在元素索引的2维上改变为元素值:
0的索引[0][2][0]->[0][2][0],然后在变量a中索引到了值a[0][2][0]=0.0575
2的索引[0][2][1]->[0][2][2],然后在变量a中索引到了值a[0][2][2]=0.4988
其他的话就一样了,这里不再赘述。

这时再看官网的说明,很清晰了吧(水字数)!

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

应用

假设我现在有一个神经网络的输出是点云数据。
例如这个输出的维度是output.shape=[128,1000,3] ,
可以理解为batch_size为128,然后每个batch有1000个点,每个点的坐标是 xyz 3个特征。
然后我现在有一个索引 index.shape=[128,3000](这3000个索引肯定是有重复的,我要做的就是根据每个batch中这3000个索引把output中的值放入到这个结果result中。

常规思路

遍历所以batch,将output每个batch中对应index的值赋值给result的每个batch。

result = torch.empty((128,3000,3))
for i in range(128):result[i]=output[i][index[i]]

用gather

索引取得是xyz,可以先扩展一维,然后复制3份,取dim为1,表示现在result[i][j][k] = output[i][index[i][j][k]][k]
即对于每个batch,都把每行换成了此行所索引的output中的对应行值。

index=index.unsqueeze(-1).expand(-1,-1,3) #(128,3000,3)
result = output.gather(dim=1,index=index)

相关内容

热门资讯

ISO/TC 228年会首次在...   新华社杭州5月16日电(记者段菁菁)5月11日至16日,国际标准化组织旅游及相关服务技术委员会(...
一斤等于12只鸡,每天吃2颗“... 一斤等于12只鸡,每天吃2颗“纯阳果”,升阳补肾气! 一斤等于12只鸡,下面分享几道“纯阳果”的食...
原创 老... 天气越来越热,整个人像被抽走了力气,看啥都没胃口。小时候常听家里老人念叨:“夏天吃三瓜,不把医生找。...
原创 “... 夏季天热出汗多,身体里的钾元素也跟着悄悄溜走,人呀,就容易觉得身子发沉、没精神。老话说“夏补钾,体不...
咖啡店“长”满小学生,全靠这新... 一代人有一代人的“蛋”要领。(图源:小红书用户 风惊云) 文 | 茨圆 出品 | Vista天下次元...