如何在pytorch中获得具有特定值条件的次二维张量?
how to get sub 2d tensor with specific value condition in pytorch?
我想将二维火炬张量复制到目标张量,该张量仅包含值,直到第一次出现 202 值,其余项为零,如下所示:
source_t=tensor[[101,2001,2034,1045,202,3454,3453,1234,202]
,[101,1999,2808,202,17658,3454,202,0,0]
,[101,2012,3832,4027,3454,202,3454,9987,202]]
destination_t=tensor[[101,2001,2034,1045,202,0,0,0,0]
,[101,1999,2808,202,0,0,0,0,0]
,[101,2012,3832,4027,3454,202,0,0,0]]
我该怎么做?
我制定了有效且非常有效的解决方案。
我制作了一些更复杂的源张量,在不同的地方增加了 202 行:
import copy
import torch
source_t = torch.tensor([[101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
[101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]
])
一开始,我们应该找到前202个出现的地方。我们可以找到所有出现的地方,然后选择第一个:
index_202 = (source_t == 202).nonzero(as_tuple=False).numpy()
rows_for_replace = list()
columns_to_replace = list()
elements = source_t.shape[1]
current_ind = 0
while current_ind < len(index_202)-1:
current = index_202[current_ind]
element_ind = current[1] + 1
rows_for_replace.extend([current[0]]*(elements-element_ind))
while element_ind < elements:
columns_to_replace.append(element_ind)
element_ind += 1
if current[0] == index_202[current_ind+1][0]:
current_ind += 1
current_ind += 1
完成此操作后,我们得到了所有应该用零替换的索引。第一行有 4 个元素,第二行有 5 个,第四行有 3 个,第三和第五行没有。
rows_for_replace, columns_to_replace
([0, 0, 0, 0, 1, 1, 1, 1, 1, 3, 3, 3], [5, 5, 5, 5, 4, 4, 4, 4, 4, 6, 6, 6])
然后我们只需复制源张量并设置新值:
destination_t = copy.deepcopy(source_t)
destination_t[rows_for_replace, columns_to_replace] = 0
总结:
source_t
tensor([[ 101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
[ 101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[ 101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]])
destination_t
tensor([[ 101, 2001, 2034, 1045, 202, 0, 0, 0, 0],
[ 101, 1999, 2808, 202, 0, 0, 0, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[ 101, 2012, 3832, 4027, 3454, 202, 0, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]])
我认为有更好的解决方案,它要求每一行都有一个“202”,否则您将不得不删除不是这种情况的行。
import torch
t = torch.tensor([[101,2001,2034,1045,202,3454,3453,1234,202],
[101,1999,2808,202,17658,3454,202,0,0],
[101,2012,3832,4027,3454,202,3454,9987,202]])
out = t.clone() # make copy
检查张量等于 202 的位置,将 boolean 转换为 int 并为每一行取 argmax,这意味着我们有第一个 1 出现的列,它对应于第一个 202。
然后遍历每一行
cols = t.eq(202).int().argmax(1)
k = t.shape[1] # number of columns
for idx, c in enumerate(cols):
if c + 1 < k:
out[idx, c+1:] = 0 # make all values right of c equal to zero
我想将二维火炬张量复制到目标张量,该张量仅包含值,直到第一次出现 202 值,其余项为零,如下所示:
source_t=tensor[[101,2001,2034,1045,202,3454,3453,1234,202]
,[101,1999,2808,202,17658,3454,202,0,0]
,[101,2012,3832,4027,3454,202,3454,9987,202]]
destination_t=tensor[[101,2001,2034,1045,202,0,0,0,0]
,[101,1999,2808,202,0,0,0,0,0]
,[101,2012,3832,4027,3454,202,0,0,0]]
我该怎么做?
我制定了有效且非常有效的解决方案。
我制作了一些更复杂的源张量,在不同的地方增加了 202 行:
import copy
import torch
source_t = torch.tensor([[101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
[101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
[101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]
])
一开始,我们应该找到前202个出现的地方。我们可以找到所有出现的地方,然后选择第一个:
index_202 = (source_t == 202).nonzero(as_tuple=False).numpy()
rows_for_replace = list()
columns_to_replace = list()
elements = source_t.shape[1]
current_ind = 0
while current_ind < len(index_202)-1:
current = index_202[current_ind]
element_ind = current[1] + 1
rows_for_replace.extend([current[0]]*(elements-element_ind))
while element_ind < elements:
columns_to_replace.append(element_ind)
element_ind += 1
if current[0] == index_202[current_ind+1][0]:
current_ind += 1
current_ind += 1
完成此操作后,我们得到了所有应该用零替换的索引。第一行有 4 个元素,第二行有 5 个,第四行有 3 个,第三和第五行没有。
rows_for_replace, columns_to_replace
([0, 0, 0, 0, 1, 1, 1, 1, 1, 3, 3, 3], [5, 5, 5, 5, 4, 4, 4, 4, 4, 6, 6, 6])
然后我们只需复制源张量并设置新值:
destination_t = copy.deepcopy(source_t)
destination_t[rows_for_replace, columns_to_replace] = 0
总结: source_t
tensor([[ 101, 2001, 2034, 1045, 202, 3454, 3453, 1234, 202],
[ 101, 1999, 2808, 202, 17658, 3454, 202, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[ 101, 2012, 3832, 4027, 3454, 202, 3454, 9987, 202],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]])
destination_t
tensor([[ 101, 2001, 2034, 1045, 202, 0, 0, 0, 0],
[ 101, 1999, 2808, 202, 0, 0, 0, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 2020],
[ 101, 2012, 3832, 4027, 3454, 202, 0, 0, 0],
[ 101, 2012, 3832, 4027, 3454, 2020, 3454, 9987, 202]])
我认为有更好的解决方案,它要求每一行都有一个“202”,否则您将不得不删除不是这种情况的行。
import torch
t = torch.tensor([[101,2001,2034,1045,202,3454,3453,1234,202],
[101,1999,2808,202,17658,3454,202,0,0],
[101,2012,3832,4027,3454,202,3454,9987,202]])
out = t.clone() # make copy
检查张量等于 202 的位置,将 boolean 转换为 int 并为每一行取 argmax,这意味着我们有第一个 1 出现的列,它对应于第一个 202。
然后遍历每一行
cols = t.eq(202).int().argmax(1)
k = t.shape[1] # number of columns
for idx, c in enumerate(cols):
if c + 1 < k:
out[idx, c+1:] = 0 # make all values right of c equal to zero