如何计算 PyTorch 中点集和线之间的成对距离?
How to compute pairwise distance between point set and lines in PyTorch?
点集 A
是一个 Nx3
矩阵,从两个点集 B
和 C
具有相同的大小 Mx3
我们可以得到它们之间的 BC
行。现在我想计算从 A
中的每个点到 BC
中的每条线的距离。 B
是Mx3
,C
是Mx3
,那么线是从对应行的点开始的,所以BC
是一个Mx3
矩阵。基本方法计算如下:
D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
p = A[i] # 1x3
for j in range(M):
p1 = B[j] # 1x3
p2 = C[j] # 1x3
D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2)
有没有更快的方法来完成这项工作?谢谢。
您可以通过这样做删除 for
循环(它应该以内存为代价加速,除非 M
和 N
很小):
diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines
当然,你不需要一步一步来。我只是想弄清楚变量名。
注意:如果你不提供dim
到torch.cross
,它会使用第一个dim=3
,这会给出错误结果如果 N=3
(来自 docs):
If dim is not given, it defaults to the first dimension found with the size 3.
如果您想知道,可以查看 here 为什么我选择 expand
而不是 repeat
。
点集 A
是一个 Nx3
矩阵,从两个点集 B
和 C
具有相同的大小 Mx3
我们可以得到它们之间的 BC
行。现在我想计算从 A
中的每个点到 BC
中的每条线的距离。 B
是Mx3
,C
是Mx3
,那么线是从对应行的点开始的,所以BC
是一个Mx3
矩阵。基本方法计算如下:
D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
p = A[i] # 1x3
for j in range(M):
p1 = B[j] # 1x3
p2 = C[j] # 1x3
D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2)
有没有更快的方法来完成这项工作?谢谢。
您可以通过这样做删除 for
循环(它应该以内存为代价加速,除非 M
和 N
很小):
diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines
当然,你不需要一步一步来。我只是想弄清楚变量名。
注意:如果你不提供dim
到torch.cross
,它会使用第一个dim=3
,这会给出错误结果如果 N=3
(来自 docs):
If dim is not given, it defaults to the first dimension found with the size 3.
如果您想知道,可以查看 here 为什么我选择 expand
而不是 repeat
。