TensorFlow - 转换为浮点类型的矩阵的矩阵乘法需要很长时间,为什么?

TensorFlow - Matrix multiplication of matrices cast to float type takes very long time , why?

tensorflow中的以下矩阵乘法2.x需要很长时间才能执行

    a = tf.random.uniform(shape=(9180, 3049))
    b = tf.random.uniform(shape=(3049, 1913))
    a = tf.cast(a ,tf.float16)
    b = tf.cast(b ,tf.float16)
    tf.matmul(a,b)

但是如果我简单地使用下面的方法,它很快

    a = tf.random.uniform(shape=(9180, 3049))
    b = tf.random.uniform(shape=(3049, 1913))
    tf.matmul(a,b)

为什么会这样?出于某种目的,我需要将张量转换为浮点数。

实际上,在您的这两种情况下,您都在尝试浮点值的矩阵乘法。在第一种情况下,您使用的是 float16,在第二种情况下,您使用的是 float32。

import tensorflow as tf
import time
a = tf.random.uniform(shape=(9180, 3049), seed = 10)
b = tf.random.uniform(shape=(3049, 1913), seed = 10)

第一运行

x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)

x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)

输出:

184.76319313049316
0.0

第二次 运行 重启我的内核后。

x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)

x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)

输出:

183.03942680358887
1.0335445404052734

现在,如果我再次 运行 相同的代码,即使在更改 a 和 b 的值后也无需再次重新启动内核。

x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)

x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)

输出:

0.0
0.0

所以本质上不是TensorFlow的问题。 Tensorflow 以图的形式执行。当您第一次 运行 它时,它会使用提到的数据结构初始化图形并优化它以进行进一步计算。看看this.

中的最终评论

因此您第二次执行操作会更快