gather函数中参数维度的影响
Impact of the parameter dimension in gather function
我正在尝试使用pytorch中的gather函数,但无法理解dim
参数的作用。
代码:
t = torch.Tensor([[1,2],[3,4]])
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))
输出:
1 2
3 2
[torch.FloatTensor of size 2x2]
维度设置为 1:
print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))
输出变为:
1 1
4 3
[torch.FloatTensor of size 2x2]
如何,gather
函数实际有效?
我了解了收集功能的工作原理。
t = torch.Tensor([[1,2],[3,4]])
index = torch.LongTensor([[0,0],[1,0]])
torch.gather(t, 0, index)
由于 dimension
为零,因此输出将为:
| t[index[0, 0], 0] t[index[0, 1], 1] |
| t[index[1, 0], 0] t[index[1, 1], 1] |
如果dimension
设置为1,输出将变为:
| t[0, index[0, 0]] t[0, index[0, 1]] |
| t[1, index[1, 0]] t[1, index[1, 1]] |
所以公式是:
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
参考:http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather
只需添加到现有答案,gather
的一个应用是沿指定维度收集分数。
比如我们有这样的设置:
- 3 classes 和 5 个例子
- 每个class分配一个分数,对每个例子都这样做
- Objective是收集标签
y
表示的分数
代码如下
torch.manual_seed(0)
num_examples = 5
num_classes = 3
scores = torch.randn(5, 3)
#print of scores
scores: tensor([[ 1.5410, -0.2934, -2.1788],
[ 0.5684, -1.0845, -1.3986],
[ 0.4033, 0.8380, -0.7193],
[-0.4033, -0.5966, 0.1820],
[-0.8567, 1.1006, -1.0712]])
y = torch.LongTensor([1, 2, 1, 0, 2])
res = scores.gather(1, y.view(-1, 1)).squeeze()
输出:
#print of gather results
tensor([-0.2934, -1.3986, 0.8380, -0.4033, -1.0712])
我正在尝试使用pytorch中的gather函数,但无法理解dim
参数的作用。
代码:
t = torch.Tensor([[1,2],[3,4]])
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))
输出:
1 2
3 2
[torch.FloatTensor of size 2x2]
维度设置为 1:
print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))
输出变为:
1 1
4 3
[torch.FloatTensor of size 2x2]
如何,gather
函数实际有效?
我了解了收集功能的工作原理。
t = torch.Tensor([[1,2],[3,4]])
index = torch.LongTensor([[0,0],[1,0]])
torch.gather(t, 0, index)
由于 dimension
为零,因此输出将为:
| t[index[0, 0], 0] t[index[0, 1], 1] |
| t[index[1, 0], 0] t[index[1, 1], 1] |
如果dimension
设置为1,输出将变为:
| t[0, index[0, 0]] t[0, index[0, 1]] |
| t[1, index[1, 0]] t[1, index[1, 1]] |
所以公式是:
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
参考:http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather
只需添加到现有答案,gather
的一个应用是沿指定维度收集分数。
比如我们有这样的设置:
- 3 classes 和 5 个例子
- 每个class分配一个分数,对每个例子都这样做
- Objective是收集标签
y
表示的分数
代码如下
torch.manual_seed(0)
num_examples = 5
num_classes = 3
scores = torch.randn(5, 3)
#print of scores
scores: tensor([[ 1.5410, -0.2934, -2.1788],
[ 0.5684, -1.0845, -1.3986],
[ 0.4033, 0.8380, -0.7193],
[-0.4033, -0.5966, 0.1820],
[-0.8567, 1.1006, -1.0712]])
y = torch.LongTensor([1, 2, 1, 0, 2])
res = scores.gather(1, y.view(-1, 1)).squeeze()
输出:
#print of gather results
tensor([-0.2934, -1.3986, 0.8380, -0.4033, -1.0712])