为什么 TensorFlow 在训练后预测全 0 或全 1?
Why is TensorFlow predicting all 0's or all 1's after training?
所以我的问题是我 运行 通过了 TensorFlow 教程中的初级代码,并根据我的需要修改了它,但是当我制作它时 print sess.run(accuracy, feed_dict={x: x_test, y_: y_test})
它过去总是打印出一个1.0,现在它总是猜测 0 并打印出约 93% 的准确率。当我使用 tf.argmin(y,1), tf.argmin(y_,1)
时,它会猜测全为 1,并产生大约 7% 的准确率。将两者相加等于 100%。我不明白 tf.argmin
是如何猜测 1 而 tf.argmax
是如何猜测 0 的。显然代码有问题。请看一下,让我知道我能做些什么来解决这个问题。我认为代码在训练过程中出错了,但我可能是错的。
import tensorflow as tf
import numpy as np
from numpy import genfromtxt
data = genfromtxt('cs-training.csv',delimiter=',') # Training data
test_data = genfromtxt('cs-test.csv',delimiter=',') # Test data
x_train = []
for i in data:
x_train.append(i[1:])
x_train = np.array(x_train)
y_train = []
for i in data:
if i[0] == 0:
y_train.append([1., i[0]])
else:
y_train.append([0., i[0]])
y_train = np.array(y_train)
where_are_NaNs = isnan(x_train)
x_train[where_are_NaNs] = 0
x_test = []
for i in test_data:
x_test.append(i[1:])
x_test = np.array(x_test)
y_test = []
for i in test_data:
if i[0] == 0:
y_test.append([1., i[0]])
else:
y_test.append([0., i[0]])
y_test = np.array(y_test)
where_are_NaNs = isnan(x_test)
x_test[where_are_NaNs] = 0
x = tf.placeholder("float", [None, 10])
W = tf.Variable(tf.zeros([10,2]))
b = tf.Variable(tf.zeros([2]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,2])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print "...Training..."
g = 0
for i in range(len(x_train)):
sess.run(train_step, feed_dict={x: [x_train[g]], y_: [y_train[g]]})
g += 1
在这一点上,如果我做到 print [x_train[g]]
和 print [y_train[g]]
,这就是结果的样子。
[array([ 7.66126609e-01, 4.50000000e+01, 2.00000000e+00,
8.02982129e-01, 9.12000000e+03, 1.30000000e+01,
0.00000000e+00, 6.00000000e+00, 0.00000000e+00,
2.00000000e+00])]
[array([ 0., 1.])]
好的,那我们继续
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: x_test, y_: y_test})
0.929209
这个百分比不会改变。无论我为 2 类(1 或 0)创建的 onehot,它都在猜测全零。
下面来看看数据-
print x_train[:10]
[[ 7.66126609e-01 4.50000000e+01 2.00000000e+00 8.02982129e-01
9.12000000e+03 1.30000000e+01 0.00000000e+00 6.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 9.57151019e-01 4.00000000e+01 0.00000000e+00 1.21876201e-01
2.60000000e+03 4.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 6.58180140e-01 3.80000000e+01 1.00000000e+00 8.51133750e-02
3.04200000e+03 2.00000000e+00 1.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.33809776e-01 3.00000000e+01 0.00000000e+00 3.60496820e-02
3.30000000e+03 5.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 9.07239400e-01 4.90000000e+01 1.00000000e+00 2.49256950e-02
6.35880000e+04 7.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.13178682e-01 7.40000000e+01 0.00000000e+00 3.75606969e-01
3.50000000e+03 3.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 3.05682465e-01 5.70000000e+01 0.00000000e+00 5.71000000e+03
0.00000000e+00 8.00000000e+00 0.00000000e+00 3.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 7.54463648e-01 3.90000000e+01 0.00000000e+00 2.09940017e-01
3.50000000e+03 8.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.16950644e-01 2.70000000e+01 0.00000000e+00 4.60000000e+01
0.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.89169052e-01 5.70000000e+01 0.00000000e+00 6.06290901e-01
2.36840000e+04 9.00000000e+00 0.00000000e+00 4.00000000e+00
0.00000000e+00 2.00000000e+00]]
print y_train[:10]
[[ 0. 1.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]]
print x_test[:20]
[[ 4.83539240e-02 4.40000000e+01 0.00000000e+00 3.02297622e-01
7.48500000e+03 1.10000000e+01 0.00000000e+00 1.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 9.10224439e-01 4.20000000e+01 5.00000000e+00 1.72900000e+03
0.00000000e+00 5.00000000e+00 2.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.92682927e-01 5.80000000e+01 0.00000000e+00 3.66480079e-01
3.03600000e+03 7.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 3.11547538e-01 3.30000000e+01 1.00000000e+00 3.55431993e-01
4.67500000e+03 1.10000000e+01 0.00000000e+00 1.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 0.00000000e+00 7.20000000e+01 0.00000000e+00 2.16630600e-03
6.00000000e+03 9.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.79217052e-01 4.50000000e+01 1.00000000e+00 4.89921122e-01
6.84500000e+03 8.00000000e+00 0.00000000e+00 2.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 0.00000000e+00 7.80000000e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 1.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 9.10363487e-01 2.80000000e+01 0.00000000e+00 4.99451497e-01
6.38000000e+03 8.00000000e+00 0.00000000e+00 2.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 6.36595797e-01 4.40000000e+01 0.00000000e+00 7.85457163e-01
4.16600000e+03 6.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.41549211e-01 2.60000000e+01 0.00000000e+00 2.68407434e-01
4.25000000e+03 4.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 4.14101100e-03 7.80000000e+01 0.00000000e+00 2.26362500e-03
5.74200000e+03 7.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 9.99999900e-01 6.00000000e+01 0.00000000e+00 1.20000000e+02
0.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 6.28525944e-01 4.70000000e+01 0.00000000e+00 1.13100000e+03
0.00000000e+00 5.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 4.02283095e-01 6.00000000e+01 0.00000000e+00 3.79442065e-01
8.63800000e+03 1.00000000e+01 0.00000000e+00 1.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 5.70997900e-03 8.10000000e+01 0.00000000e+00 2.17382000e-04
2.30000000e+04 4.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 4.71171849e-01 5.10000000e+01 0.00000000e+00 1.53700000e+03
0.00000000e+00 1.40000000e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.42395210e-02 8.20000000e+01 0.00000000e+00 7.40466500e-03
2.70000000e+03 1.00000000e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 4.67455800e-02 3.70000000e+01 0.00000000e+00 1.48010090e-02
9.12000000e+03 8.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 4.00000000e+00]
[ 9.99999900e-01 4.70000000e+01 0.00000000e+00 3.54604127e-01
1.10000000e+04 1.10000000e+01 0.00000000e+00 2.00000000e+00
0.00000000e+00 3.00000000e+00]
[ 8.96417860e-02 2.70000000e+01 0.00000000e+00 8.14664000e-03
5.40000000e+03 6.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]]
print y_test[:20]
[[ 1. 0.]
[ 0. 1.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 0. 1.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]]
tl;dr:上面发布的示例代码计算交叉熵的方式在数值上并不稳健。请改用 tf.nn.cross_entropy_with_logits
。
(对问题的 v1 的回应,它已经改变):我担心你的训练实际上没有 运行 完成或工作,基于你的 nan
s x_train 您显示的数据。我建议首先修复它 - 并确定它们出现的原因并修复该错误,然后查看您的测试集中是否也有 nan
s。也可能有助于显示 x_test 和 y_test。
最后,我认为 y_
与 x 相关的处理方式存在错误。代码写得好像 y_
是一个单热矩阵,但是当你显示 y_train[:10]
时,它只有 10 个元素,而不是 10*num_classes
个类别。我怀疑那里有错误。当您在轴 1 上对它进行 argmax 时,您总是会得到一个全为零的向量(因为该轴上只有一个元素,所以它当然是最大元素)。将其与在估计中产生始终为零输出的错误相结合,您总是会产生 "correct" 答案。 :)
修订版更新
在更改后的版本中,如果您 运行 它并通过将代码更改为如下所示在每次执行结束时打印出 W :
_, w_out, b_out = sess.run([train_step, W, b], feed_dict={x: [x_train[g]], y_: [y_train[g]]})
您会发现 W 充满了 nan
。要对此进行调试,您可以仔细查看您的代码以查看是否存在可以发现的数学问题,或者您可以通过管道返回仪器以查看它们出现的位置。让我们试试吧。首先,cross_entropy
是什么? (将cross_entropy
添加到run
语句中的事物列表中并打印出来)
Cross entropy: inf
太棒了!所以为什么?好吧,一个答案是当:
y = [0, 1]
tf.log(y) = [-inf, 0]
这是 y 的一个有效可能输出,但是您对交叉熵的计算并不稳健。您可以手动添加一些 epsilons 以避免极端情况,或者使用 tf.nn.softmax_cross_entropy_with_logits
为您完成。我推荐后者:
yprime = tf.matmul(x,W)+b
y = tf.nn.softmax(yprime)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(yprime, y_)
我不保证您的模型一定有效,但这应该可以解决您当前的 NaN 问题。
所以我的问题是我 运行 通过了 TensorFlow 教程中的初级代码,并根据我的需要修改了它,但是当我制作它时 print sess.run(accuracy, feed_dict={x: x_test, y_: y_test})
它过去总是打印出一个1.0,现在它总是猜测 0 并打印出约 93% 的准确率。当我使用 tf.argmin(y,1), tf.argmin(y_,1)
时,它会猜测全为 1,并产生大约 7% 的准确率。将两者相加等于 100%。我不明白 tf.argmin
是如何猜测 1 而 tf.argmax
是如何猜测 0 的。显然代码有问题。请看一下,让我知道我能做些什么来解决这个问题。我认为代码在训练过程中出错了,但我可能是错的。
import tensorflow as tf
import numpy as np
from numpy import genfromtxt
data = genfromtxt('cs-training.csv',delimiter=',') # Training data
test_data = genfromtxt('cs-test.csv',delimiter=',') # Test data
x_train = []
for i in data:
x_train.append(i[1:])
x_train = np.array(x_train)
y_train = []
for i in data:
if i[0] == 0:
y_train.append([1., i[0]])
else:
y_train.append([0., i[0]])
y_train = np.array(y_train)
where_are_NaNs = isnan(x_train)
x_train[where_are_NaNs] = 0
x_test = []
for i in test_data:
x_test.append(i[1:])
x_test = np.array(x_test)
y_test = []
for i in test_data:
if i[0] == 0:
y_test.append([1., i[0]])
else:
y_test.append([0., i[0]])
y_test = np.array(y_test)
where_are_NaNs = isnan(x_test)
x_test[where_are_NaNs] = 0
x = tf.placeholder("float", [None, 10])
W = tf.Variable(tf.zeros([10,2]))
b = tf.Variable(tf.zeros([2]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,2])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print "...Training..."
g = 0
for i in range(len(x_train)):
sess.run(train_step, feed_dict={x: [x_train[g]], y_: [y_train[g]]})
g += 1
在这一点上,如果我做到 print [x_train[g]]
和 print [y_train[g]]
,这就是结果的样子。
[array([ 7.66126609e-01, 4.50000000e+01, 2.00000000e+00,
8.02982129e-01, 9.12000000e+03, 1.30000000e+01,
0.00000000e+00, 6.00000000e+00, 0.00000000e+00,
2.00000000e+00])]
[array([ 0., 1.])]
好的,那我们继续
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: x_test, y_: y_test})
0.929209
这个百分比不会改变。无论我为 2 类(1 或 0)创建的 onehot,它都在猜测全零。
下面来看看数据-
print x_train[:10]
[[ 7.66126609e-01 4.50000000e+01 2.00000000e+00 8.02982129e-01
9.12000000e+03 1.30000000e+01 0.00000000e+00 6.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 9.57151019e-01 4.00000000e+01 0.00000000e+00 1.21876201e-01
2.60000000e+03 4.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 6.58180140e-01 3.80000000e+01 1.00000000e+00 8.51133750e-02
3.04200000e+03 2.00000000e+00 1.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.33809776e-01 3.00000000e+01 0.00000000e+00 3.60496820e-02
3.30000000e+03 5.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 9.07239400e-01 4.90000000e+01 1.00000000e+00 2.49256950e-02
6.35880000e+04 7.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.13178682e-01 7.40000000e+01 0.00000000e+00 3.75606969e-01
3.50000000e+03 3.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 3.05682465e-01 5.70000000e+01 0.00000000e+00 5.71000000e+03
0.00000000e+00 8.00000000e+00 0.00000000e+00 3.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 7.54463648e-01 3.90000000e+01 0.00000000e+00 2.09940017e-01
3.50000000e+03 8.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.16950644e-01 2.70000000e+01 0.00000000e+00 4.60000000e+01
0.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.89169052e-01 5.70000000e+01 0.00000000e+00 6.06290901e-01
2.36840000e+04 9.00000000e+00 0.00000000e+00 4.00000000e+00
0.00000000e+00 2.00000000e+00]]
print y_train[:10]
[[ 0. 1.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]]
print x_test[:20]
[[ 4.83539240e-02 4.40000000e+01 0.00000000e+00 3.02297622e-01
7.48500000e+03 1.10000000e+01 0.00000000e+00 1.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 9.10224439e-01 4.20000000e+01 5.00000000e+00 1.72900000e+03
0.00000000e+00 5.00000000e+00 2.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.92682927e-01 5.80000000e+01 0.00000000e+00 3.66480079e-01
3.03600000e+03 7.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 3.11547538e-01 3.30000000e+01 1.00000000e+00 3.55431993e-01
4.67500000e+03 1.10000000e+01 0.00000000e+00 1.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 0.00000000e+00 7.20000000e+01 0.00000000e+00 2.16630600e-03
6.00000000e+03 9.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.79217052e-01 4.50000000e+01 1.00000000e+00 4.89921122e-01
6.84500000e+03 8.00000000e+00 0.00000000e+00 2.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 0.00000000e+00 7.80000000e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 1.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 9.10363487e-01 2.80000000e+01 0.00000000e+00 4.99451497e-01
6.38000000e+03 8.00000000e+00 0.00000000e+00 2.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 6.36595797e-01 4.40000000e+01 0.00000000e+00 7.85457163e-01
4.16600000e+03 6.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.41549211e-01 2.60000000e+01 0.00000000e+00 2.68407434e-01
4.25000000e+03 4.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 4.14101100e-03 7.80000000e+01 0.00000000e+00 2.26362500e-03
5.74200000e+03 7.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 9.99999900e-01 6.00000000e+01 0.00000000e+00 1.20000000e+02
0.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 6.28525944e-01 4.70000000e+01 0.00000000e+00 1.13100000e+03
0.00000000e+00 5.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 4.02283095e-01 6.00000000e+01 0.00000000e+00 3.79442065e-01
8.63800000e+03 1.00000000e+01 0.00000000e+00 1.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 5.70997900e-03 8.10000000e+01 0.00000000e+00 2.17382000e-04
2.30000000e+04 4.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 4.71171849e-01 5.10000000e+01 0.00000000e+00 1.53700000e+03
0.00000000e+00 1.40000000e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.42395210e-02 8.20000000e+01 0.00000000e+00 7.40466500e-03
2.70000000e+03 1.00000000e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 4.67455800e-02 3.70000000e+01 0.00000000e+00 1.48010090e-02
9.12000000e+03 8.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 4.00000000e+00]
[ 9.99999900e-01 4.70000000e+01 0.00000000e+00 3.54604127e-01
1.10000000e+04 1.10000000e+01 0.00000000e+00 2.00000000e+00
0.00000000e+00 3.00000000e+00]
[ 8.96417860e-02 2.70000000e+01 0.00000000e+00 8.14664000e-03
5.40000000e+03 6.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]]
print y_test[:20]
[[ 1. 0.]
[ 0. 1.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 0. 1.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]]
tl;dr:上面发布的示例代码计算交叉熵的方式在数值上并不稳健。请改用 tf.nn.cross_entropy_with_logits
。
(对问题的 v1 的回应,它已经改变):我担心你的训练实际上没有 运行 完成或工作,基于你的 nan
s x_train 您显示的数据。我建议首先修复它 - 并确定它们出现的原因并修复该错误,然后查看您的测试集中是否也有 nan
s。也可能有助于显示 x_test 和 y_test。
最后,我认为 y_
与 x 相关的处理方式存在错误。代码写得好像 y_
是一个单热矩阵,但是当你显示 y_train[:10]
时,它只有 10 个元素,而不是 10*num_classes
个类别。我怀疑那里有错误。当您在轴 1 上对它进行 argmax 时,您总是会得到一个全为零的向量(因为该轴上只有一个元素,所以它当然是最大元素)。将其与在估计中产生始终为零输出的错误相结合,您总是会产生 "correct" 答案。 :)
修订版更新 在更改后的版本中,如果您 运行 它并通过将代码更改为如下所示在每次执行结束时打印出 W :
_, w_out, b_out = sess.run([train_step, W, b], feed_dict={x: [x_train[g]], y_: [y_train[g]]})
您会发现 W 充满了 nan
。要对此进行调试,您可以仔细查看您的代码以查看是否存在可以发现的数学问题,或者您可以通过管道返回仪器以查看它们出现的位置。让我们试试吧。首先,cross_entropy
是什么? (将cross_entropy
添加到run
语句中的事物列表中并打印出来)
Cross entropy: inf
太棒了!所以为什么?好吧,一个答案是当:
y = [0, 1]
tf.log(y) = [-inf, 0]
这是 y 的一个有效可能输出,但是您对交叉熵的计算并不稳健。您可以手动添加一些 epsilons 以避免极端情况,或者使用 tf.nn.softmax_cross_entropy_with_logits
为您完成。我推荐后者:
yprime = tf.matmul(x,W)+b
y = tf.nn.softmax(yprime)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(yprime, y_)
我不保证您的模型一定有效,但这应该可以解决您当前的 NaN 问题。