防止 Tensosflow 数据集在多次 model.predict 调用时重置生成器

Prevent Tensosflow Dataset from resetting the generator on multiple model.predict calls

我正在使用 tensorflow 数据集 from_generator 方法在不同批次上使用 CNN 模型进行预测。但我想在每次批量预测后添加一些额外的逻辑。具体来说,我想汇总不同的结果。

这是我的生成器函数:

def gen_predict(img_no):
  img_data = nib.load('./testing-images/10' + '%02d' %img_no + '_3.nii.gz').get_fdata()
  patch_size = 23
  dist_center = (patch_size - 1) // 2
  l, b, h = img_data.shape
  for zc in range(dist_center, h - dist_center - 1):
    for yc in range(dist_center, b - dist_center - 1):
      for xc in range(dist_center, l - dist_center - 1):    
        print(xc,yc,zc) 
        xl, yl, zl = (xc - dist_center, yc - dist_center, zc - dist_center)
        xr, yr, zr = (xc + dist_center, yc + dist_center, zc + dist_center)
        cartesianCoordinate = np.array([xc, yc, zc])
        spectralCoordinates = np.array([0, 0, 0])
        X = (np.array(img_data[xl:(xr + 1), yl:(yr + 1), zl:(zr + 1)]), np.concatenate((cartesianCoordinate, spectralCoordinates)).reshape((6,1)))
        yield (X,)

问题在于,在每次预测调用之后,生成器都会重置,而在下一次预测调用时,它会给出对同一组数据的预测。这是我的代码:

dataset_pred = tf.data.Dataset.from_generator(lambda: gen_predict(3), ((tf.float32, tf.float32),), output_shapes=((tf.TensorShape([23,23,23]), tf.TensorShape([6,1])),))
dataset_pred = dataset_pred.batch(BS)
for i in range(num_batches):
  temp_pred = np.array(model.predict(dataset_pred, batch_size=BS, steps=1))
  ## aggregate the temp_pred result ##

我想用额外的逻辑模仿 model.predict(dataset_pred, batch_size=BS, steps=num_batches) 的行为。另外,由于 num_batches.

太大,我无法存储此调用的结果

编辑: 我已经添加了答案。但非常感谢任何有助于提高效率的帮助。

我找到了答案。基本上,我们可以将相应的生成器存储在一个变量中,然后使用 lambda 使其可调用。这不会重置生成器。

cur_gen = gen_predict(img_no)
dataset_pred = tf.data.Dataset.from_generator(lambda: cur_gen, ((tf.float32, tf.float32),), output_shapes=((tf.TensorShape([23,23,23]), tf.TensorShape([6,1])),))
dataset_pred = dataset_pred.batch(BS)
for i in range(num_batches):
  temp_pred = np.array(model.predict(dataset_pred, batch_size=BS, steps=1))
  ## aggregate the temp_pred result ##