使用 PYTORCH 制作个人 Dataloader

Make personnal Dataloader with PYTORCH

我正在寻找创建一个具有特定格式的个人数据加载器以使用 Pytorch 库,有人知道我该怎么做吗?我已遵循 Pytorch 教程,但找不到答案!

我需要一个 DataLoader 来生成以下格式的元组: (Bx3xHxW FloatTensor x, BxHxW LongTensor y, BxN LongTensor y_cls) 其中 x - 一批输入图像, y - 一批 groung truth seg maps, y_cls - 维数 N 的一维张量批次:N 总数 classes, y_cls[i, T] = 1 如果 class T 出现在图像 i 中,否则为 0

我希望有人能解决这个问题.. :) 谢谢!

你只需要有一个从 torch.utils.data.Dataset 派生的数据库,其中 __getitem__(index) returns 一个你想要的类型的元组 (x, y, y_cls),pytorch 会处理一切否则。

from torch.utils import data

class MyTupleDataset(data.Dataset):
  def __init__(self):
    super(MyTupleDataset, self).__init__()
    # init your dataset here...

  def __getitem__(index):
    x = torch.Tensor(3, H, W)  # batch dim is handled by the data loader
    y = torch.Tensor(H, W).to(torch.long)
    y_cls = torch.Tensor(N).to(torch.long)
    return x, y, y_cls

就是这样。为 pytorch 的 torch.utils.data.DataLoader 提供 MyTupleDataset 就大功告成了。