tensorflow - tf.confusion_matrix() throws error ValueError: Shape (2, 2048, 2) must have rank 2
tensorflow - tf.confusion_matrix() throws error ValueError: Shape (2, 2048, 2) must have rank 2
我尝试确定我的神经网络模型的 confusion_matrix
使用 google tensorflow 在 python 中编写。
通过使用这段代码:
cm = tf.zeros(shape=[2,2], dtype=tf.int32)
for i in range(0, validation_data.shape[0], batch_size_validation):
batched_val_data = np.array(validation_data[i:i+batch_size_validation, :, :], dtype='float')
batched_val_labels = np.array(validation_labels[i:i+batch_size_validation, :], dtype='float')
batched_val_data = batched_val_data.reshape((-1, n_chunks, chunk_size))
_acc, _c, _p = sess.run([accuracy, correct, pred], feed_dict=({x:batched_val_data, y:batched_val_labels}))
#batched_val_labels.shape ==> (2048, 2)
#_p.shape ==> (2048, 2)
#this piece of code throws the error!
cm = tf.confusion_matrix(labels=batched_val_labels, predictions=_p)
我收到以下错误:
ValueError: 形状 (2, 2048, 2) 必须具有等级 2
至少你应该知道验证标签数组 batched_val_labels 是一个 one hot array。
有人可以帮我吗?提前致谢!
问题是我使用的是 one hot array。
按照此说明操作:Tensorflow confusion matrix using one-hot code
我修改了这段代码:
cm = tf.confusion_matrix(labels=batched_val_labels, predictions=_p)
进入:
cm = tf.confusion_matrix(labels=tf.argmax(batched_val_labels, 1), predictions=tf.argmax(_p, 1))
我尝试确定我的神经网络模型的 confusion_matrix 使用 google tensorflow 在 python 中编写。 通过使用这段代码:
cm = tf.zeros(shape=[2,2], dtype=tf.int32)
for i in range(0, validation_data.shape[0], batch_size_validation):
batched_val_data = np.array(validation_data[i:i+batch_size_validation, :, :], dtype='float')
batched_val_labels = np.array(validation_labels[i:i+batch_size_validation, :], dtype='float')
batched_val_data = batched_val_data.reshape((-1, n_chunks, chunk_size))
_acc, _c, _p = sess.run([accuracy, correct, pred], feed_dict=({x:batched_val_data, y:batched_val_labels}))
#batched_val_labels.shape ==> (2048, 2)
#_p.shape ==> (2048, 2)
#this piece of code throws the error!
cm = tf.confusion_matrix(labels=batched_val_labels, predictions=_p)
我收到以下错误: ValueError: 形状 (2, 2048, 2) 必须具有等级 2
至少你应该知道验证标签数组 batched_val_labels 是一个 one hot array。 有人可以帮我吗?提前致谢!
问题是我使用的是 one hot array。 按照此说明操作:Tensorflow confusion matrix using one-hot code
我修改了这段代码:
cm = tf.confusion_matrix(labels=batched_val_labels, predictions=_p)
进入:
cm = tf.confusion_matrix(labels=tf.argmax(batched_val_labels, 1), predictions=tf.argmax(_p, 1))