Numpy 3d array - 规范化行

Numpy 3d array - normalize rows

我一直致力于使用水平 dark/bright 线对图像数据进行标准化。

请在下面找到我的最小工作示例。 假设有一个 3d 数组(三个图像):

import numpy as np
a= np.arange(3)#
a= np.vstack([a,a+3,a+6])
a= np.repeat(a,3,axis=0)#.reshape((3,3,3))
a = a.reshape(3,3,3)

例如,图像 #2 a[:,:,1] 包含

[[1 1 1]
 [4 4 4]
 [7 7 7]]

并显示水平条纹。为了消除条纹,我正在计算行的中位数、总体中位数和差异矩阵。

row_median = np.median(a, axis=1)
bg_median= np.tile(np.median(a, axis=[0,1]),(3,1))
difference_matrix= bg_median-row_median

随后,我遍历图像并将差异矩阵应用于所有图像。

for i in range(len(a)):
    a[:,:,i] = a[:,:,i] + np.tile(difference_matrix[:,i],(3,1)).T

这给出了期望的结果,例如,在 a[:,:,1] 中:

[[4 4 4]
 [4 4 4]
 [4 4 4]]

对于大图像和图像堆栈,此过程非常慢。 我将不胜感激通过使用 broadcasting.

来提高我的代码性能的任何评论和提示

编辑: 跟进

a = a.astype('float64')
diff = np.median(a, axis=[0,1]) - np.median(a, axis=0))
a += diff[None,:]

为我解决了这个问题。

方法 #1

利用 broadcasting 使用 None/np.newaxis 扩展维度而不是平铺那些中间数组,以节省内存并因此实现性能。效率。因此,更改将是 -

diff = np.median(a, axis=[0,1]) - np.median(a, axis=1)
a += diff[:,None]

这处理了引擎盖下的维度扩展。

方法 #2

或者,一种更明确的跟踪 dims 的方法是在执行数据缩减时保留它们,从而避免使用 None 进行最终的暗淡扩展。因此,我们可以将 keepdims 用作 True -

diff = np.median(a,axis=(0,1),keepdims=True) - np.median(a, axis=1,keepdims=True)
a += diff