如何沿着单个pytorch张量的维度连接?
How to concatenate along a dimension of a single pytorch tensor?
我写了一个自定义的 pytorch Dataset
和 __getitem__()
函数 return 一个形状为 (250, 150)
的张量,然后我用 DataLoader
生成了一个批处理批量大小为 10 的数据。我的意图是将形状为 (2500, 150)
的批次作为这 10 个张量沿维度 0 的串联,但 DataLoader
的输出具有形状 (10, 250, 150)
。如何将 DataLoader
的输出转换为形状 (2500, 150)
作为沿维度 0 的串联?
PyTorch DataLoader 将始终在第 0 个索引处添加额外的批次维度。所以,如果你得到一个形状为 (10, 250, 150)
的张量,你可以用
简单地重塑它
# x is of shape (10, 250, 150)
x_ = x.view(-1, 150)
# x_ is of shape (2500, 150)
或者,更正确地说,您可以为数据加载器
提供自定义 整理器
def custom_collate(batch):
# each item in batch is (250, 150) as returned by __getitem__
return torch.cat(batch, 0)
dl = DataLoader(dataset, batch_size=10, collate_fn=custom_collate, ...)
这将在 dataloder 本身中创建适当大小的张量,因此不需要使用 .view()
进行任何 post 处理。
我写了一个自定义的 pytorch Dataset
和 __getitem__()
函数 return 一个形状为 (250, 150)
的张量,然后我用 DataLoader
生成了一个批处理批量大小为 10 的数据。我的意图是将形状为 (2500, 150)
的批次作为这 10 个张量沿维度 0 的串联,但 DataLoader
的输出具有形状 (10, 250, 150)
。如何将 DataLoader
的输出转换为形状 (2500, 150)
作为沿维度 0 的串联?
PyTorch DataLoader 将始终在第 0 个索引处添加额外的批次维度。所以,如果你得到一个形状为 (10, 250, 150)
的张量,你可以用
# x is of shape (10, 250, 150)
x_ = x.view(-1, 150)
# x_ is of shape (2500, 150)
或者,更正确地说,您可以为数据加载器
提供自定义 整理器def custom_collate(batch):
# each item in batch is (250, 150) as returned by __getitem__
return torch.cat(batch, 0)
dl = DataLoader(dataset, batch_size=10, collate_fn=custom_collate, ...)
这将在 dataloder 本身中创建适当大小的张量,因此不需要使用 .view()
进行任何 post 处理。