(tf2.keras) InternalError: Recorded operation 'GradientReversalOperator' returned too few gradients. Expected 3 but received 2
(tf2.keras) InternalError: Recorded operation 'GradientReversalOperator' returned too few gradients. Expected 3 but received 2
我的代码在 github 上可用。
我写了一个自定义渐变图层如下:
@tf.custom_gradient
def GradientReversalOperator(x, lambdal):
def grad(dy):
return lambdal * tf.negative(dy)
return x, grad
class GradientReversalLayer(tf.keras.layers.Layer):
def __init__(self, lambdal):
super(GradientReversalLayer, self).__init__()
self.lambdal = lambdal
def call(self, inputs):
return GradientReversalOperator(inputs, self.lambdal)
如果我删除 lambdal
,一切正常。但是当我把它加回去时,我得到了错误:
InternalError: Recorded operation 'GradientReversalOperator' returned too few gradients. Expected 3 but received 2
一些答案报告我应该再制作一个假的 return 值,但错误变成了“梯度太多”。回溯如下:
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\contextlib.py", line 130, in exit
self.gen.throw(type, value, traceback)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2804, in variable_creator_scope
yield
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1695, in train_on_batch
logs = train_function(iterator)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in call
result = self._call(*args, **kwds)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\def_function.py", line 823, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\def_function.py", line 697, in _initialize
*args, **kwds))
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\function.py", line 2855, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\function.py", line 3075, in _create_graph_function
capture_by_value=self._capture_by_value),
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().wrapped(*args, **kwds)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.InternalError: in user code:
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:806 train_function *
return step_function(self, iterator)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:796 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1211 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2585 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2945 _call_for_each_replica
return fn(*args, **kwargs)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:789 run_step **
outputs = model.train_step(data)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:757 train_step
self.trainable_variables)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:2722 _minimize
gradients = tape.gradient(loss, trainable_variables)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\backprop.py:1073 gradient
unconnected_gradients=unconnected_gradients)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\imperative_grad.py:77 imperative_grad
compat.as_str(unconnected_gradients.value))
InternalError: Recorded operation 'GradientReversalOperator' returned too few gradients. Expected 3 but received 2
我遇到了同样的问题,
试试这个:
class GradientReversal(Layer):
'''Flip the sign of gradient during training.'''
@tf.custom_gradient
def grad_reverse(self, x):
y = tf.identity(x)
def custom_grad(dy):
return -self.hp_lambda * dy
return y, custom_grad
# --------------------------------------
def __init__(self, hp_lambda, **kwargs):
super(GradientReversal, self).__init__(**kwargs)
self.hp_lambda = K.variable(hp_lambda, dtype='float32', name='hp_lambda')
# --------------------------------------
def call(self, x, mask=None):
return self.grad_reverse(x)
# --------------------------------------
def set_hp_lambda(self,hp_lambda):
#self.hp_lambda = hp_lambda
K.set_value(self.hp_lambda, hp_lambda)
# --------------------------------------
def increment_hp_lambda_by(self,increment):
new_value = float(K.get_value(self.hp_lambda)) + increment
K.set_value(self.hp_lambda, new_value)
# --------------------------------------
def get_hp_lambda(self):
return float(K.get_value(self.hp_lambda))
我的代码在 github 上可用。
我写了一个自定义渐变图层如下:
@tf.custom_gradient
def GradientReversalOperator(x, lambdal):
def grad(dy):
return lambdal * tf.negative(dy)
return x, grad
class GradientReversalLayer(tf.keras.layers.Layer):
def __init__(self, lambdal):
super(GradientReversalLayer, self).__init__()
self.lambdal = lambdal
def call(self, inputs):
return GradientReversalOperator(inputs, self.lambdal)
如果我删除 lambdal
,一切正常。但是当我把它加回去时,我得到了错误:
InternalError: Recorded operation 'GradientReversalOperator' returned too few gradients. Expected 3 but received 2
一些答案报告我应该再制作一个假的 return 值,但错误变成了“梯度太多”。回溯如下:
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\contextlib.py", line 130, in exit self.gen.throw(type, value, traceback) File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2804, in variable_creator_scope yield File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1695, in train_on_batch logs = train_function(iterator) File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in call result = self._call(*args, **kwds) File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\def_function.py", line 823, in _call self._initialize(args, kwds, add_initializers_to=initializers) File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\def_function.py", line 697, in _initialize *args, **kwds)) File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\function.py", line 2855, in _get_concrete_function_internal_garbage_collected graph_function, _, _ = self._maybe_define_function(args, kwargs) File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\function.py", line 3213, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\function.py", line 3075, in _create_graph_function capture_by_value=self._capture_by_value), File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\func_graph.py", line 986, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\def_function.py", line 600, in wrapped_fn return weak_wrapped_fn().wrapped(*args, **kwds) File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\func_graph.py", line 973, in wrapper raise e.ag_error_metadata.to_exception(e) tensorflow.python.framework.errors_impl.InternalError: in user code:
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:806 train_function * return step_function(self, iterator) D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:796 step_function ** outputs = model.distribute_strategy.run(run_step, args=(data,)) D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1211 run return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2585 call_for_each_replica return self._call_for_each_replica(fn, args, kwargs) D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2945 _call_for_each_replica return fn(*args, **kwargs) D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:789 run_step ** outputs = model.train_step(data) D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:757 train_step self.trainable_variables) D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\training.py:2722 _minimize gradients = tape.gradient(loss, trainable_variables) D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\backprop.py:1073 gradient unconnected_gradients=unconnected_gradients) D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\eager\imperative_grad.py:77 imperative_grad compat.as_str(unconnected_gradients.value)) InternalError: Recorded operation 'GradientReversalOperator' returned too few gradients. Expected 3 but received 2
我遇到了同样的问题, 试试这个:
class GradientReversal(Layer):
'''Flip the sign of gradient during training.'''
@tf.custom_gradient
def grad_reverse(self, x):
y = tf.identity(x)
def custom_grad(dy):
return -self.hp_lambda * dy
return y, custom_grad
# --------------------------------------
def __init__(self, hp_lambda, **kwargs):
super(GradientReversal, self).__init__(**kwargs)
self.hp_lambda = K.variable(hp_lambda, dtype='float32', name='hp_lambda')
# --------------------------------------
def call(self, x, mask=None):
return self.grad_reverse(x)
# --------------------------------------
def set_hp_lambda(self,hp_lambda):
#self.hp_lambda = hp_lambda
K.set_value(self.hp_lambda, hp_lambda)
# --------------------------------------
def increment_hp_lambda_by(self,increment):
new_value = float(K.get_value(self.hp_lambda)) + increment
K.set_value(self.hp_lambda, new_value)
# --------------------------------------
def get_hp_lambda(self):
return float(K.get_value(self.hp_lambda))