使用字典替换 NumPy 数组中的值会产生不明确的结果,这是为什么?
Substituting values in NumPy array using a dictionary is giving ambiguous results, why is that?
所以,我有一个包含一些词的数组,我正在尝试执行单热编码。
假设输入是 AI DSA DSA AI ML ML AI DS DS AI C AI ML ML C
这是我的代码:
def apply_one_hot_encoding(X):
dic = {}
k = sorted(list(set(X)))
for i in range(len(k)):
arr = ['0' for i in range(len(k))]
arr[i] = '1'
dic[k[i]] = ''.join(arr)
for i in range(len(X)):
t = dic[X[i]]
X[i] = t
return X
if __name__ == "__main__":
X = np.array(list(input().split()))
one_hot_encoded_array = apply_one_hot_encoding(X)
for i in one_hot_encoded_array:
print(*i)
现在,我希望输出如下:
1 0 0 0 0
0 0 0 1 0
0 0 1 0 0
但我得到的是:
1 0 0
0 0 1
1 0 0
如果我将 t 值附加到另一个列表并且 return 那个,它给出了正确的结果。
为什么在直接替换的情况下分配的值被修剪为仅 3 个字符?
问题是由于Numpy数组的dtype
(数据类型)引起的。
当您使用 print(X.dtype)
检查上述程序中 numy 数组的数据类型时,它显示数据类型为 <U3
,numpy 数组中的每个元素只能包含三个字符 X
.
由于输入数组包含五个类别,数组的dtype
可以通过X = np.array(list(input().split()), dtype='<U5')
变为<U5
,numpy数组中的每个元素最多可以容纳五个字符X
.
更正后的代码是,
def apply_one_hot_encoding(X):
dic = {}
k = sorted(list(set(X)))
for i in range(len(k)):
arr = ['0' for i in range(len(k))]
arr[i] = '1'
dic[k[i]] = ''.join(arr)
for i in range(len(X)):
t = dic[X[i]]
X[i] = t
return X
if __name__ == "__main__":
X = np.array(list(input().split()),dtype = '<U5')
one_hot_encoded_array = apply_one_hot_encoding(X)
for i in one_hot_encoded_array:
print(*i)
当您将值存储在单独的 numpy 数组中时,不需要上述方法,因为 numpy 会根据字符串的大小自动更改数据类型,
所以,我有一个包含一些词的数组,我正在尝试执行单热编码。
假设输入是 AI DSA DSA AI ML ML AI DS DS AI C AI ML ML C
这是我的代码:
def apply_one_hot_encoding(X):
dic = {}
k = sorted(list(set(X)))
for i in range(len(k)):
arr = ['0' for i in range(len(k))]
arr[i] = '1'
dic[k[i]] = ''.join(arr)
for i in range(len(X)):
t = dic[X[i]]
X[i] = t
return X
if __name__ == "__main__":
X = np.array(list(input().split()))
one_hot_encoded_array = apply_one_hot_encoding(X)
for i in one_hot_encoded_array:
print(*i)
现在,我希望输出如下:
1 0 0 0 0
0 0 0 1 0
0 0 1 0 0
但我得到的是:
1 0 0
0 0 1
1 0 0
如果我将 t 值附加到另一个列表并且 return 那个,它给出了正确的结果。
为什么在直接替换的情况下分配的值被修剪为仅 3 个字符?
问题是由于Numpy数组的dtype
(数据类型)引起的。
当您使用 print(X.dtype)
检查上述程序中 numy 数组的数据类型时,它显示数据类型为 <U3
,numpy 数组中的每个元素只能包含三个字符 X
.
由于输入数组包含五个类别,数组的dtype
可以通过X = np.array(list(input().split()), dtype='<U5')
变为<U5
,numpy数组中的每个元素最多可以容纳五个字符X
.
更正后的代码是,
def apply_one_hot_encoding(X):
dic = {}
k = sorted(list(set(X)))
for i in range(len(k)):
arr = ['0' for i in range(len(k))]
arr[i] = '1'
dic[k[i]] = ''.join(arr)
for i in range(len(X)):
t = dic[X[i]]
X[i] = t
return X
if __name__ == "__main__":
X = np.array(list(input().split()),dtype = '<U5')
one_hot_encoded_array = apply_one_hot_encoding(X)
for i in one_hot_encoded_array:
print(*i)
当您将值存储在单独的 numpy 数组中时,不需要上述方法,因为 numpy 会根据字符串的大小自动更改数据类型,