如何在列表中找到张量的索引?

How to find the index of a tensor in a list?

我想在列表 li 中找到最小张量的索引(通过某个关键函数)。 所以我做了 min 然后 li.index(min_el)。我的 MWE 建议张量以某种方式不适用于 index.

import torch
li=[torch.ones(1,1), torch.zeros(2,2)]
li.index(li[0])
0
li.index(li[1])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../local/lib/python2.7/site-packages/torch/tensor.py", line 330, in __eq__
    return self.eq(other)
RuntimeError: inconsistent tensor size at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:2679

我当然可以制作自己的索引函数,首先检查大小,然后再检查元素。例如

def index(list, element):
    for i,el in enumerate(list):
        if el.size() == element.size():
            diff = el - element
            if (1- diff.byte()).all():
               return i
    return -1

我只是想知道,为什么 index 不起作用? 是否有一种聪明的方法可以做到这一点,而不是我所缺少的?

您可以使用 enumerate 和对每个元组的第二个元素进行操作的键函数直接找到索引。例如,如果您的关键函数比较每个张量的第一个元素,您可以使用

ix, _ = min(enumerate(li), key=lambda x: x[1][0, 0])

我认为 index 对第一个元素成功的原因是 Python 可能做了一些等同于 x is value or x == value 的事情,其中​​ x 是当前元素在列表中。自 value is value 以来,失败的相等比较从未发生过。这就是 li.index(li[0]) 有效但失败的原因:

y = torch.ones(1,1)
li.index(y)