如何在自定义 Tensorflow 层中支持混合精度?

How to support mixed precision in custom Tensorflow layers?

在为 tf.keras 开发自己的自定义层时:我应该如何支持混合精度?

documentation of mixed precision - 目前在 Tensorflow 2.2 中标记为实验性的功能 - 仅解释了如何从消费者的角度使用预定义层(例如 tf.keras.layers.Dense 层)。

我已经尝试自己猜测并找到了两个 - 可能相关的 - 详细信息:

我是否应该使用上面的 get_layer_policy 方法并将变量转换为层的 call(...) 方法中的 compute_dtype? (并在创建变量时将我的图层 build(...) 方法中的 variable_dtype 传递给 add_weight(...)?)

例如,这里是标准密集神经元层的简单示例实现:

  def call(self, input):
    policy = mixed_precision.get_layer_policy(self)
    bias = tf.cast(self._bias, policy.compute_dtype)
    weights = tf.cast(self._weights, policy.compute_dtype)
    y = tf.nn.bias_add(tf.matmul(input, weights), bias)
    outputs = self._activation(y)
    return outputs

当然,没有人会自己实现这么基本的东西,那只是为了演示。但是,这会是 Tensorflow 团队期望我们实现自定义层的 call(...) 方法的方式吗?

This 来自 nvidia 的指南(幻灯片 13-14)提到了用于混合精度训练的自定义层。

您必须实施方法 cast_input()。在此示例中,当启用混合精度时,图层被转换为 float16:

class CustomBiasLayer(tf.keras.layers.Layer):

 def build(self, _):
   self.v = self.add_weight('v', ())
   self.built = True
  
 def call(self, inputs):
   return inputs + self.v

 def cast_inputs(self, inputs):
   # Casts to float16, the policy's lowest-precision dtype
   return self._mixed_precision_policy.cast_to_lowest(inputs)

我自己还没有尝试过,所以请告诉我这是否适合你。