torch / lua:从 Tensor 中检索 n-best 子集

torch / lua: retrieving n-best subset from Tensor

我现在有下面的代码,它把每个问题的最高分数的索引存储在pred中,并把它转换成字符串。

我想对每个问题的 n 个最佳索引做同样的事情,而不仅仅是具有最高分数的单个索引,并将它们转换为字符串。我还想显示每个索引(或每个转换后的字符串)的分数。

所以 scores 必须排序,并且 pred 必须是多个 rows/columns 而不是 1 x nqs。 pred 中每个条目对应的 score 值必须是可检索的。

我对 lua/torch 语法一窍不通,如有任何帮助,我们将不胜感激。

nqs=dataset['question']:size(1);
scores=torch.Tensor(nqs,noutput);
qids=torch.LongTensor(nqs);
for i=1,nqs,batch_size do
    xlua.progress(i, nqs)
    r=math.min(i+batch_size-1,nqs);
    scores[{{i,r},{}}],qids[{{i,r}}]=forward(i,r);
end

tmp,pred=torch.max(scores,2);

answer=json_file['ix_to_ans'][tostring(pred[{i,1}])]
print(answer)

这是我的尝试,我使用一个简单的随机 scores 张量来演示它的行为:

> scores=torch.floor(torch.rand(4,10)*100)
> =scores
 9   1  90  12  62   1  62  86  46  27
 7   4   7   4  71  99  33  48  98  63
 82   5  73  84  61  92  81  99  65   9
 33  93  64  77  36  68  89  44  19  25
[torch.DoubleTensor of size 4x10]

现在,由于您想要每个问题(行)的 N 个最佳索引,让我们对张量的每一行进行排序:

> values,indexes=scores:sort(2)

现在,让我们看看 return 张量包含的内容:

> =values
  1   1   9  12  27  46  62  62  86  90
  4   4   7   7  33  48  63  71  98  99
  5   9  61  65  73  81  82  84  92  99
  19  25  33  36  44  64  68  77  89  93
  [torch.DoubleTensor of size 4x10]

> =indexes
  2   6   1   4  10   9   5   7   8   3
  2   4   1   3   7   8  10   5   9   6
  2  10   5   9   3   7   1   4   6   8
  9  10   1   5   8   3   6   4   7   2
  [torch.LongTensor of size 4x10]

如您所见,valuesi-th 行是 scoresi-th 行的排序版本(按升序),并且每一行在indexes给你相应的索引。

每个问题(行)您可以获得N最佳values/indexes

> N_best_indexes=indexes[{{},{indexes:size(2)-N+1,indexes:size(2)}}]
> N_best_values=values[{{},{values:size(2)-N+1,values:size(2)}}]

让我们看看给定示例的值,N=3:

> return N_best_indexes
 7  8  3
 5  9  6
 4  6  8
 4  7  2
[torch.LongTensor of size 4x3]

> return N_best_values
 62  86  90
 71  98  99
 84  92  99
 77  89  93
[torch.DoubleTensor of size 4x3]

所以,问题 jk-th 最佳值是 N_best_values[{{j},{values:size(2)-k+1}]],它在 scores 矩阵中的相应索引由此 row, column 给出值:

row=j
column=N_best_indexes[{{j},indexes:size(2)-k+1}}]. 

例如,第二个问题的第一个最佳值(k=1)是99,它位于2nd行和6th列。 16=]。你可以看到 values[{{2},values:size(2)}}]99indexes[{{2},{indexes:size(2)}}] 给你 6,这是 scores 矩阵中的列索引。

希望我能很好地解释我的解决方案。