如何找到占用内存的元素数量以扩展 pytorch 张量视图?
How do I find the number of elements taking up memory for an expanded view of a pytorch tensor?
Tensor.expand()
returns 底层张量的新视图,但实际上并没有为扩展视图分配更多内存。
如果我有一个张量是调用 expand()
(未知次数)的结果,我怎么知道在内存中实际为张量分配了多少个单元格(在我的实际使用中-案例,我真的只关心知道那个数字是否是 1)?
是否有类似我所说的 elements_in_memory
用于以下内容?:
import torch
t = torch.tensor(4.0)
t2 = t.expand(3, 4)
t3 = t2.unsqueeze(0).expand(5, 3, 4)
# I'm looking for something like this (which doesn't work)
assert t.elements_in_memory == 1
assert t2.elements_in_memory == 1
assert t3.elements_in_memory == 1
我尝试过的一些事情:
t.data_ptr
指的是内存中底层张量的第一个元素,所以t.data_ptr == t2.data_ptr
,但这并没有告诉我有多少个元素。
看来t.storage().size()
就是我想要的。
Each tensor has an associated torch.Storage, which holds its data. The tensor class also provides multi-dimensional, strided view of a storage and defines numeric operations on it.
Tensor.storage()
returns 对用于张量的存储的引用:
import torch
t = torch.tensor(4.0)
t2 = t.expand(3, 4)
t3 = t2.unsqueeze(0).expand(5, 3, 4)
assert t.storage().size() == 1
assert t2.storage().size() == 1
assert t3.storage().size() == 1
t4 = torch.ones(3, 4)
t5 = t4.unsqueeze(0).expand(5, 3, 4)
assert t4.storage().size() == 12
assert t5.storage().size() == 12
请注意,底层存储可能还包含 更多 元素,而不是某些特定视图公开的元素(这与我的用例无关)。例如 torch.ones(10)[3:6].storage().size() == 10
.
Tensor.expand()
returns 底层张量的新视图,但实际上并没有为扩展视图分配更多内存。
如果我有一个张量是调用 expand()
(未知次数)的结果,我怎么知道在内存中实际为张量分配了多少个单元格(在我的实际使用中-案例,我真的只关心知道那个数字是否是 1)?
是否有类似我所说的 elements_in_memory
用于以下内容?:
import torch
t = torch.tensor(4.0)
t2 = t.expand(3, 4)
t3 = t2.unsqueeze(0).expand(5, 3, 4)
# I'm looking for something like this (which doesn't work)
assert t.elements_in_memory == 1
assert t2.elements_in_memory == 1
assert t3.elements_in_memory == 1
我尝试过的一些事情:
t.data_ptr
指的是内存中底层张量的第一个元素,所以t.data_ptr == t2.data_ptr
,但这并没有告诉我有多少个元素。
看来t.storage().size()
就是我想要的。
Each tensor has an associated torch.Storage, which holds its data. The tensor class also provides multi-dimensional, strided view of a storage and defines numeric operations on it.
Tensor.storage()
returns 对用于张量的存储的引用:
import torch
t = torch.tensor(4.0)
t2 = t.expand(3, 4)
t3 = t2.unsqueeze(0).expand(5, 3, 4)
assert t.storage().size() == 1
assert t2.storage().size() == 1
assert t3.storage().size() == 1
t4 = torch.ones(3, 4)
t5 = t4.unsqueeze(0).expand(5, 3, 4)
assert t4.storage().size() == 12
assert t5.storage().size() == 12
请注意,底层存储可能还包含 更多 元素,而不是某些特定视图公开的元素(这与我的用例无关)。例如 torch.ones(10)[3:6].storage().size() == 10
.