AttributeError: module 'tensorflow.keras.mixed_precision' has no attribute 'set_global_policy'

AttributeError: module 'tensorflow.keras.mixed_precision' has no attribute 'set_global_policy'

我需要添加mixed precision to my code in order to save some memory. Specifically, I have tried adding mixed precision policy near line 27 in https://github.com/nimRobotics/google-research/blob/master/ravens/train.py,下面是代码摘录

import argparse
import datetime
import os

import numpy as np
from ravens import agents
from ravens import Dataset
import tensorflow as tf

# tf.keras.mixed_precision.set_global_policy('mixed_float16')

# OR

# policy = tf.keras.mixed_precision.Policy('mixed_float16')
# mixed_precision.set_global_policy(policy)

这两种方法都会导致如下所示的属性错误,我使用的是 Google Colab with TF 2.3.0

使用 tf.keras.mixed_precision.set_global_policy('mixed_float16') 结果

Traceback (most recent call last):
  File "train.py", line 28, in <module>
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
AttributeError: module 'tensorflow.keras.mixed_precision' has no attribute 'set_global_policy'

正在使用

policy = tf.keras.mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

结果

Traceback (most recent call last):
  File "train.py", line 29, in <module>
    policy = tf.keras.mixed_precision.Policy('mixed_float16')
AttributeError: module 'tensorflow.keras.mixed_precision' has no attribute 'Policy'

任何帮助或提示将不胜感激!

对于 tf < 2.4 你应该使用 mixed precision 的实验包。

tf.keras.mixed_precision.experimental.Policy(
    name, loss_scale='auto'
)

例如,在 tf 2.3

policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)

并在 tf 2.4

tf.keras.mixed_precision.set_global_policy('mixed_float16')

tf 2.4 开始,此功能不再是实验性的。