如何沿着单个pytorch张量的维度连接?

How to concatenate along a dimension of a single pytorch tensor?

我写了一个自定义的 pytorch Dataset__getitem__() 函数 return 一个形状为 (250, 150) 的张量,然后我用 DataLoader 生成了一个批处理批量大小为 10 的数据。我的意图是将形状为 (2500, 150) 的批次作为这 10 个张量沿维度 0 的串联,但 DataLoader 的输出具有形状 (10, 250, 150)。如何将 DataLoader 的输出转换为形状 (2500, 150) 作为沿维度 0 的串联?

PyTorch DataLoader 将始终在第 0 个索引处添加额外的批次维度。所以,如果你得到一个形状为 (10, 250, 150) 的张量,你可以用

简单地重塑它
# x is of shape (10, 250, 150)
x_ = x.view(-1, 150)
# x_ is of shape (2500, 150)

或者,更正确地说,您可以为数据加载器

提供自定义 整理器
def custom_collate(batch):
    # each item in batch is (250, 150) as returned by __getitem__
    return torch.cat(batch, 0)

dl = DataLoader(dataset, batch_size=10, collate_fn=custom_collate, ...)

这将在 dataloder 本身中创建适当大小的张量,因此不需要使用 .view() 进行任何 post 处理。