Python 中张量的全变正则化

Total Variation Regularization for Tensors in Python

Formula

您好,我正在尝试为张量或更准确的多通道图像实现总变差函数。我发现对于上面的Total Variation(图中),有这样的源代码:

def compute_total_variation_loss(img, weight):      
    tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
    tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()    
    return weight * (tv_h + tv_w)

因为,我是 python 的初学者,我不明白索引是如何在图像中引用 i 和 j 的。我还想为 c 添加总变化(除了 i 和 j),但我不知道哪个索引指的是 c。

或者为了更简洁,如何在python中写出以下等式: enter image description here

此函数假定批处理图像。所以 img 是一个维度为 (B, C, H, W) 的 4 维张量(B 是批次中的图像数量,C 颜色通道的数量,H高度和 W 宽度)。

因此,img[0, 1, 2, 3]是第一个图像中第二种颜色(RGB中的绿色)的像素(2, 3)

在Python(以及 Numpy 和 PyTorch)中,可以使用符号 i:j 选择 切片 元素,这意味着元素 i, i + 1, i + 2, ..., j - 1 被选中。在您的示例中,: 表示 所有元素 1: 表示 除第一个 :-1 之外的所有元素表示 除最后一个 之外的所有元素(负索引向后检索元素)。请参考“NumPy 中的切片”教程。

所以 img[:,:,1:,:] - img[:,:,:-1,:] 相当于(一批)图像减去它们本身垂直移动一个像素,或者,在你的符号中 X(i + 1, j, k) - X(i, j, k)。然后将张量平方 (.pow(2)) 并求和 (.sum())。请注意,在这种情况下,总和也在批次上,因此您收到的是批次的总变化,而不是每个图像的总变化。