如何使用矢量化快速解码单热编码的 NumPy 矩阵?

How do I decode a one-hot encoded NumPy matrix in a fast manner using vectorization?

给定一个形状为 (height, width) 的图像矩阵,其值在 uint8 范围内,它被单热编码(转换为分类)为 (height, width, n) 的形状,其中 n 是可能类别的数量,在本例中为 3,导致 (height, width, 3) 的形状,我想撤消分类转换并获得 (height, width) 的原始形状。以下解决方案有效,但可以更快:

def decode(image):
    image = image

    height = image.shape[0]
    width = image.shape[1]

    decoded_image = numpy.ndarray(shape=(height, width), dtype=numpy.uint8)

    for i in range(0, height):
        for j in range(0, width):
            decoded_image[i][j] = numpy.argmax(image[i][j])

    return decoded_image

我想要一个解决方案,使用 NumPy vectorization,不需要较慢的 Python for loop

感谢您的任何建议。

您似乎想对数组的最后一个维度进行缩减,尤其是 numpy.argmax。幸运的是,这个 numpy 函数接受一个 axis 关键字,所以你应该能够在一次调用中完成同样的事情:

decoded_image = numpy.argmax(image, axis=2)