如何使用 torch.FloatTensor 的 if 语句 PyTorch

How to use if statement PyTorch using torch.FloatTensor

我正在尝试在我的 PyTorch 代码中使用 if 语句,使用 torch.FloatTensor 作为数据类型,以将其加速到 GPU.

这是我的代码:

import torch
import time
def fitness(x):
     return torch.pow(x, 2)

def velocity(v, gxbest, pxbest, pybest, x, pop):
     return torch.rand(pop).type(dtype)*v + \
            torch.rand(pop).type(dtype)*(pxbest - x) + \
            torch.rand(pop).type(dtype)*(gxbest.expand(x.size(0)) - x)

dtype = torch.cuda.FloatTensor 
def main():

    pop, xmax, xmin, niter = 300000, 50, -50, 100
    v                      = torch.rand(pop).type(dtype)
    x                      = (xmax-xmin)*torch.rand(pop).type(dtype)+xmin
    y                      = fitness(x)
    [miny, indexminy]      = y.min(0)
    gxbest                 = x[indexminy] 
    pxbest                 = x
    pybest                 = y

    for K in range(niter):

        vnext = velocity(v, gxbest, pxbest, pybest, x, pop)

        xnext = x + vnext
        ynext = fitness(x)
        [minynext, indexminynext]  = ynext.min(0)

        if (minynext < miny):
            miny   = minynext
            gxbest = xnext[indexminynext]

        indexpbest         = (ynext < pybest)
        pxbest[indexpbest] = xnext[indexpbest]
        pybest[indexpbest] = ynext[indexpbest]
        x                  = xnext
        v                  = vnext
main()

不幸的是,它不起作用。它给我一条错误消息,我无法弄清楚是什么问题。

RuntimeError: bool value of non-empty torch.cuda.ByteTensor objects is ambiguous

如何在 PyTorch 中使用 if?我试图将 cuda.Tensor 转换为一个 numpy 数组,但它也不起作用。

  minynext = minynext.cpu().numpy()
  miny = miny.cpu().numpy()

PS:我是否可以按照 efficient/faster 的方式编写代码?或者我应该改变一些东西以获得更快的结果?

当你比较 pyTorch 张量时,输出通常是 ByteTensor。此数据类型不适用于 if 语句。

改变if里面的条件:

if (minynext[0] < miny[0])

如果您查看以下简单示例:

import torch

a = torch.LongTensor([1])
b = torch.LongTensor([5])

print(a > b)

输出:

 0
[torch.ByteTensor of size 1]

比较张量 ab 结果是 torch.ByteTensor,这显然不等同于 boolean。因此,您可以执行以下操作。

print(a[0] > b[0]) # False

因此,您应该按如下方式更改您的 if 条件。

if (minynext[0] < miny[0])