火炬挤压和批量尺寸
Torch squeeze and the batch dimension
这里有人知道 torch.squeeze
函数是否尊重批处理(例如第一个)维度吗?从一些内联代码看来它不是..但也许其他人比我更了解内部工作原理。
顺便说一句,潜在的问题是我有形状为 (n_batch, channel, x, y, 1)
的张量。我想用一个简单的函数删除最后一个维度,这样我就得到了 (n_batch, channel, x, y)
.
的形状
当然可以进行整形,甚至可以选择最后一个轴。但是我想将这个功能嵌入到一个层中,以便我可以轻松地将它添加到 ModuleList
或 Sequence
对象中。
编辑:刚刚发现对于 Tensorflow (2.5.0),函数 tf.linalg.diag
确实尊重批量维度。仅供参考,它可能因您使用的功能而异
没有!挤压不尊重批量维度。如果您在批量维度可能为 1 时使用挤压,这可能是一个错误来源。经验法则是只有 类 和 torch.nn 中的函数默认遵守批量维度。
过去这让我很头疼。 我建议使用 reshape
或仅将 squeeze
与可选的输入维度参数一起使用。 在您的情况下,您可以使用 .squeeze(4)
仅删除最后一个维度.这样就不会发生意外。没有输入维度的挤压导致了我意想不到的结果,特别是当
- 模型的输入形状可能不同
- 批量大小可能不同
nn.DataParallel
正在使用(在这种情况下,特定实例的批量大小可能会减少到 1)
已接受的答案足以解决问题 - squeeze
最后一个维度。但是,我有维度 (batch, 1280, 1, 1)
的张量并且想要 (batch, 1280)
。 Squeeze
函数不允许这样做 - squeeze(tensor, 1).shape
-> (batch, 1280, 1, 1)
和 squeeze(tensor, 2).shape
-> (batch, 1280, 1)
。我本可以使用 squeeze
两次,但你知道,美学 :)。
对我有帮助的是 torch.flatten(tensor, start_dim = 1)
-> (batch, 1280)
。微不足道,但我忘记了。不过警告,这个函数我会创建一个副本而不是视图,所以要小心。
https://pytorch.org/docs/stable/generated/torch.flatten.html
这里有人知道 torch.squeeze
函数是否尊重批处理(例如第一个)维度吗?从一些内联代码看来它不是..但也许其他人比我更了解内部工作原理。
顺便说一句,潜在的问题是我有形状为 (n_batch, channel, x, y, 1)
的张量。我想用一个简单的函数删除最后一个维度,这样我就得到了 (n_batch, channel, x, y)
.
当然可以进行整形,甚至可以选择最后一个轴。但是我想将这个功能嵌入到一个层中,以便我可以轻松地将它添加到 ModuleList
或 Sequence
对象中。
编辑:刚刚发现对于 Tensorflow (2.5.0),函数 tf.linalg.diag
确实尊重批量维度。仅供参考,它可能因您使用的功能而异
没有!挤压不尊重批量维度。如果您在批量维度可能为 1 时使用挤压,这可能是一个错误来源。经验法则是只有 类 和 torch.nn 中的函数默认遵守批量维度。
过去这让我很头疼。 我建议使用 reshape
或仅将 squeeze
与可选的输入维度参数一起使用。 在您的情况下,您可以使用 .squeeze(4)
仅删除最后一个维度.这样就不会发生意外。没有输入维度的挤压导致了我意想不到的结果,特别是当
- 模型的输入形状可能不同
- 批量大小可能不同
nn.DataParallel
正在使用(在这种情况下,特定实例的批量大小可能会减少到 1)
已接受的答案足以解决问题 - squeeze
最后一个维度。但是,我有维度 (batch, 1280, 1, 1)
的张量并且想要 (batch, 1280)
。 Squeeze
函数不允许这样做 - squeeze(tensor, 1).shape
-> (batch, 1280, 1, 1)
和 squeeze(tensor, 2).shape
-> (batch, 1280, 1)
。我本可以使用 squeeze
两次,但你知道,美学 :)。
对我有帮助的是 torch.flatten(tensor, start_dim = 1)
-> (batch, 1280)
。微不足道,但我忘记了。不过警告,这个函数我会创建一个副本而不是视图,所以要小心。
https://pytorch.org/docs/stable/generated/torch.flatten.html