PyTorch DataLoader 为 TorchVision MNIST 添加额外维度

PyTorch DataLoader adding extra dimension for TorchVision MNIST

我对 PyTorch 还很陌生,一直在尝试使用 DataLoader class。 当我尝试加载 MNIST 数据集时,DataLoader 似乎在批处理维度之后添加了一个附加维度。我不确定是什么导致了这种情况。

import torch
from torchvision.datasets import MNIST
from torchvision import transforms

if __name__ == '__main__':
    mnist_train = MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
    first_x = mnist_train.data[0]
    print(first_x.shape)  # expect to see [28, 28], actual [28, 28]

    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=200)
    batch_x, batch_y = next(iter(train_loader))  # get first batch
    print(batch_x.shape)  # expect to see [200, 28, 28], actual [200, 1, 28, 28]
    # Where is the extra dimension of 1 from?

任何人都可以阐明这个问题吗?

我猜这是输入图像的通道数。所以基本上就是

batch_x.shape = Batch-size, No of channels, Height of the image, Width of the image