向量化函数 - 标量变量的索引无效
Vectorize a function - Invalid index to scalar variable
我是 python 的新手。
我发现了一篇关于矢量化的有趣文章,所以我开始研究它。
虽然我能够做到这一点:
def cost(a, b):
"Return a-b if a>b, otherwise return a+b"
if a > b:
return a - b
else:
return a + b
cost_vector= np.vectorize(cost)
print(z([1,2,3],[3,4,5]))
output: [4 6 8]
我不能这样做:
ww = [[1,2,3,4,5,6],[2,2,3,4,5,6],[3,2,3,4,5,6],[4,2,3,4,5,6],[5,2,3,4,5,6],[6,2,3,4,5,6]]
def cost(ww, a, b):
if a > b:
return ww[a][b]
else:
return ww[b][a]
z = np.vectorize(cost)
print(z(ww, [1,2,3], [3,4,5]))
output: IndexError: invalid index to scalar variable.
我不知道如何映射到我的数组
谢谢
您的代码的问题是 np.vectorize()
试图分解所有参数,包括 ww
。
根据 documentation 您需要通过 exclude
参数排除它,例如:
import numpy as np
ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
[4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]
def cost(ww, a, b):
if a > b:
return ww[a][b]
else:
return ww[b][a]
v_cost = np.vectorize(cost, excluded={0})
print(v_cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]
请注意,您可以在 NumPy 中执行此操作,而无需 np.vectorize()
装饰函数。
您只需要确保 ww
是一个 NumPy 数组并使用 np.where()
两次:
import numpy as np
ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
[4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]
def cost(ww, a, b):
return np.array(ww)[np.where(a > b, a, b), np.where(a > b, b, a)]
print(cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]
我是 python 的新手。 我发现了一篇关于矢量化的有趣文章,所以我开始研究它。 虽然我能够做到这一点:
def cost(a, b):
"Return a-b if a>b, otherwise return a+b"
if a > b:
return a - b
else:
return a + b
cost_vector= np.vectorize(cost)
print(z([1,2,3],[3,4,5]))
output:
[4 6 8]
我不能这样做:
ww = [[1,2,3,4,5,6],[2,2,3,4,5,6],[3,2,3,4,5,6],[4,2,3,4,5,6],[5,2,3,4,5,6],[6,2,3,4,5,6]]
def cost(ww, a, b):
if a > b:
return ww[a][b]
else:
return ww[b][a]
z = np.vectorize(cost)
print(z(ww, [1,2,3], [3,4,5]))
output:
IndexError: invalid index to scalar variable.
我不知道如何映射到我的数组
谢谢
您的代码的问题是 np.vectorize()
试图分解所有参数,包括 ww
。
根据 documentation 您需要通过 exclude
参数排除它,例如:
import numpy as np
ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
[4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]
def cost(ww, a, b):
if a > b:
return ww[a][b]
else:
return ww[b][a]
v_cost = np.vectorize(cost, excluded={0})
print(v_cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]
请注意,您可以在 NumPy 中执行此操作,而无需 np.vectorize()
装饰函数。
您只需要确保 ww
是一个 NumPy 数组并使用 np.where()
两次:
import numpy as np
ww = [[1, 2, 3, 4, 5, 6], [2, 2, 3, 4, 5, 6], [3, 2, 3, 4, 5, 6],
[4, 2, 3, 4, 5, 6], [5, 2, 3, 4, 5, 6], [6, 2, 3, 4, 5, 6]]
def cost(ww, a, b):
return np.array(ww)[np.where(a > b, a, b), np.where(a > b, b, a)]
print(cost(ww, [1, 2, 3], [3, 4, 5]))
# [2 3 4]