tf.gather_nd 在 Keras 模型中使用时似乎有问题('>' 在 'NoneType' 和 'int' 的实例之间不受支持)

tf.gather_nd seems to be problematic when used in a Keras model ( '>' not supported between instances of 'NoneType' and 'int' )

尝试运行这个模型

import tensorflow as tf
import numpy as np

from tensorflow.keras.models import Model 
from tensorflow.keras.layers import Input 
from tensorflow.keras.layers import Reshape 
from tensorflow.keras.layers import Convolution2D 
from tensorflow.keras.layers import MaxPooling2D 
from tensorflow.keras.layers import Dense 
from tensorflow.keras.layers import Flatten


inputs = Input(shape=(80,80,8), name="images" )

conv_1 = Convolution2D(32, (5, 5), padding="same", activation='relu')(inputs) conv_1 = MaxPooling2D(strides=(2,2))(conv_1)

conv_2 = Convolution2D(64, (5, 5), padding="same", activation='relu')(conv_1) conv_2 = MaxPooling2D(strides=(2,2))(conv_1) conv_2_flat = Flatten()(conv_2)

dense_1 = Dense(512, activation='relu')(conv_2_flat) 
y_pred = Dense(3, name='prediction')(dense_1) 
act = tf.math.argmax(y_pred, 1)

enum_action = Input(shape=(2), dtype=tf.int32, name="enum_act") 
gathered_layer = tf.gather_nd(y_pred, enum_action)

model = Model(inputs=[inputs,enum_action], outputs=gathered_layer) 
model.compile(optimizer=opt, loss='mean_squared_error',metrics=['accuracy'])

image = np.arange(0, 80*80*8)
enum = np.array([[0,0]])
image = image.reshape(1, 80,80,8) 
y_true = np.array([[12]]) 

model .fit([image,enum], y_true)

但我不断收到此错误消息

'>' 'NoneType' 和 'int'

实例之间不支持

知道我的模型有什么问题吗?

这样试试

inputs = Input(shape=(80,80,8), name="images" )

conv_1 = Convolution2D(32, (5, 5), padding="same", activation='relu')(inputs) 
conv_1 = MaxPooling2D(strides=(2,2))(conv_1)

conv_2 = Convolution2D(64, (5, 5), padding="same", activation='relu')(conv_1) 
conv_2 = MaxPooling2D(strides=(2,2))(conv_1) 
conv_2_flat = Flatten()(conv_2)

dense_1 = Dense(512, activation='relu')(conv_2_flat) 
y_pred = Dense(3, name='prediction')(dense_1) 

enum_action = Input(shape=(2), dtype=tf.int32, name="enum_act") 
gathered_layer = Lambda(lambda x: tf.expand_dims(tf.gather_nd(x[0],x[1]),-1))([y_pred, enum_action])

model = Model(inputs=[inputs,enum_action], outputs=gathered_layer) 
model.compile(optimizer='adam', loss='mean_squared_error',metrics=['accuracy'])
model.summary()

image = np.arange(0, 80*80*8)
enum = np.array([[0,0]])
image = image.reshape(1, 80,80,8) 
y_true = np.array([[0]]) 

model.fit([image,enum], y_true)