如何反转一个热编码?
How to reverse onehot-encoding?
我有一些标签看起来像这样:'ABC1234'
。我使用以下代码对它们进行了 onehot 编码:
from numpy import argmax
# define input string
def my_onehot_encoded(label):
# define universe of possible input values
characters = '0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ'
# define a mapping of chars to integers
char_to_int = dict((c, i) for i, c in enumerate(characters))
int_to_char = dict((i, c) for i, c in enumerate(characters))
# integer encode input data
integer_encoded = [char_to_int[char] for char in label]
# one hot encode
onehot_encoded = list()
for value in integer_encoded:
character = [0 for _ in range(len(characters))]
character[value] = 1
onehot_encoded.append(character)
return onehot_encoded
我得到形状为 (7, 35)
的 onehot 编码标签。
然后我创建了一个模型来预测标签。我使用这段代码来预测一张图片的标签:
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
img = imread('/content/gdrive/My Drive/2017-IWT4S-CarsReId_LP-dataset/2_4.png')
img = resize(img,(224,224))
img = img*1./255
img = np.reshape(img,[1,224,224,3])
classes = model.predict(img)
np.argmax(classes, axis=2)
这给了我一个预测值 类 的向量。在标签的情况下:array([[ 10, 11, 12, 1, 2, 3, 4]])
我现在想要一个函数将这个数组解码为我的原始字符串标签 'ABC1234'
。我该怎么做?
使用characters
并遍历它得到预测输出的索引值:
characters = '0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ'
output = [[ 10, 11, 12, 1, 2, 3, 4]]
res = []
for i in output:
res_str = ''
for j in i:
res_str = res_str + str(characters[j])
res.append(res_str)
res
'''
Output:
['ABC1234']
'''
使用这样的嵌套循环并逐个添加似乎效率很低。
一个简单的解决方案是只使用整个输出行作为索引。
characters = '0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ'
characters = np.array(list(characters))
outputs = np.array([[10, 11, 12, 1, 2, 3, 4]])
labels = [''.join(characters[row]) for row in outputs]
# ['ABC1234']
我有一些标签看起来像这样:'ABC1234'
。我使用以下代码对它们进行了 onehot 编码:
from numpy import argmax
# define input string
def my_onehot_encoded(label):
# define universe of possible input values
characters = '0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ'
# define a mapping of chars to integers
char_to_int = dict((c, i) for i, c in enumerate(characters))
int_to_char = dict((i, c) for i, c in enumerate(characters))
# integer encode input data
integer_encoded = [char_to_int[char] for char in label]
# one hot encode
onehot_encoded = list()
for value in integer_encoded:
character = [0 for _ in range(len(characters))]
character[value] = 1
onehot_encoded.append(character)
return onehot_encoded
我得到形状为 (7, 35)
的 onehot 编码标签。
然后我创建了一个模型来预测标签。我使用这段代码来预测一张图片的标签:
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
img = imread('/content/gdrive/My Drive/2017-IWT4S-CarsReId_LP-dataset/2_4.png')
img = resize(img,(224,224))
img = img*1./255
img = np.reshape(img,[1,224,224,3])
classes = model.predict(img)
np.argmax(classes, axis=2)
这给了我一个预测值 类 的向量。在标签的情况下:array([[ 10, 11, 12, 1, 2, 3, 4]])
我现在想要一个函数将这个数组解码为我的原始字符串标签 'ABC1234'
。我该怎么做?
使用characters
并遍历它得到预测输出的索引值:
characters = '0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ'
output = [[ 10, 11, 12, 1, 2, 3, 4]]
res = []
for i in output:
res_str = ''
for j in i:
res_str = res_str + str(characters[j])
res.append(res_str)
res
'''
Output:
['ABC1234']
'''
使用这样的嵌套循环并逐个添加似乎效率很低。
一个简单的解决方案是只使用整个输出行作为索引。
characters = '0123456789ABCDEFGHIJKLMNPQRSTUVWXYZ'
characters = np.array(list(characters))
outputs = np.array([[10, 11, 12, 1, 2, 3, 4]])
labels = [''.join(characters[row]) for row in outputs]
# ['ABC1234']