如何从另一个数组索引到张量 tensorflow

How do I index from another array into a tensor tensorflow

我正在尝试为 AI 中的问题编写深度 q 学习网络。我有一个函数 predict(),它产生一个形状为 (None, 3) 的张量,接受一个形状为 (None, 5) 的输入。 (None, 3)中的3对应的是每个状态下可以采取的每个动作的q值。现在,在训练步骤中,我必须多次调用 predict() 并使用结果计算成本并训练模型。为此,我还有另一个名为 current_actions 的数据数组,它是一个列表,其中包含在先前迭代中针对特定状态所采取的操作的索引。

需要发生的是 current_states_outputs 应该是从 predict() 的输出创建的张量,其中每一行仅包含一个 q 值(而不是 [= 的输出的三个) 11=]),选择哪个q值应该取决于current_actions.

对应的索引

例如current_states_output = [[1,2,3],[4,5,6],[7,8,9]]current_actions=[0,2,1],运算后的结果应该是[1,6,8](已更新)

我该怎么做?

我尝试了以下 -

    current_states_outputs = self.sess.run(self.prediction, feed_dict={self.X:current_states})
    current_states_outputs = np.array([current_states_outputs[a][current_actions[a]] for a in range(len(current_actions))])

我基本 运行 在 predict() 上的会话,并使用正常的 python 方法完成了所需的操作。但是因为这切断了成本与图表前几层的联系,所以无法进行任何训练。因此,我需要在 tensorflow 中执行此操作并将所有内容都保留为 tensorflow 张量本身。我该如何管理?

你可以试试,

tf.squeeze(tf.gather_nd(a,tf.stack([tf.range(b.shape[0])[...,tf.newaxis], b[...,tf.newaxis]], axis=2)))

示例代码:

a = tf.Variable(current_states_outputs)
b = tf.Variable(current_actions)
out = tf.squeeze(tf.gather_nd(a,tf.stack([tf.range(b.shape[0])[...,tf.newaxis], b[...,tf.newaxis]], axis=2)))
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
sess.run(out)

#output
[1, 6, 8]