向量化函数 - 标量变量的索引无效

Vectorize a function - Invalid index to scalar variable

我是 python 的新手。 我发现了一篇关于矢量化的有趣文章,所以我开始研究它。 虽然我能够做到这一点:

 def cost(a, b):
    "Return a-b if a>b, otherwise return a+b"
    if a > b:
        return a - b
    else:
        return a + b

cost_vector= np.vectorize(cost)
print(z([1,2,3],[3,4,5]))

output: [4 6 8]

我不能这样做:

ww = [[1,2,3,4,5,6],[2,2,3,4,5,6],[3,2,3,4,5,6],[4,2,3,4,5,6],[5,2,3,4,5,6],[6,2,3,4,5,6]]

def cost(ww, a, b):
    if a > b:
        return ww[a][b]
    else:
        return ww[b][a]

z = np.vectorize(cost)
print(z(ww, [1,2,3], [3,4,5]))

output: IndexError: invalid index to scalar variable.

我不知道如何映射到我的数组

谢谢

您的代码的问题是 np.vectorize() 试图分解所有参数,包括 ww。 根据 documentation 您需要通过 exclude 参数排除它,例如:

import numpy as np


ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
      [4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]


def cost(ww, a, b):
    if a > b:
        return ww[a][b]
    else:
        return ww[b][a]


v_cost = np.vectorize(cost, excluded={0})
print(v_cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]

请注意,您可以在 NumPy 中执行此操作,而无需 np.vectorize() 装饰函数。 您只需要确保 ww 是一个 NumPy 数组并使用 np.where() 两次:

import numpy as np


ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
      [4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]


def cost(ww, a, b):
    return np.array(ww)[np.where(a > b, a, b), np.where(a > b, b, a)]


print(cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]