PyTorch:如何在随机裁剪图像时对张量进行归一化?

PyTorch: How to normalize a tensor when the image is cropped randomly?

假设我们正在使用 CIFAR-10 dataset 并且我们想要应用一些数据扩充并另外对张量进行归一化。这是一些可重现的代码

from torchvision import transforms, datasets
import matplotlib.pyplot as plt
trafo = transforms.Compose([transforms.Pad(padding = 4, fill = 0, padding_mode = "constant"), 
                            transforms.RandomHorizontalFlip(p=0.5),
                            transforms.RandomCrop(size = (32, 32)), 
                            transforms.ToTensor(), 
                            transforms.Normalize(mean = (0.0, 0.0, 0.0), std = (1.0, 1.0, 1.0))]
                          )

cifar10_full = datasets.CIFAR10(root = "CIFAR-10", train = True, transform = trafo, target_transform = None, download = True)

到目前为止我选择的归一化对张量没有任何作用,因为我将 meanstd 分别设置为 01。根据 torchvision.transforms.Normalize 的文档,提供的均值和标准差适用于输入的每个通道。但是,问题是由于某些随机翻转和裁剪均值,我无法计算每个通道的均值。因此,我的想法大致是这样的

trafo_1 = transforms.Compose([transforms.Pad(padding = 4, fill = 0, padding_mode = "constant"), 
                            transforms.RandomHorizontalFlip(p=0.5),
                            transforms.RandomCrop(size = (32, 32)), 
                            transforms.ToTensor() 
                          )

cifar10_full = datasets.CIFAR10(root = "CIFAR-10", train = True, transform = trafo_1, target_transform = None, download = True)

现在我可以计算输入的每个通道的平均值,然后我想再次对张量进行归一化。 但是,我不能简单地使用 transforms.Normalize(),因为 cifar10_full 不再是原始数据集,但是我该如何继续呢?(一种解决方案是简单地修复随机生成器的种子,即使用 torch.manual_seed(0),但我现在想避免这种情况...)

均值和标准差不是针对每个张量,而是来自整个数据集。您尝试做什么并不重要,您只需要一个足以表示整个数据的比例,没有确切的均值或标准差,这些都是随机操作,只需使用均值和标准差来自实际数据,这几乎是标准。

首先,尝试计算数据集的均值和标准差(尝试随机抽样),并将其用于归一化。

# Calculate the mean, std of the complete dataset
import glob
import cv2
import numpy as np 
import tqdm
import random

# calculating 3 channel mean and std for image dataset

means = np.array([0, 0, 0], dtype=np.float32)
stds = np.array([0, 0, 0], dtype=np.float32)
total_images = 0
randomly_sample = 5000
for f in tqdm.tqdm(random.sample(glob.glob("dataset_path/**.jpg", recursive = True), randomly_sample)):
    img = cv2.imread(f)
    means += img.mean(axis=(0,1))
    stds += img.std(axis=(0,1))
    total_images += 1
means = means / (total_images * 255.)
stds = stds / (total_images * 255.)
print("Total images: ", total_images)
print("Means: ", means)
print("Stds: ", stds)

只是一个简单的场景,您是否认为在实际测试或推理中您的图像也会以这种方式增强,可能不会,您将拥有与数据的干净版本的均值和标准差密切匹配的干净图像, 所以计算 mean 和 std 是没有用的(你可以取几个随机样本),除非你想应用 TTA。

如果你也想应用 TTA,那么你可以继续 运行 对图像进行一些增强,进行随机采样并取这些图像的平均值和标准差。