TensorFlow 1.7 + Keras 和数据集:对象没有属性 'ndim'

TensorFlow 1.7 + Keras and datasets: Object has no attribute 'ndim'

调用 keras 时出现以下错误 model.fit()

AttributeError: 'RepeatDataset' object has no attribute 'ndim'

我正在使用 TensorFlow 1.7 和 Keras。不幸的是,我必须使用 TF 1.7。知道发生了什么事吗?代码,改编 来自 tensorflow 演示:

import tensorflow as tf
from IPython import embed
from tensorflow.python import keras
from tensorflow.python.keras import layers

model = tf.keras.Sequential()
model.add(layers.Dense(64, input_shape=(32,), activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.compile(
    optimizer=tf.train.AdamOptimizer(0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy'])

import numpy as np

# Generate random data using numpy
def random_one_hot_labels(shape):
    n, n_class = shape
    classes = np.random.randint(0, n_class, n)
    labels = np.zeros((n, n_class))
    labels[np.arange(n), classes] = 1
    return labels

data = np.random.random((1000, 32))
labels = random_one_hot_labels((1000, 10))

datasetA = tf.data.Dataset.from_tensor_slices((data, labels))
datasetB = datasetA.batch(32)
dataset = datasetB.repeat()

model.fit(
    dataset, 
    epochs=10,
    steps_per_epoch=30
)

出现此错误是因为 repeat() 正在返回一个生成器,而您正将其传递给 fitfit 需要一个已定义 ndim 的 numpy 数组。稍后添加了对带有 fit 的生成器的支持。尝试使用现已弃用的 fit_generator 代替:

model.fit_generator(
    dataset, 
    epochs=10,
    steps_per_epoch=30
)

另请注意,如果没有任何参数,repeat() 将使用 -1,这可能是也可能不是您正在寻找的行为。 repeat(1)repeat(2) 之类的内容可能就是您要查找的内容。截至 1.7 发布时 RepeatDataset 的来源:

class RepeatDataset(Dataset):
  """A `Dataset` that repeats its input several times."""

  def __init__(self, input_dataset, count):
    """See `Dataset.repeat()` for details."""
    super(RepeatDataset, self).__init__()
    self._input_dataset = input_dataset
    if count is None:
      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
    else:
      self._count = ops.convert_to_tensor(
          count, dtype=dtypes.int64, name="count")

我试过重现它,但要安装正确的版本需要我付出更多的努力。

如果这不起作用,尝试手动遍历数据集生成器并首先从中创建一个 numpy 数组,然后将其传递给 fit 可能是值得的。我不确定 1.7 中是否有 Keras 方法可以做到这一点,但如果你必须走那条路,this answer 可能会有用。