如何使用 numpy 来加速计算质心的代码?

How to use numpy to speed up code that calculates center of mass?

我编写了一小段代码 - 给定 n 个指定质量的对象和随时间变化的矢量坐标 - 将计算质心。我认为代码看起来很笨重(它使用 3 个 for 循环),并且想知道是否有 numpy 方法来矢量化(或至少加速)这个方法。请注意,对于此任务可能会避免使用 class Body,但会在此处未显示的其他相关代码中使用。

import numpy as np

class Body():

    def __init__(self, mass, position):
        self.mass = mass
        self.position = position

    def __str__(self):
        return '\n .. mass:\n{}\n\n .. position:\n{}\n'.format(self.mass, self.position)

三个对象被初始化。

mass = 100 # same for all 3 objects
ndim = 3 # 3 dimensional space
nmoments = 10 # 10 moments in time

## initialize bodies
nelems = ndim * nmoments
x = np.arange(nelems).astype(int).reshape((nmoments, ndim))
A = Body(mass, position=x)
B = Body(mass, position=x / 2.)
C = Body(mass, position=x * 2.)
bodies = [A, B, C]
total_mass = sum([body.mass for body in bodies])

# print("\n ** A **\n{}\n".format(A))
# print("\n ** B **\n{}\n".format(B))
# print("\n ** C **\n{}\n".format(C))

## get center of mass
center_of_mass = []
for dim in range(ndim):
    coms = []
    for moment in range(nmoments):
        numerator = 0
        for body in bodies:
            numerator += body.mass * body.position[moment, dim]
        com = numerator / total_mass
        coms.append(com)
    center_of_mass.append(coms)
center_of_mass = np.array(center_of_mass).T

# print("\n .. center of mass:\n{}\n".format(center_of_mass))

为了验证代码是否有效,上面代码中的 print 语句输出以下内容:

 ** A **

 .. mass:
100

 .. position:
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]
 [12 13 14]
 [15 16 17]
 [18 19 20]
 [21 22 23]
 [24 25 26]
 [27 28 29]]



 ** B **

 .. mass:
100

 .. position:
[[ 0.   0.5  1. ]
 [ 1.5  2.   2.5]
 [ 3.   3.5  4. ]
 [ 4.5  5.   5.5]
 [ 6.   6.5  7. ]
 [ 7.5  8.   8.5]
 [ 9.   9.5 10. ]
 [10.5 11.  11.5]
 [12.  12.5 13. ]
 [13.5 14.  14.5]]



 ** C **

 .. mass:
100

 .. position:
[[ 0.  2.  4.]
 [ 6.  8. 10.]
 [12. 14. 16.]
 [18. 20. 22.]
 [24. 26. 28.]
 [30. 32. 34.]
 [36. 38. 40.]
 [42. 44. 46.]
 [48. 50. 52.]
 [54. 56. 58.]]



 .. center of mass:
[[ 0.          1.16666667  2.33333333]
 [ 3.5         4.66666667  5.83333333]
 [ 7.          8.16666667  9.33333333]
 [10.5        11.66666667 12.83333333]
 [14.         15.16666667 16.33333333]
 [17.5        18.66666667 19.83333333]
 [21.         22.16666667 23.33333333]
 [24.5        25.66666667 26.83333333]
 [28.         29.16666667 30.33333333]
 [31.5        32.66666667 33.83333333]]

使用 numpy 会加快速度并使代码更清晰。我不是 n 体问题的专家,所以我希望遵循算法 OK,结果看起来是一样的。所有循环都隐含在 numpy 中。

# *****  From the question  *****
import numpy as np

class Body():   

    def __init__(self, mass, position):
        self.mass = mass
        self.position = position

    def __str__(self):
        return '\n .. mass:\n{}\n\n .. position:\n{}\n'.format(self.mass, self.position)

mass = 100 # same for all 3 objects
ndim = 3 # 3 dimensional space
nmoments = 10 # 10 moments in time

## initialize bodies
nelems = ndim * nmoments
x = np.arange(nelems).astype(int).reshape((nmoments, ndim))
A = Body(mass, position=x)
B = Body(mass, position=x / 2.)
C = Body(mass, position=x * 2.)
bodies = [A, B, C]

# **** End of code from the question  ****

# Fill the numpy arrays
np_mass = np.array( [ body.mass for body in bodies ])[ :,None, None ]
# the [:, None, None] turns np_mass into a 3D array for correct broadcasting

np_pos = np.array( [ body.position for body in bodies ])    # 3D 

np_mass.shape
# (3, 1, 1) # (n_bodies, 1, 1 ) - The two 'spare' dimensions force the broadcasting to be along the correct axes

np_pos.shape
# (3, 10, 3) # ( n_bodies, nmoments, ndims )

total_mass = np_mass.sum()  # Sum the three masses
numerator = (np_mass * np_pos).sum(axis=0)   # sum np_mass * np_pos along the body (0) axis. 
com = numerator / total_mass                 # divide by total_mass

# Could be a oneliner
# com = (np_mass * np_pos).sum(axis=0) / np.mass.sum()

print(com)

# array([[ 0.        ,  1.16666667,  2.33333333],
#        [ 3.5       ,  4.66666667,  5.83333333],
#        [ 7.        ,  8.16666667,  9.33333333],
#        [10.5       , 11.66666667, 12.83333333],
#        [14.        , 15.16666667, 16.33333333],
#        [17.5       , 18.66666667, 19.83333333],
#        [21.        , 22.16666667, 23.33333333],
#        [24.5       , 25.66666667, 26.83333333],
#        [28.        , 29.16666667, 30.33333333],
#        [31.5       , 32.66666667, 33.83333333]])
center_of_mass = (A.mass * A.position + B.mass * B.position + C.mass * C.position) / total_mass