numpy.apply_along_axis 截断字符串,因为它推断出错误的数据类型 '<U1'

numpy.apply_along_axis truncates strings because it infers wrong dtype '<U1'

我不知道如何return dtype U3 的字符串

我想:

  1. apply_along_axis 到 my_array

  2. 对于每一行,return一个字符串

def my_function(x):
    return x[2]
my_array = np.array([[1,1,"A"],[1,1,"BBB"], [1,1,"CCC"]])
np.apply_along_axis(my_function, axis=1, arr=my_array)

我希望输出: array(['A', 'BBB', 'CCC'], dtype='<U3') 但实际输出是 array(['A', 'B', 'C'], dtype='<U1')

因为第一个元素 ('A') 具有固定大小的 U1,每个下一个元素都被截断为 U1 ('BBB' -> 'B')。

您知道如何使用例如 dtype U3 将代码更改为字符串吗?

试试这个(虽然可能应该有更好的方法):

import numpy as np

def my_function(x):
    return np.array(x[2], dtype='<U3')

my_array = np.array([[1,1,"A"],[1,1,"BBB"], [1,1,"CCC"]])
np.apply_along_axis(my_function, axis=1, arr=my_array)

对于这个特定的用例,您可以使用切片,即

my_array[:, 2]

并完全避免 apply_along_axis。但我同意从函数的第一次应用中推断类型很麻烦。还有一个issue

顺便说一句:数组中的数字被转换为字符串,但这会导致 <U21 的类型不是最佳类型。如果你直接把它们变成字符串,你会得到 <U3.