在CNN中正确使用tf.cond
Proper use of tf.cond in CNN
我对 tf.cond 的使用有疑问。我正在尝试通过 CNN 运行 2 张图像,并且仅使用较低的交叉熵损失值进行反向传播。代码如下:
train_cross_entropy = tf.cond(train_cross_entropy1 < train_cross_entropy2,
lambda: train_cross_entropy1,
lambda: train_cross_entropy2)
用这个train_cross_entropy跟写字一样慢
train_cross_entropy = train_cross_entropy1 + train_cross_entropy2
这对我来说表明它正在通过图表的两个部分进行反向传播,而不仅仅是一个。
我希望它几乎和写作一样快
train_cross_entropy = train_cross_entropy1
如果有人对如何实现这一点有任何想法,我们将不胜感激!谢谢
假设您使用相同的 CNN 处理两张图像,这是有道理的。让我们分别考虑前向(输入 -> 成本)和后向(backprop/gradients)计算。
对于正向计算,两个输入都需要进行条件处理,因为需要比较两个交叉熵值。因此,tf.cond
的情况并不比将两个成本相加快。
对于反向计算,实际上没有区别:在任何一种情况下,误差都需要从输出层一直反向传播到网络的开头。请注意,我们正在计算关于 变量 (网络权重)的梯度;输入被认为是固定的。因此,添加多少输入并不重要:这只是改变了反向传播在输出层开始的标量成本值。实际传播保持不变(只是具有不同的值)。
我只需要像这样在 tf.cond 中移动梯度计算:
def f1():
grads = tf.gradients(train_cross_entropy1, var_list,
stop_gradients=[train_cross_entropy2])
return grads
def f2():
grads = tf.gradients(train_cross_entropy2, var_list,
stop_gradients=[train_cross_entropy1])
return grads
gradients = tf.cond(train_cross_entropy1 < train_cross_entropy2, f1, f2)
然后我可以稍后应用渐变。
我对 tf.cond 的使用有疑问。我正在尝试通过 CNN 运行 2 张图像,并且仅使用较低的交叉熵损失值进行反向传播。代码如下:
train_cross_entropy = tf.cond(train_cross_entropy1 < train_cross_entropy2,
lambda: train_cross_entropy1,
lambda: train_cross_entropy2)
用这个train_cross_entropy跟写字一样慢
train_cross_entropy = train_cross_entropy1 + train_cross_entropy2
这对我来说表明它正在通过图表的两个部分进行反向传播,而不仅仅是一个。
我希望它几乎和写作一样快
train_cross_entropy = train_cross_entropy1
如果有人对如何实现这一点有任何想法,我们将不胜感激!谢谢
假设您使用相同的 CNN 处理两张图像,这是有道理的。让我们分别考虑前向(输入 -> 成本)和后向(backprop/gradients)计算。
对于正向计算,两个输入都需要进行条件处理,因为需要比较两个交叉熵值。因此,tf.cond
的情况并不比将两个成本相加快。
对于反向计算,实际上没有区别:在任何一种情况下,误差都需要从输出层一直反向传播到网络的开头。请注意,我们正在计算关于 变量 (网络权重)的梯度;输入被认为是固定的。因此,添加多少输入并不重要:这只是改变了反向传播在输出层开始的标量成本值。实际传播保持不变(只是具有不同的值)。
我只需要像这样在 tf.cond 中移动梯度计算:
def f1():
grads = tf.gradients(train_cross_entropy1, var_list,
stop_gradients=[train_cross_entropy2])
return grads
def f2():
grads = tf.gradients(train_cross_entropy2, var_list,
stop_gradients=[train_cross_entropy1])
return grads
gradients = tf.cond(train_cross_entropy1 < train_cross_entropy2, f1, f2)
然后我可以稍后应用渐变。