Python 中张量的全变正则化
Total Variation Regularization for Tensors in Python
Formula
您好,我正在尝试为张量或更准确的多通道图像实现总变差函数。我发现对于上面的Total Variation(图中),有这样的源代码:
def compute_total_variation_loss(img, weight):
tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()
return weight * (tv_h + tv_w)
因为,我是 python 的初学者,我不明白索引是如何在图像中引用 i 和 j 的。我还想为 c 添加总变化(除了 i 和 j),但我不知道哪个索引指的是 c。
或者为了更简洁,如何在python中写出以下等式:
enter image description here
此函数假定批处理图像。所以 img
是一个维度为 (B, C, H, W)
的 4 维张量(B
是批次中的图像数量,C
颜色通道的数量,H
高度和 W
宽度)。
因此,img[0, 1, 2, 3]
是第一个图像中第二种颜色(RGB中的绿色)的像素(2, 3)
。
在Python(以及 Numpy 和 PyTorch)中,可以使用符号 i:j
选择 切片 元素,这意味着元素 i, i + 1, i + 2, ..., j - 1
被选中。在您的示例中,:
表示 所有元素 ,1:
表示 除第一个 和 :-1
之外的所有元素表示 除最后一个 之外的所有元素(负索引向后检索元素)。请参考“NumPy 中的切片”教程。
所以 img[:,:,1:,:] - img[:,:,:-1,:]
相当于(一批)图像减去它们本身垂直移动一个像素,或者,在你的符号中 X(i + 1, j, k) - X(i, j, k)
。然后将张量平方 (.pow(2)
) 并求和 (.sum()
)。请注意,在这种情况下,总和也在批次上,因此您收到的是批次的总变化,而不是每个图像的总变化。
Formula
您好,我正在尝试为张量或更准确的多通道图像实现总变差函数。我发现对于上面的Total Variation(图中),有这样的源代码:
def compute_total_variation_loss(img, weight):
tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()
return weight * (tv_h + tv_w)
因为,我是 python 的初学者,我不明白索引是如何在图像中引用 i 和 j 的。我还想为 c 添加总变化(除了 i 和 j),但我不知道哪个索引指的是 c。
或者为了更简洁,如何在python中写出以下等式: enter image description here
此函数假定批处理图像。所以 img
是一个维度为 (B, C, H, W)
的 4 维张量(B
是批次中的图像数量,C
颜色通道的数量,H
高度和 W
宽度)。
因此,img[0, 1, 2, 3]
是第一个图像中第二种颜色(RGB中的绿色)的像素(2, 3)
。
在Python(以及 Numpy 和 PyTorch)中,可以使用符号 i:j
选择 切片 元素,这意味着元素 i, i + 1, i + 2, ..., j - 1
被选中。在您的示例中,:
表示 所有元素 ,1:
表示 除第一个 和 :-1
之外的所有元素表示 除最后一个 之外的所有元素(负索引向后检索元素)。请参考“NumPy 中的切片”教程。
所以 img[:,:,1:,:] - img[:,:,:-1,:]
相当于(一批)图像减去它们本身垂直移动一个像素,或者,在你的符号中 X(i + 1, j, k) - X(i, j, k)
。然后将张量平方 (.pow(2)
) 并求和 (.sum()
)。请注意,在这种情况下,总和也在批次上,因此您收到的是批次的总变化,而不是每个图像的总变化。