如何反转一个热编码?

How to reverse onehot-encoding?

我有一些标签看起来像这样:'ABC1234'。我使用以下代码对它们进行了 onehot 编码:

from numpy import argmax
# define input string

def my_onehot_encoded(label):
    # define universe of possible input values
    characters = '0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ'
    # define a mapping of chars to integers
    char_to_int = dict((c, i) for i, c in enumerate(characters))
    int_to_char = dict((i, c) for i, c in enumerate(characters))
    # integer encode input data
    integer_encoded = [char_to_int[char] for char in label]
    # one hot encode
    onehot_encoded = list()
    for value in integer_encoded:
        character = [0 for _ in range(len(characters))]
        character[value] = 1
        onehot_encoded.append(character)

    return onehot_encoded

我得到形状为 (7, 35) 的 onehot 编码标签。

然后我创建了一个模型来预测标签。我使用这段代码来预测一张图片的标签:

from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math

img = imread('/content/gdrive/My Drive/2017-IWT4S-CarsReId_LP-dataset/2_4.png')
img = resize(img,(224,224))
img = img*1./255
img = np.reshape(img,[1,224,224,3])

classes = model.predict(img)

np.argmax(classes, axis=2)

这给了我一个预测值 类 的向量。在标签的情况下:array([[ 10, 11, 12, 1, 2, 3, 4]]) 我现在想要一个函数将这个数组解码为我的原始字符串标签 'ABC1234'。我该怎么做?

使用characters并遍历它得到预测输出的索引值:

characters = '0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ'
output = [[ 10, 11, 12, 1, 2, 3, 4]]
res = []
for i in output:
    res_str = ''
    for j in i:
        res_str = res_str + str(characters[j])
    res.append(res_str)
res

'''
Output:
['ABC1234']
'''

使用这样的嵌套循环并逐个添加似乎效率很低。

一个简单的解决方案是只使用整个输出行作为索引。

characters = '0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ'
characters = np.array(list(characters))
outputs = np.array([[10, 11, 12, 1, 2, 3, 4]])
labels = [''.join(characters[row]) for row in outputs]
# ['ABC1234']