按升序对张量列表进行排序

Sorting a tensor list in ascending order

我正在开发一个面部比较应用程序,它会给我最接近目标面部的 n 个面部。

我已经用 dlib/face_recognition 完成了这个,因为它使用了 numpy 数组,但是我现在正在尝试对 facenet/pytorch 和 运行 做同样的事情,因为它使用张量.

我已经创建了一个嵌入数据库,我正在给函数一张图片来与它们进行比较。我想要的是它从最低距离到最高距离对列表进行排序,并给我最低的 5 个左右的结果。

这是我正在处理的用于比较的代码。此时我正在给它提供一张照片并要求它与嵌入数据库进行比较。

def face_match(img_path, data_path): # img_path= location of photo, data_path= location of data.pt 
    # getting embedding matrix of the given img
    img_path = (os.getcwd()+'/1.jpg')
    img = Image.open(img_path)
    face = mtcnn(img) # returns cropped face and probability
    emb = resnet(face.unsqueeze(0)).detach() # detech is to make required gradient false

    saved_data = torch.load('data.pt') # loading data.pt file
    embedding_list = saved_data[0] # getting embedding data
    name_list = saved_data[1] # getting list of names
    dist_list = [] # list of matched distances, minimum distance is used to identify the person
    
    for idx, emb_db in enumerate(embedding_list):
        dist = torch.dist(emb, emb_db)
        dist_list.append(dist)
    
    namestodistance = list(zip(name_list,dist_list))
    
    print(namestodistance)

face_match('1.jpg', 'data.pt')

这导致我按照名称的字母顺序以 (Adam Smith, tensor(1.2123432))、Brian Smith、tensor(0.6545464) 等的形式给出所有名称以及它们与目标照片的距离。如果 'tensor' 不是每个条目的一部分我认为对它进行排序是没有问题的。我不太明白为什么将其附加到条目中。我可以通过在 dist_list 末尾添加 [0:5] 来将其减少到最好的 5 个,但我不知道如何对列表进行排序,我认为问题在于每个词中的张量条目。

我试过了 for idx, emb_db in enumerate(embedding_list): dist = torch.dist(emb, emb_db) sorteddist = torch.sort(dist) 但无论出于何种原因,这只是 returns 一个距离值,而且它不是最小的。

idx_min = dist_list.index(min(dist_list)),这很好地给出了最低值,然后使用 namelist[idx_min] 将其与名称匹配,因此给出了最佳匹配,但我想要最好的 5 个匹配顺序而不仅仅是最佳匹配。

有人能解决这个问题吗?

不幸的是,我无法测试您的代码,但对我来说,您似乎正在对 python 元组列表进行操作。您可以使用键对其进行排序:

namestodistance = [('Alice', .1), ('Bob', .3), ('Carrie', .2)]
names_top = sorted(namestodistance, key=lambda x: x[1])
print(names_top[:2])

当然你必须将key中的匿名函数修改为return一个可排序的值而不是例如torch.tensor.

这可以通过使用 key = lambda x: x[1].item() 来完成。

编辑:为了回答评论中出现的问题,我们可以稍微重构一下我们的代码。即

namestodistance = list(map(lambda x: (x[0], x[1].item()), namestodistance)
names_top = sorted(namestodistance, key=lambda x: x[1])
print(names_top[:2])