keras模型的奇怪分析结果:越复杂越快

Weird profiling result on keras models: the more complex the faster

所以我目前正在尝试找出哪种深度学习框架最适合处理傅立叶变换。到目前为止,我正在使用 kerastensorflow 后端,但我注意到 fft 有点慢(参见 this issue on Github)。

所以最近想直接和pytorch比较速度。因为我想做的不仅仅是简单的傅立叶变换,我尝试添加一些操作来做一个更全面的基准测试,我注意到对于 keras,添加操作减少了计算时间。

这是最小的工作示例(基本上是在 2D 中进行逆傅里叶变换,通过获取图像的模块完成,并且可能介于 "decomplexification" 和 "recomplexification" 之间):

import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
from keras.layers import Input, Lambda, concatenate
from keras.models import Model
import numpy as np
import tensorflow as tf
from tensorflow.signal import ifft2d
import time

def concatenate_real_imag(x):
    x_real = Lambda(tf.math.real)(x)
    x_imag = Lambda(tf.math.imag)(x)
    return concatenate([x_real, x_imag])

def to_complex(x):
    return tf.complex(x[0], x[1])

def complex_from_half(x, n, output_shape):
    return Lambda(lambda x: to_complex([x[..., :n], x[..., n:]]), output_shape=output_shape)(x)

def weird_model(conc_then_com=False):
    input_size = (320, None, 1)
    kspace_input = Input(input_size, dtype='complex64', name='kspace_input')
    inv_kspace = Lambda(ifft2d, output_shape=input_size)(kspace_input)
    if conc_then_com:
        inv_kspace = concatenate_real_imag(kspace_input)
        inv_kspace = complex_from_half(inv_kspace, 1, input_size)
    abs_inv_kspace = Lambda(tf.math.abs)(inv_kspace)
    model = Model(inputs=kspace_input, outputs=abs_inv_kspace)
    model.compile(
        optimizer='adam',
        loss='mse',
    )
    return model

# fake data
data_x = np.random.rand(35, 320, 320, 1) + 1j * np.random.rand(35, 320, 320, 1)
data_y = np.random.rand(35, 320, 320, 1)

start = time.time()
r = weird_model(conc_then_com=True).predict_on_batch(data_x)
end = time.time()
duration = end - start
print(f'For the prediction with the complex model it took {duration}')

start = time.time()
r = weird_model(conc_then_com=False).predict_on_batch(data_x)
end = time.time()
duration = end - start
print(f'For the prediction with the simple model it took {duration}')

start = time.time()
weird_model(conc_then_com=True).fit(
    x=data_x,
    y=data_y,
    batch_size=35,
    epochs=1,
    verbose=2,
    shuffle=False,
)
end = time.time()
duration = end - start
print(f'For the fitting with the complex model it took {duration}')

start = time.time()
weird_model(conc_then_com=False).fit(
    x=data_x,
    y=data_y,
    batch_size=35,
    epochs=1,
    verbose=2,
    shuffle=False,
)
end = time.time()
duration = end - start
print(f'For the fitting with the simple model it took {duration}')

给出以下时间(或多或少):

For the prediction with the complex model it took 0.24
For the prediction with the simple model it took 3.98
For the fitting with the complex model it took 0.28
For the fitting with the simple model it took 4.01

我不知道发生了什么。

实际上,这只是一个错字: inv_kspace = concatenate_real_imag(kspace_input) 应该是 inv_kspace = concatenate_real_imag(inv_kspace)