使用数据集的均值和标准差对图像进行归一化
Normalising images using mean and std of a dataset
我使用以下代码片段计算 cityscapes
数据集中图像的 mean
和 std
以对其进行归一化:
def compute_mean_std(dataloader):
pop_mean = []
pop_std = []
for i, (img,mask, rgb_mask) in enumerate(dataloader):
numpy_image = img.cpu().numpy()
batch_mean = np.mean(numpy_image,axis=(0,2,3))
pop_mean.append(batch_mean)
#print(batch_mean.shape)
batch_std = np.mean(numpy_image, axis=(0,2,3))
pop_std.append(batch_std)
#print(batch_std.shape)
pop_mean = np.array(pop_mean).mean(axis=0)
pop_std = np.array(pop_std).std(axis=0)
print(pop_mean.shape)
print(pop_std.shape)
return(pop_mean, pop_std)
这段代码给了我以下 mean
和 std
:
MEAN = [0.28660315, 0.32426634, 0.28302112]
STD = [0.00310452, 0.00292714, 0.00296411]
但是当我使用这些 mean
和 std
计算标准化后图像的 mean
和 std
时,它们并不接近 0
和1
.
这种方法在整个数据集上计算 mean
和 std
并对图像进行归一化是否正确?
您的公式不正确。您不能取一批值的平均值,然后取这些平均值的标准差,并期望它是整个数据集的标准差。尝试类似的东西:
total = 0.0
totalsq = 0.0
count = 0
for data, *_ in dataloader:
count += np.prod(data.shape)
total += data.sum()
totalsq += (data**2).sum()
mean = total/count
var = (totalsq/count) - (mean**2)
std = torch.sqrt(var)
我使用以下代码片段计算 cityscapes
数据集中图像的 mean
和 std
以对其进行归一化:
def compute_mean_std(dataloader):
pop_mean = []
pop_std = []
for i, (img,mask, rgb_mask) in enumerate(dataloader):
numpy_image = img.cpu().numpy()
batch_mean = np.mean(numpy_image,axis=(0,2,3))
pop_mean.append(batch_mean)
#print(batch_mean.shape)
batch_std = np.mean(numpy_image, axis=(0,2,3))
pop_std.append(batch_std)
#print(batch_std.shape)
pop_mean = np.array(pop_mean).mean(axis=0)
pop_std = np.array(pop_std).std(axis=0)
print(pop_mean.shape)
print(pop_std.shape)
return(pop_mean, pop_std)
这段代码给了我以下 mean
和 std
:
MEAN = [0.28660315, 0.32426634, 0.28302112]
STD = [0.00310452, 0.00292714, 0.00296411]
但是当我使用这些 mean
和 std
计算标准化后图像的 mean
和 std
时,它们并不接近 0
和1
.
这种方法在整个数据集上计算 mean
和 std
并对图像进行归一化是否正确?
您的公式不正确。您不能取一批值的平均值,然后取这些平均值的标准差,并期望它是整个数据集的标准差。尝试类似的东西:
total = 0.0
totalsq = 0.0
count = 0
for data, *_ in dataloader:
count += np.prod(data.shape)
total += data.sum()
totalsq += (data**2).sum()
mean = total/count
var = (totalsq/count) - (mean**2)
std = torch.sqrt(var)