使用 PyTorch 加载 FITS 图像
Loading FITS images with PyTorch
我正在尝试使用 PyTorch 创建 CNN,但我的图像需要从 FITS 格式而不是传统的 .png 或 .jpeg 等格式导入。
有没有一种方法可以使用 torch.utils.data.DataLoader 轻松完成此操作,或者在源代码中是否有一个地方可以放入一个子句来在加载时处理 FITS 文件?
我查看了文档,发现最相关的东西是 ToPILImage 转换器,它可以将张量或 ndarray 转换为 PIL 图像。
目前我正在使用如下图像加载例程:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision
batch_size = 4
transform = transforms.Compose(
[transforms.Resize((32,32)),
transforms.ToTensor(),
])
trainset = dset.ImageFolder(root="Documents/Image_data",transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)
torch.utils: https://pytorch.org/docs/master/data.html
更新:也许使用 torchvision.datasets.DatasetFolder 而不是 DataLoader,在我自己的 FITS 处理程序中插入会起作用吗?
尝试使用此 class 时出现以下错误:
AttributeError: module 'torchvision.datasets' has no attribute 'DatasetFolder'
此时 torchvision 是否真的支持 DatasetFolder?
您可以使用此方法将 FITS 图像导出为 pyplot.imsave() 支持的任何格式:
from astropy.io import fits
import matplotlib.pyplot as plt
image_data = fits.getdata(r"/path/to/image.fits")
plt.imsave("/path/to/image.png", image_data, cmap="gray")
通过阅读文档和代码的一些组合,我认为您不一定要使用 ImageFolder
,因为它对 FITS 一无所知。
相反,您应该尝试使用更通用的 DataSetFolder
class(实际上是 ImageFolder
的父级 class)。您将向它传递一个它应该处理的扩展列表(即 ['.fits']
和一个 "loader" 函数,它接受一个 FITS 文件,而且看起来应该 return 一个 PIL.Image
.
您甚至可以按照 ImageFolder
的示例制作自己的子class。例如
class FitsFolder(DatasetFolder):
EXTENSIONS = ['.fits']
def __init__(self, root, transform=None, target_transform=None,
loader=None):
if loader is None:
loader = self.__fits_loader
super(FitsFolder, self).__init__(root, loader, self.EXTENSIONS,
transform=transform,
target_transform=target_transform)
@staticmethod
def __fits_loader(filename):
data = fits.getdata(filename)
return Image.fromarray(data)
__fits_loader
的确切详细信息可能取决于您的 FITS 文件的详细信息。这个基本示例只使用高级 fits.getdata()
函数,它 return 是 FITS 文件中的第一个图像数组(一些 FITS 文件可能有许多包含许多图像的扩展名,或者有表格等)。所以那部分由你决定。
几周前我遇到了与@user8188120 相同的问题。从文件夹结构中读取标签时,使用@Iguananaut 的答案效果很好。如果有人偶然发现并需要从 csv 文件中读取,这也可能有效:
labels = []
transform = transforms.Compose([
# here go your transforms
])
class MyFitsDataset(data.Dataset):
def __init__(self, csv_path):
# Read the csv file
self.data_info = pd.read_csv(csv_path, header=None)
# First column contains the image paths
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# the rest contain the labels
self.label_arr = np.asarray(self.data_info.iloc[:, 1:]) # for multi-label
self.label_arr = np.asarray(self.data_info.iloc[:, 1]) # for single-label
labels.append(self.label_arr)
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
single_image_name = self.image_arr[index]
data = pyfits.open(single_image_name, axes=2)
data = data[0].data.astype('float32')
data = data.reshape(IMG_WIDTH, IMG_HEIGHT, CHANNELS)
img = transform(data)
# Get label(class) of the image based on the pandas column
single_image_label = self.label_arr[index]
return (img, single_image_label)
def __len__(self):
return self.data_len
这也避免了使用 DatasetFolder
class,它在最新版本的 PyTorch 中仍然不可用。我希望这对某人有所帮助。
我正在尝试使用 PyTorch 创建 CNN,但我的图像需要从 FITS 格式而不是传统的 .png 或 .jpeg 等格式导入。
有没有一种方法可以使用 torch.utils.data.DataLoader 轻松完成此操作,或者在源代码中是否有一个地方可以放入一个子句来在加载时处理 FITS 文件?
我查看了文档,发现最相关的东西是 ToPILImage 转换器,它可以将张量或 ndarray 转换为 PIL 图像。
目前我正在使用如下图像加载例程:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision
batch_size = 4
transform = transforms.Compose(
[transforms.Resize((32,32)),
transforms.ToTensor(),
])
trainset = dset.ImageFolder(root="Documents/Image_data",transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)
torch.utils: https://pytorch.org/docs/master/data.html
更新:也许使用 torchvision.datasets.DatasetFolder 而不是 DataLoader,在我自己的 FITS 处理程序中插入会起作用吗?
尝试使用此 class 时出现以下错误:
AttributeError: module 'torchvision.datasets' has no attribute 'DatasetFolder'
此时 torchvision 是否真的支持 DatasetFolder?
您可以使用此方法将 FITS 图像导出为 pyplot.imsave() 支持的任何格式:
from astropy.io import fits
import matplotlib.pyplot as plt
image_data = fits.getdata(r"/path/to/image.fits")
plt.imsave("/path/to/image.png", image_data, cmap="gray")
通过阅读文档和代码的一些组合,我认为您不一定要使用 ImageFolder
,因为它对 FITS 一无所知。
相反,您应该尝试使用更通用的 DataSetFolder
class(实际上是 ImageFolder
的父级 class)。您将向它传递一个它应该处理的扩展列表(即 ['.fits']
和一个 "loader" 函数,它接受一个 FITS 文件,而且看起来应该 return 一个 PIL.Image
.
您甚至可以按照 ImageFolder
的示例制作自己的子class。例如
class FitsFolder(DatasetFolder):
EXTENSIONS = ['.fits']
def __init__(self, root, transform=None, target_transform=None,
loader=None):
if loader is None:
loader = self.__fits_loader
super(FitsFolder, self).__init__(root, loader, self.EXTENSIONS,
transform=transform,
target_transform=target_transform)
@staticmethod
def __fits_loader(filename):
data = fits.getdata(filename)
return Image.fromarray(data)
__fits_loader
的确切详细信息可能取决于您的 FITS 文件的详细信息。这个基本示例只使用高级 fits.getdata()
函数,它 return 是 FITS 文件中的第一个图像数组(一些 FITS 文件可能有许多包含许多图像的扩展名,或者有表格等)。所以那部分由你决定。
几周前我遇到了与@user8188120 相同的问题。从文件夹结构中读取标签时,使用@Iguananaut 的答案效果很好。如果有人偶然发现并需要从 csv 文件中读取,这也可能有效:
labels = []
transform = transforms.Compose([
# here go your transforms
])
class MyFitsDataset(data.Dataset):
def __init__(self, csv_path):
# Read the csv file
self.data_info = pd.read_csv(csv_path, header=None)
# First column contains the image paths
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# the rest contain the labels
self.label_arr = np.asarray(self.data_info.iloc[:, 1:]) # for multi-label
self.label_arr = np.asarray(self.data_info.iloc[:, 1]) # for single-label
labels.append(self.label_arr)
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
single_image_name = self.image_arr[index]
data = pyfits.open(single_image_name, axes=2)
data = data[0].data.astype('float32')
data = data.reshape(IMG_WIDTH, IMG_HEIGHT, CHANNELS)
img = transform(data)
# Get label(class) of the image based on the pandas column
single_image_label = self.label_arr[index]
return (img, single_image_label)
def __len__(self):
return self.data_len
这也避免了使用 DatasetFolder
class,它在最新版本的 PyTorch 中仍然不可用。我希望这对某人有所帮助。