是否有一个函数可以在满足要求的pytorch张量中选择最大值
Is there a function for picking the max value in a pytorch tensor which follows a requirenment
我正在使用强化学习和 pytorch 在 python 中编写一个 othello 机器人。在程序中,我扫描棋盘以寻找合法的动作。 AI 应该选择正确概率最高的着法,并且根据之前的计算是合法的。在这里我需要一个像这样工作的函数:
a = torch.tensor([1,2,3,4,5])
b = torch.tensor([True, True, False, True, False], dtype=bool)
print(torch.somefunction(a,b))
输出应该是a中最大值的id,在本例中是3.这个函数存在吗?如果没有,还有其他方法吗?
假设你的张量中至少有一个非负值,你将它乘以掩码本身以删除排序中排除的值:
>>> torch.argmax(a*b)
tensor(3)
如果不是这种情况,您仍然可以通过使用 torch.where
将排除的值替换为一些会被 argmax 忽略的值(例如 a.min()
):
>>> torch.where(b, a, a.min()).argmax()
tensor(3)
我正在使用强化学习和 pytorch 在 python 中编写一个 othello 机器人。在程序中,我扫描棋盘以寻找合法的动作。 AI 应该选择正确概率最高的着法,并且根据之前的计算是合法的。在这里我需要一个像这样工作的函数:
a = torch.tensor([1,2,3,4,5])
b = torch.tensor([True, True, False, True, False], dtype=bool)
print(torch.somefunction(a,b))
输出应该是a中最大值的id,在本例中是3.这个函数存在吗?如果没有,还有其他方法吗?
假设你的张量中至少有一个非负值,你将它乘以掩码本身以删除排序中排除的值:
>>> torch.argmax(a*b)
tensor(3)
如果不是这种情况,您仍然可以通过使用 torch.where
将排除的值替换为一些会被 argmax 忽略的值(例如 a.min()
):
>>> torch.where(b, a, a.min()).argmax()
tensor(3)