Pytorch:'float' 和 'function' 实例之间不支持“<=”

Pytorch: '<=' not supported between instances of 'float' and 'function'

我正在尝试计算并集交集 (IOU) 分数。这是我的代码实现,效果很好。

def IoU(predict: torch.Tensor, target: torch.Tensor):

    i = (predict & target).float().sum()
    u = (predict | target).float().sum()
    x = i/u
    IOU = x.item()

    return IoU

但是当我 运行 我的单元测试时:

def test_IoU1():
    pred = torch.tensor([[1, 0], [1, 0]])
    target = torch.tensor([[1, 0], [1, 1]])
    
    iou = IoU(pred,target)
    
    assert 0.66 <= iou
    assert iou <= 2/3

我得到:

 TypeError: '<=' not supported between instances of 'float' and 'function'

如何在不更改单元测试的情况下解决此问题?谢谢

在这个函数中

def IoU(predict: torch.Tensor, target: torch.Tensor):

    i = (predict & target).float().sum()
    u = (predict | target).float().sum()
    x = i/u
    IOU = x.item()
    
    return IoU

您正在 returning IoU 函数名称,我想您需要 return IOU。所以正确的方法是 -

def IoU(predict: torch.Tensor, target: torch.Tensor):

    i = (predict & target).float().sum()
    u = (predict | target).float().sum()
    x = i/u
    IOU = x.item()
    
    return IOU