使用 PyTorch 加载 FITS 图像

Loading FITS images with PyTorch

我正在尝试使用 PyTorch 创建 CNN,但我的图像需要从 FITS 格式而不是传统的 .png 或 .jpeg 等格式导入。

有没有一种方法可以使用 torch.utils.data.DataLoader 轻松完成此操作,或者在源代码中是否有一个地方可以放入一个子句来在加载时处理 FITS 文件?

我查看了文档,发现最相关的东西是 ToPILImage 转换器,它可以将张量或 ndarray 转换为 PIL 图像。

目前我正在使用如下图像加载例程:

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision

batch_size = 4

transform = transforms.Compose(
                   [transforms.Resize((32,32)),
                    transforms.ToTensor(),
                    ])

trainset = dset.ImageFolder(root="Documents/Image_data",transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)

天体:http://www.astropy.org/

火炬:https://pytorch.org/

torch.utils: https://pytorch.org/docs/master/data.html

更新:也许使用 torchvision.datasets.DatasetFolder 而不是 DataLoader,在我自己的 FITS 处理程序中插入会起作用吗?

尝试使用此 class 时出现以下错误:

AttributeError: module 'torchvision.datasets' has no attribute 'DatasetFolder'

此时 torchvision 是否真的支持 DatasetFolder?

您可以使用此方法将 FITS 图像导出为 pyplot.imsave() 支持的任何格式:

from astropy.io import fits
import matplotlib.pyplot as plt

image_data = fits.getdata(r"/path/to/image.fits")
plt.imsave("/path/to/image.png", image_data, cmap="gray")

通过阅读文档和代码的一些组合,我认为您不一定要使用 ImageFolder,因为它对 FITS 一无所知。

相反,您应该尝试使用更通用的 DataSetFolder class(实际上是 ImageFolder 的父级 class)。您将向它传递一个它应该处理的扩展列表(即 ['.fits'] 和一个 "loader" 函数,它接受一个 FITS 文件,而且看起来应该 return 一个 PIL.Image.

您甚至可以按照 ImageFolder 的示例制作自己的子class。例如

class FitsFolder(DatasetFolder):

    EXTENSIONS = ['.fits']

    def __init__(self, root, transform=None, target_transform=None,
                 loader=None):
        if loader is None:
            loader = self.__fits_loader

        super(FitsFolder, self).__init__(root, loader, self.EXTENSIONS,
                                         transform=transform,
                                         target_transform=target_transform)

    @staticmethod
    def __fits_loader(filename):
        data = fits.getdata(filename)
        return Image.fromarray(data)

__fits_loader 的确切详细信息可能取决于您的 FITS 文件的详细信息。这个基本示例只使用高级 fits.getdata() 函数,它 return 是 FITS 文件中的第一个图像数组(一些 FITS 文件可能有许多包含许多图像的扩展名,或者有表格等)。所以那部分由你决定。

几周前我遇到了与@user8188120 相同的问题。从文件夹结构中读取标签时,使用@Iguananaut 的答案效果很好。如果有人偶然发现并需要从 csv 文件中读取,这也可能有效:

labels = []
transform = transforms.Compose([
    # here go your transforms
    ])


class MyFitsDataset(data.Dataset):
    def __init__(self, csv_path):
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # the rest contain the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1:])  # for multi-label
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])  # for single-label
        labels.append(self.label_arr)
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        single_image_name = self.image_arr[index]

        data = pyfits.open(single_image_name, axes=2)
        data = data[0].data.astype('float32')
        data = data.reshape(IMG_WIDTH, IMG_HEIGHT, CHANNELS)

        img = transform(data)

        # Get label(class) of the image based on the pandas column
        single_image_label = self.label_arr[index]

        return (img, single_image_label)

    def __len__(self):
        return self.data_len

这也避免了使用 DatasetFolder class,它在最新版本的 PyTorch 中仍然不可用。我希望这对某人有所帮助。