是否有一个函数可以在满足要求的pytorch张量中选择最大值

Is there a function for picking the max value in a pytorch tensor which follows a requirenment

我正在使用强化学习和 pytorch 在 python 中编写一个 othello 机器人。在程序中,我扫描棋盘以寻找合法的动作。 AI 应该选择正确概率最高的着法,并且根据之前的计算是合法的。在这里我需要一个像这样工作的函数:

a = torch.tensor([1,2,3,4,5])
b = torch.tensor([True, True, False, True, False], dtype=bool)
print(torch.somefunction(a,b))

输出应该是a中最大值的id,在本例中是3.这个函数存在吗?如果没有,还有其他方法吗?

假设你的张量中至少有一个非负值,你将它乘以掩码本身以删除排序中排除的值:

>>> torch.argmax(a*b)
tensor(3)

如果不是这种情况,您仍然可以通过使用 torch.where 将排除的值替换为一些会被 argmax 忽略的值(例如 a.min()):

>>> torch.where(b, a, a.min()).argmax()
tensor(3)