输出时不兼容的形状 * actions_one_hot
Incompatible shapes when output * actions_one_hot
我正在尝试实现一个播放 Doom (vizdoom)
的深度 Q 网络
但是(从昨天开始)我一直被一个热编码的问题及其后果所困扰:事实上,我有 3 种可能的编码方式
[[True, False, False], [False, True, False], [False, False, True]]
尺寸 = [Batch_size, 3]
当我 one_hot 编码这个动作数组时,我获得了一个这个大小的数组 [BatchSize, 3, 3]
当我想计算我的 Q 值估计时的结果:
Q = tf.reduce_sum(tf.multiply(self.output, self.actions_one_hot), axis=1)
tf.multiply(self.output, self.actions_one_hot)
产生错误:
InvalidArgumentError: Incompatible shapes: [10,3] vs. [10,3,3]
[[Node: DQNetwork/Mul = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DQNetwork/dense/BiasAdd, DQNetwork/one_hot)]]
我知道这 2 个有不兼容的形状要相乘但我不明白我必须做什么才能使它们兼容。
更清楚this is the notebook with each part explained:
我确定我犯了一个非常愚蠢的错误,但我没有看到它。
感谢您的帮助!
您必须使形状与 tf.multiply
兼容,因为函数是 element-wise 乘法。
但是,我认为您可能对 one_hot
做错了什么。通常,one_hot
函数会将数字转换为单热矩阵。假设您的操作 space 中有 3 个可能的操作,它们是 (0,1,2),one hot 函数会将其转换为 [[1,0,0],[0,1,0],[0,0,1]]
。
问题是您将 one_hot 向量发送到另一个 one_hot 函数。如果直接发送动作,两个张量的形状将相同。
长话短说,您正在使用 one_hot 函数两次。如果你已经有一个 [True, False, False] 类型的向量,你已经有一个 one_hot.
我正在尝试实现一个播放 Doom (vizdoom)
的深度 Q 网络但是(从昨天开始)我一直被一个热编码的问题及其后果所困扰:事实上,我有 3 种可能的编码方式
[[True, False, False], [False, True, False], [False, False, True]]
尺寸 = [Batch_size, 3]
当我 one_hot 编码这个动作数组时,我获得了一个这个大小的数组 [BatchSize, 3, 3]
当我想计算我的 Q 值估计时的结果:
Q = tf.reduce_sum(tf.multiply(self.output, self.actions_one_hot), axis=1)
tf.multiply(self.output, self.actions_one_hot)
产生错误:
InvalidArgumentError: Incompatible shapes: [10,3] vs. [10,3,3]
[[Node: DQNetwork/Mul = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DQNetwork/dense/BiasAdd, DQNetwork/one_hot)]]
我知道这 2 个有不兼容的形状要相乘但我不明白我必须做什么才能使它们兼容。
更清楚this is the notebook with each part explained:
我确定我犯了一个非常愚蠢的错误,但我没有看到它。
感谢您的帮助!
您必须使形状与 tf.multiply
兼容,因为函数是 element-wise 乘法。
但是,我认为您可能对 one_hot
做错了什么。通常,one_hot
函数会将数字转换为单热矩阵。假设您的操作 space 中有 3 个可能的操作,它们是 (0,1,2),one hot 函数会将其转换为 [[1,0,0],[0,1,0],[0,0,1]]
。
问题是您将 one_hot 向量发送到另一个 one_hot 函数。如果直接发送动作,两个张量的形状将相同。
长话短说,您正在使用 one_hot 函数两次。如果你已经有一个 [True, False, False] 类型的向量,你已经有一个 one_hot.