class 分类交叉熵与函数 categorical_crossentropy
class CategoricalCrossentropy versus function categorical_crossentropy
在 Tensorflow2 中,我可以使用 class tf.keras.losses.CategoricalCrossentropy
(defined here) or the function categorical_crossentropy
(defined here) 来计算标签和预测之间的交叉熵损失:
第一个代码是:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
# ...
@tf.function
def train_step(samples, labels):
with tf.GradientTape() as tape:
predictions = model(samples)
loss = loss_object(labels, predictions)
#...
而第二种更直接:
@tf.function
def forward(features, training=False):
predictions = model.call(...)
loss = tf.losses.categorical_crossentropy(
y_true=features['label'],
y_pred=predictions)
return loss, predictions
数值结果相同。但是,我想知道是否有一种更有效的方法?或者更一般地说,根据某些特定情况应该使用哪一个?
请注意,任何 class/function defined by the API
的问题可能相同
Note that all losses are available both via a class handle and via a function handle. The class handles enable you to pass configuration arguments to the constructor (e.g. loss_fn = CategoricalCrossentropy(from_logits=True)), and they perform reduction by default when used in a standalone way (see details below).
在 Tensorflow2 中,我可以使用 class tf.keras.losses.CategoricalCrossentropy
(defined here) or the function categorical_crossentropy
(defined here) 来计算标签和预测之间的交叉熵损失:
第一个代码是:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
# ...
@tf.function
def train_step(samples, labels):
with tf.GradientTape() as tape:
predictions = model(samples)
loss = loss_object(labels, predictions)
#...
而第二种更直接:
@tf.function
def forward(features, training=False):
predictions = model.call(...)
loss = tf.losses.categorical_crossentropy(
y_true=features['label'],
y_pred=predictions)
return loss, predictions
数值结果相同。但是,我想知道是否有一种更有效的方法?或者更一般地说,根据某些特定情况应该使用哪一个?
请注意,任何 class/function defined by the API
的问题可能相同Note that all losses are available both via a class handle and via a function handle. The class handles enable you to pass configuration arguments to the constructor (e.g. loss_fn = CategoricalCrossentropy(from_logits=True)), and they perform reduction by default when used in a standalone way (see details below).