火炬挤压和批量尺寸

Torch squeeze and the batch dimension

这里有人知道 torch.squeeze 函数是否尊重批处理(例如第一个)维度吗?从一些内联代码看来它不是..但也许其他人比我更了解内部工作原理。

顺便说一句,潜在的问题是我有形状为 (n_batch, channel, x, y, 1) 的张量。我想用一个简单的函数删除最后一个维度,这样我就得到了 (n_batch, channel, x, y).

的形状

当然可以进行整形,甚至可以选择最后一个轴。但是我想将这个功能嵌入到一个层中,以便我可以轻松地将它添加到 ModuleListSequence 对象中。

编辑:刚刚发现对于 Tensorflow (2.5.0),函数 tf.linalg.diag 确实尊重批量维度。仅供参考,它可能因您使用的功能而异

没有!挤压不尊重批量维度。如果您在批量维度可能为 1 时使用挤压,这可能是一个错误来源。经验法则是只有 类 和 torch.nn 中的函数默认遵守批量维度。

过去这让我很头疼。 我建议使用 reshape 或仅将 squeeze 与可选的输入维度参数一起使用。 在您的情况下,您可以使用 .squeeze(4) 仅删除最后一个维度.这样就不会发生意外。没有输入维度的挤压导致了我意想不到的结果,特别是当

  1. 模型的输入形状可能不同
  2. 批量大小可能不同
  3. nn.DataParallel 正在使用(在这种情况下,特定实例的批量大小可能会减少到 1)

已接受的答案足以解决问题 - squeeze 最后一个维度。但是,我有维度 (batch, 1280, 1, 1) 的张量并且想要 (batch, 1280)Squeeze 函数不允许这样做 - squeeze(tensor, 1).shape -> (batch, 1280, 1, 1)squeeze(tensor, 2).shape -> (batch, 1280, 1)。我本可以使用 squeeze 两次,但你知道,美学 :)。

对我有帮助的是 torch.flatten(tensor, start_dim = 1) -> (batch, 1280)。微不足道,但我忘记了。不过警告,这个函数我会创建一个副本而不是视图,所以要小心。

https://pytorch.org/docs/stable/generated/torch.flatten.html