从具有任意数量尾随维数的pytorch中的张量获取矩阵

obtain matrix from a tensor in pytorch with arbitrary number of trailing dimensions

我有一个任意维数的 pytorch 张量:...X,Y,Z

我想要一个函数,这样我给一个数字 C,我得到 ...,C,Y,Z

my_matrix = [:,:,C,:,:]

但我不知道 C 之前有多少尾随维度,我看到了一个使用切片元组的答案,但似乎可以让它工作。

我认为 ellipsis 可以胜任:

t = torch.randn(2, 3, 6, 5, 9, 3)
t[..., 4, :, :]

u = torch.randn(11, 4, 2, 7)
u[..., 2, :, :].shape