在 tensorflow.js 中,如何计算模型输入的梯度?
In tensorflow.js, how do I compute the gradient wrt a model input?
我想计算关于 TensorFlow.js 中输入向量的损失梯度。
这是我尝试过的:
function f(img) {
return tf.metrics.categoricalCrossentropy(model.predict(img), lbl);
// (Typo: the order of arguments should be flipped, but it does not affect the question here)
}
var g = tf.grad(f);
g(img).print();
img
是形状为 [1, 784] 的张量。 lbl
是形状为 [1, 10] 的张量。 model
是经过 tf.Sequential
.
训练的普通 MNIST DNN
调用 g(img)
失败,堆栈跟踪:
Uncaught TypeError: Cannot read property 'shape' of undefined
at gradFunc (Concat_grad.js:29)
at Object.s.gradient (engine.js:931)
at a (tape.js:158)
at tape.js:136
at engine.js:1038
at engine.js:433
at e.t.scopedRun (engine.js:444)
at e.t.tidy (engine.js:431)
at e.t.gradients (engine.js:1033)
at gradients.js:69
我错过了什么?
通过删除 f
范围之外的 model.predict
,tf.grad
将起作用。
function f(img) {
return tf.metrics.categoricalCrossentropy(img, lbl);
}
var g = tf.grad(f);
const output = model.predict(img);
g(output).print();
顺序模型似乎有一个错误,该错误已从发布 2.7 +
中修复
我原来的代码片段是正确的;在 TensorFlow.js 的 2.6.0 和 2.5.0 版本中有一个 tf.grad
bug 导致了这个错误。
代码在 2.4.0 或新版本 2.7.0 中按预期工作。
我想计算关于 TensorFlow.js 中输入向量的损失梯度。
这是我尝试过的:
function f(img) {
return tf.metrics.categoricalCrossentropy(model.predict(img), lbl);
// (Typo: the order of arguments should be flipped, but it does not affect the question here)
}
var g = tf.grad(f);
g(img).print();
img
是形状为 [1, 784] 的张量。 lbl
是形状为 [1, 10] 的张量。 model
是经过 tf.Sequential
.
调用 g(img)
失败,堆栈跟踪:
Uncaught TypeError: Cannot read property 'shape' of undefined
at gradFunc (Concat_grad.js:29)
at Object.s.gradient (engine.js:931)
at a (tape.js:158)
at tape.js:136
at engine.js:1038
at engine.js:433
at e.t.scopedRun (engine.js:444)
at e.t.tidy (engine.js:431)
at e.t.gradients (engine.js:1033)
at gradients.js:69
我错过了什么?
通过删除 f
范围之外的 model.predict
,tf.grad
将起作用。
function f(img) {
return tf.metrics.categoricalCrossentropy(img, lbl);
}
var g = tf.grad(f);
const output = model.predict(img);
g(output).print();
顺序模型似乎有一个错误,该错误已从发布 2.7 +
我原来的代码片段是正确的;在 TensorFlow.js 的 2.6.0 和 2.5.0 版本中有一个 tf.grad
bug 导致了这个错误。
代码在 2.4.0 或新版本 2.7.0 中按预期工作。