tf.keras.backend.clip 没有给出正确的结果
tf.keras.backend.clip not giving correct results
tf.keras.backend.clip
没有剪裁张量
当我在这个函数中使用 tf.keras.backend.clip
时
def grads_ds(model_ds, ds_inputs,y_true,cw):
print(y_true)
with tf.GradientTape() as ds_tape:
y_pred = model_ds(ds_inputs)
print(y_pred.numpy())
logits_1 = -1*y_true*K.log(y_pred)*cw[:,0]
logits_0 = -1*(1-y_true)*K.log(1-y_pred)*cw[:,1]
loss = logits_1 + logits_0
loss_value_ds = K.sum(loss)
ds_grads = ds_tape.gradient(loss_value_ds,model_ds.trainable_variables,unconnected_gradients=tf.UnconnectedGradients.NONE)
for g in ds_grads:
g = tf.keras.backend.clip(g,min_grad,max_grad)
return loss_value_ds, ds_grads
渐变值保持不变(未剪切)。
当我在自定义训练循环中使用 tf.keras.backend.clip
时,方法相同
for g in ds_grads:
g = tf.keras.backend.clip(g,min_grad,max_grad)
没用。应用于变量的梯度未被剪裁。
但是,如果我在循环中打印 g
,那么它会显示裁剪后的值。
无法理解问题出在哪里。
这是因为您示例中的 g
是对列表中值的引用。当您分配给它时,您只是在更改它指向的值(即您没有修改它指向的当前值)。
考虑这个例子,我想将 lst
中的所有值设置为 5。猜猜当你 运行 这个代码示例时会发生什么?
lst = [1,2,3,4]
for ele in lst:
ele = 5
print(lst)
没有!您会得到完全相同的列表。但是在循环中你会看到 ele 现在是 5,正如你已经在你的案例中发现的那样。这是列表中的值不可变(张量不可变)的情况。
但是,您可以就地修改可变对象:
lst = [[2], [2], [2]]
for ele in lst:
ele.append(3)
print(lst)
以上代码将使每个元素 [2, 3]
符合预期。
解决您的问题的一种方法是:
lst = [1,2,3,4]
for itr in range(len(lst)):
lst[itr] = 5
print(lst)
tf.keras.backend.clip
没有剪裁张量
当我在这个函数中使用 tf.keras.backend.clip
时
def grads_ds(model_ds, ds_inputs,y_true,cw):
print(y_true)
with tf.GradientTape() as ds_tape:
y_pred = model_ds(ds_inputs)
print(y_pred.numpy())
logits_1 = -1*y_true*K.log(y_pred)*cw[:,0]
logits_0 = -1*(1-y_true)*K.log(1-y_pred)*cw[:,1]
loss = logits_1 + logits_0
loss_value_ds = K.sum(loss)
ds_grads = ds_tape.gradient(loss_value_ds,model_ds.trainable_variables,unconnected_gradients=tf.UnconnectedGradients.NONE)
for g in ds_grads:
g = tf.keras.backend.clip(g,min_grad,max_grad)
return loss_value_ds, ds_grads
渐变值保持不变(未剪切)。
当我在自定义训练循环中使用 tf.keras.backend.clip
时,方法相同
for g in ds_grads:
g = tf.keras.backend.clip(g,min_grad,max_grad)
没用。应用于变量的梯度未被剪裁。
但是,如果我在循环中打印 g
,那么它会显示裁剪后的值。
无法理解问题出在哪里。
这是因为您示例中的 g
是对列表中值的引用。当您分配给它时,您只是在更改它指向的值(即您没有修改它指向的当前值)。
考虑这个例子,我想将 lst
中的所有值设置为 5。猜猜当你 运行 这个代码示例时会发生什么?
lst = [1,2,3,4]
for ele in lst:
ele = 5
print(lst)
没有!您会得到完全相同的列表。但是在循环中你会看到 ele 现在是 5,正如你已经在你的案例中发现的那样。这是列表中的值不可变(张量不可变)的情况。
但是,您可以就地修改可变对象:
lst = [[2], [2], [2]]
for ele in lst:
ele.append(3)
print(lst)
以上代码将使每个元素 [2, 3]
符合预期。
解决您的问题的一种方法是:
lst = [1,2,3,4]
for itr in range(len(lst)):
lst[itr] = 5
print(lst)