如何使用 numpy 或 Theano 将矩阵的整数值用作另一个矩阵的索引?

How to use integer values of a matrix as index for another matrix using numpy or Theano?

我有以下 4 个相同形状的矩阵:(1) 包含整数值的矩阵 I,(2) 包含整数值的矩阵 J,(3) 矩阵 D 包含浮点值和 (4) 矩阵 V 包含浮点值。

我想用这4个矩阵按以下方式构造一个"output"矩阵:

  1. 要查找输出矩阵的元素 ij 的值,请查找矩阵 I 的所有等于 [=22= 的单元格(元素) ] 和矩阵 J 中等于 j.
  2. 的所有单元格(元素)
  3. 仅使用满足两个条件的单元格(记住矩阵 IJ 具有相同的形状)。
  4. 在 "selected" 个单元格中搜索具有最小值 D 的单元格。
  5. 取找到的单元格(具有最小 D 值)并检查它在矩阵中的值 V

通过这种方式,我们找到输出矩阵的 ij 元素的值。我为所有 isjs 做这件事。

我想使用 numpy 或 Theano 解决这个问题。

当然我可以遍历所有 i_s 和 j_s 但我认为(希望)应该有更有效的方法。

已添加

应要求,我举个例子:

这里是矩阵 I:

 0   1   2
 1   1   0
 0   0   2

这里是矩阵 J:

 1   1   1
 1   2   1
 0   1   0

这里是矩阵 D:

 1.2   3.4   2.2
 2.2   4.3   2.3
 7.1   6.1   2.7

最后我们得到矩阵 V:

 1.1   8.1   9.1
 3.1   7.1   2.1
 0.1   5.1   3.1

如您所见,所有 4 个矩阵都具有相同的形状 (3 x 4),但它们可以具有其他形状(例如 2 x 5)。最主要的是所有 4 个矩阵的形状都相同。

我们可以看到矩阵 I 的值是从 0 到 2,所以输出矩阵应该有 3 行。同理,我们可以得出输出矩阵应该有3列(因为矩阵J的值也是0到2)。

让我们先找到输出矩阵的元素(0, 1)。在 I 矩阵中,以下单元格(用 x 标记)包含 0.

 x   .   .
 .   .   x
 x   x   .

矩阵J中以下元素包含1:

 x   x   x
 x   .   x
 .   x   .

这两组单元格的交集是:

 x   .   .
 .   .   x
 .   x   .

对应的距离为:

 1.2    .     . 
  .     .    2.3
  .    6.1    . 

所以,最小距离位于左上角。结果我们从矩阵的左上角取值V(这个值为1.1)。

这就是我们找到输出矩阵的 (0,1) 元素值的方式。我们对所有可能的索引组合(总共有 3 x 3 = 9)组合执行相同的过程。对于某些组合,我们找不到任何值,在这种情况下,我们将值设置为 nan.

这是使用 broadcasting -

的矢量化方法
# Get mask of matching elements against the iterators
m,n = I.shape
Imask = I == np.arange(m)[:,None,None,None]
Jmask = J == np.arange(n)[:,None,None]

# Get the mask of intersecting ones
mask = Imask & Jmask

# Get D intersection masked array
Dvals = np.where(mask,D,np.inf)

# Get argmin along merged last two axes. Index into flattened V for final o/p
out = V.ravel()[Dvals.reshape(m,n,-1).argmin(-1)]

样本输入、输出-

In [136]: I = np.array([[0,1,2],[1,1,0],[0,0,2]])
     ...: J = np.array([[1,1,1],[1,2,1],[0,1,0]])
     ...: D = np.array([[1.2, 3.4, 2.2],[2.2, 4.3, 2.3],[7.1, 6.1, 2.7]])
     ...: V = np.array([[1.1 , 8.1, 9.1],[3.1, 7.1, 2.1],[0.1, 5.1, 3.1]])
     ...: 

In [144]: out
Out[144]: 
array([[ 0.1,  1.1,  1.1], # To verify : v[0,1] = 1.1
       [ 1.1,  3.1,  7.1],
       [ 3.1,  9.1,  1.1]])