如何计算 PyTorch 中点集和线之间的成对距离?

How to compute pairwise distance between point set and lines in PyTorch?

点集 A 是一个 Nx3 矩阵,从两个点集 BC 具有相同的大小 Mx3 我们可以得到它们之间的 BC 行。现在我想计算从 A 中的每个点到 BC 中的每条线的距离。 BMx3CMx3,那么线是从对应行的点开始的,所以BC是一个Mx3矩阵。基本方法计算如下:

D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
    p = A[i]  # 1x3
    for j in range(M):
        p1 = B[j] # 1x3
        p2 = C[j] # 1x3
        D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2) 

有没有更快的方法来完成这项工作?谢谢。

您可以通过这样做删除 for 循环(它应该以内存为代价加速,除非 MN 很小):

diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines

当然,你不需要一步一步来。我只是想弄清楚变量名。

注意:如果你不提供dimtorch.cross,它会使用第一个dim=3,这会给出错误结果如果 N=3(来自 docs):

If dim is not given, it defaults to the first dimension found with the size 3.

如果您想知道,可以查看 here 为什么我选择 expand 而不是 repeat