Keras/Tensorflow 中的一组矩阵之间的成对距离

Pairwise distance between a set of Matrices in Keras/Tensorflow

我想计算一组 Tensor 之间的成对距离(例如 4 张量)。每个矩阵都是 2D Tensor。我不知道如何以矢量化格式执行此操作。我写了以下 sudo-code 来确定我需要什么:

E.shape => [4,30,30]

  sum = 0
  for i in range(4):
    for j in range(4):
        res = calculate_distance(E[i],E[j]) # E[i] is one the 30*30 Tensor
        sum = sum + reduce_sum(res)

这是我最后一次尝试:

x_ = tf.expand_dims(E, 0)
y_ = tf.expand_dims(E, 1)
s = x_ - y_
P = tf.reduce_sum(tf.norm(s, axis=[-2, -1]))

此代码有效但我不知道如何批量执行此操作。例如,当 E.shape[BATCH_SIZE * 4 * 30 * 30] 时,我的代码无法正常工作,并且会发生内存不足的情况。我怎样才能有效地做到这一点?

编辑: 折腾了一天,终于找到解决办法了。它并不完美,但有效:

res = tf.map_fn(lambda x: tf.map_fn(lambda y: tf.map_fn(lambda z: tf.norm(z - x), x), x), E)
    res = tf.reduce_mean(tf.square(res))

Tensorflow 允许通过 tf.norm 函数计算 Frobenius 范数。对于二维矩阵,它等价于 1-norm.

以下解决方案未矢量化,并假定 E 中的第一个维度是静态已知的:

E = tf.random_normal(shape=[5, 3, 3], dtype=tf.float32)
F = tf.split(E, E.shape[0])
total = tf.reduce_sum([tf.norm(tensor=(lhs-rhs), ord=1, axis=(-2, -1)) for lhs in F for rhs in F])

更新:

同一代码的优化矢量化版本:

E = tf.random_normal(shape=[1024, 4, 30, 30], dtype=tf.float32)
lhs = tf.expand_dims(E, axis=1)
rhs = tf.expand_dims(E, axis=2)
total = tf.reduce_sum(tf.norm(tensor=(lhs - rhs), ord=1, axis=(-2, -1)))

内存问题:评估此代码时, tf.contrib.memory_stats.MaxBytesInUse() 报告内存消耗峰值为 73729792 = 74Mb,这表明开销相对适中(原始 lhs-rhs 张量为 59Mb)。你的 OOM 很可能是由于你计算 s = x_ - y_BATCH_SIZE 维度重复造成的,因为你的批量大小远远大于矩阵的数量(1024 vs 4)。

如果您的批量不是太大,您使用 expand_dims 的解决方案应该没问题。但是,鉴于您的原始伪代码在 range(4) 上循环,您可能应该扩展轴 1 和 2,而不是 0 和 1。

您可以检查张量的形状以确保您指定了正确的轴。例如,

batch_size = 8
E_np = np.random.rand(batch_size, 4, 30, 30)
E = K.variable(E_np)  # shape=(8, 4, 30, 30)

x_ = K.expand_dims(E, 1)
y_ = K.expand_dims(E, 2)
s = x_ - y_  # shape=(8, 4, 4, 30, 30)

distances = tf.norm(s, axis=[-2, -1])  # shape=(8, 4, 4)
P = K.sum(distances, axis=[-2, -1])  # shape=(8,)

现在 P 将是 8 个样本中每个样本的 4 个矩阵之间的成对距离之和。


您还可以验证 P 中的值是否与伪代码中计算的值相同:

answer = []
for batch_idx in range(batch_size):
    s = 0
    for i in range(4):
        for j in range(4):
            a = E_np[batch_idx, i]
            b = E_np[batch_idx, j]
            s += np.sqrt(np.trace(np.dot(a - b, (a - b).T)))
    answer.append(s)

print(answer)
[149.45960605637578, 147.2815068236368, 144.97487402393705, 146.04866735065312, 144.25537059201062, 148.9300986019226, 146.61229889228133, 149.34259789169045]

print(K.eval(P).tolist())
[149.4595947265625, 147.281494140625, 144.97488403320312, 146.04867553710938, 144.25537109375, 148.9300994873047, 146.6123046875, 149.34259033203125]