Tensorflow 多输出分类错误

Tensorflow Multiple outputs Classification error

在强化学习算法中,我需要构建和训练一个具有多个输出的神经网络,其中每个输出是一个 7 维的概率向量。我使用分类交叉熵作为损失函数。但是,我无法训练网络。

NN 是

num_actions = 7
input_actor = keras.layers.Input(shape = [4], name = 'states')
layer_1 = keras.layers.Dense(30, activation="relu", kernel_initializer=keras.initializers.he_normal())(input_actor)
layer_2 = keras.layers.Dense(30, activation="relu", kernel_initializer=keras.initializers.he_normal())(layer_1)
out_1 = keras.layers.Dense(num_actions, activation='softmax')(layer_2)
out_2 = keras.layers.Dense(num_actions, activation='softmax')(layer_2)
out_3 = keras.layers.Dense(num_actions, activation='softmax')(layer_2)
out_4 = keras.layers.Dense(num_actions, activation='softmax')(layer_2)
actor = keras.Model(inputs = [input_actor], outputs = [out_1, out_2, out_3, out_4])
actor.compile(loss=['categorical_crossentropy', 'categorical_crossentropy', 'categorical_crossentropy', 'categorical_crossentropy'], optimizer=keras.optimizers.Adam())

使用单个样本进行训练:

s = np.arange(1,5).reshape(1,-1)
out = [keras.utils.to_categorical(a, 7).tolist() for a in iter([4,0,1,4])]
actor.train_on_batch(s, out , sample_weight=[-1.])

导致以下错误

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics, return_dict)
   1693                                                     class_weight)
   1694       train_function = self.make_train_function()
-> 1695       logs = train_function(iterator)
   1696 
   1697     if reset_metrics:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    805       # In this case we have created variables on the first call, so we run the
    806       # defunned version which is guaranteed to never create variables.
--> 807       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    808     elif self._stateful_fn is not None:
    809       # Release the lock early so that multiple threads can perform the call

TypeError: 'NoneType' object is not callable

我们必须将输出中的每个向量重塑为 (None, 7)。

s = np.arange(1,5).reshape(1,-1)
out = [np.array(keras.utils.to_categorical(a, 7)).reshape(1,-1) for a in iter([4,0,1,4])]
actor.train_on_batch(s, out, sample_weight=[-1])