如何找到占用内存的元素数量以扩展 pytorch 张量视图?

How do I find the number of elements taking up memory for an expanded view of a pytorch tensor?

Tensor.expand() returns 底层张量的新视图,但实际上并没有为扩展视图分配更多内存。

如果我有一个张量是调用 expand()(未知次数)的结果,我怎么知道在内存中实际为张量分配了多少个单元格(在我的实际使用中-案例,我真的只关心知道那个数字是否是 1)?

是否有类似我所说的 elements_in_memory 用于以下内容?:

import torch

t = torch.tensor(4.0)
t2 = t.expand(3, 4)
t3 = t2.unsqueeze(0).expand(5, 3, 4)

# I'm looking for something like this (which doesn't work) 
assert t.elements_in_memory == 1
assert t2.elements_in_memory == 1
assert t3.elements_in_memory == 1 

我尝试过的一些事情:

看来t.storage().size()就是我想要的。

来自torch.Tensor documentation:

Each tensor has an associated torch.Storage, which holds its data. The tensor class also provides multi-dimensional, strided view of a storage and defines numeric operations on it.

Tensor.storage() returns 对用于张量的存储的引用:

import torch

t = torch.tensor(4.0)
t2 = t.expand(3, 4)
t3 = t2.unsqueeze(0).expand(5, 3, 4)

assert t.storage().size() == 1
assert t2.storage().size() == 1
assert t3.storage().size() == 1

t4 = torch.ones(3, 4)
t5 = t4.unsqueeze(0).expand(5, 3, 4)

assert t4.storage().size() == 12
assert t5.storage().size() == 12

请注意,底层存储可能还包含 更多 元素,而不是某些特定视图公开的元素(这与我的用例无关)。例如 torch.ones(10)[3:6].storage().size() == 10.