无法加载已保存的策略(TF 代理)
Cant load saved policy (TF-agents)
我用策略保护程序保存了经过训练的策略,如下所示:
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
tf_policy_saver.save(policy_dir)
我想继续使用保存的策略进行训练。所以我尝试用保存的策略初始化训练,这导致了一些错误。
agent = dqn_agent.DqnAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()
agent.policy=tf.compat.v2.saved_model.load(policy_dir)
错误:
File "C:/Users/Rohit/PycharmProjects/pythonProject/waypoint.py", line 172, in <module>
agent.policy=tf.compat.v2.saved_model.load('waypoints\Two_rewards')
File "C:\Users\Rohit\anaconda3\envs\btp36\lib\site-packages\tensorflow\python\training\tracking\tracking.py", line 92, in __setattr__
super(AutoTrackable, self).__setattr__(name, value)
AttributeError: can't set attribute
我只是想节省每次从头开始重新训练的时间。如何加载已保存的策略并继续训练?
提前致谢
为此,您应该查看 Checkpointer。
是的,如前所述,您应该使用检查点来执行此操作,请查看下面的示例代码。
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
policy=policy)
... # Train the agent
# Policy --> X
policy_checkpointer.save(global_step=epoch_counter.numpy())
当您稍后想要重新加载策略时,您只需 运行 相同的初始化代码。
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y1, possibly Y1==Y depending on agent class you are using, if it's DQN
# then they are different because of random initialization of network weights
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
policy=policy)
# Policy --> X
创建后,policy_checkpointer
会自动识别是否有任何预先存在的检查点。如果有,它将在创建时自动更新它正在跟踪的变量的值。
做几个笔记:
- 使用检查点可以节省的不仅仅是策略,实际上我建议这样做。 TF-Agent 的 Checkpointer 对象非常灵活,例如:
train_checkpointer = common.Checkpointer(ckpt_dir=first/dir,
agent=tf_agent, # tf_agent.TFAgent
train_step=train_step, # tf.Variable
epoch_counter=epoch_counter, # tf.Variable
metrics=metric_utils.MetricsGroup(
train_metrics, 'train_metrics'))
policy_checkpointer = common.Checkpointer(ckpt_dir=second/dir,
policy=agent.policy)
rb_checkpointer = common.Checkpointer(ckpt_dir=third/dir,
max_to_keep=1,
replay_buffer=replay_buffer # TFUniformReplayBuffer
)
- 请注意,在
DqnAgent
的情况下,agent.policy
和 agent.collect_policy
本质上是 QNetwork 的包装器。其含义如下面的代码所示(查看策略变量状态的注释)
agent = DqnAgent(...)
policy = agent.policy # Random initial policy ---> X
dataset = replay_buffer.as_dataset(...)
for data in dataset:
experience, _ = data
loss_agent_info = agent.train(experience=experience)
# policy variable stores a trained Policy object ---> Y
发生这种情况是因为 TF 中的张量在您的 运行 时间内共享。因此,当您使用 agent.train
更新代理的 QNetwork
权重时,这些相同的权重也会在您的 policy
变量的 QNetwork
中隐式更新。事实上,并不是 policy
的张量得到更新,而是它们与 agent
.
中的张量相同
我用策略保护程序保存了经过训练的策略,如下所示:
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
tf_policy_saver.save(policy_dir)
我想继续使用保存的策略进行训练。所以我尝试用保存的策略初始化训练,这导致了一些错误。
agent = dqn_agent.DqnAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()
agent.policy=tf.compat.v2.saved_model.load(policy_dir)
错误:
File "C:/Users/Rohit/PycharmProjects/pythonProject/waypoint.py", line 172, in <module>
agent.policy=tf.compat.v2.saved_model.load('waypoints\Two_rewards')
File "C:\Users\Rohit\anaconda3\envs\btp36\lib\site-packages\tensorflow\python\training\tracking\tracking.py", line 92, in __setattr__
super(AutoTrackable, self).__setattr__(name, value)
AttributeError: can't set attribute
我只是想节省每次从头开始重新训练的时间。如何加载已保存的策略并继续训练?
提前致谢
为此,您应该查看 Checkpointer。
是的,如前所述,您应该使用检查点来执行此操作,请查看下面的示例代码。
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
policy=policy)
... # Train the agent
# Policy --> X
policy_checkpointer.save(global_step=epoch_counter.numpy())
当您稍后想要重新加载策略时,您只需 运行 相同的初始化代码。
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y1, possibly Y1==Y depending on agent class you are using, if it's DQN
# then they are different because of random initialization of network weights
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
policy=policy)
# Policy --> X
创建后,policy_checkpointer
会自动识别是否有任何预先存在的检查点。如果有,它将在创建时自动更新它正在跟踪的变量的值。
做几个笔记:
- 使用检查点可以节省的不仅仅是策略,实际上我建议这样做。 TF-Agent 的 Checkpointer 对象非常灵活,例如:
train_checkpointer = common.Checkpointer(ckpt_dir=first/dir,
agent=tf_agent, # tf_agent.TFAgent
train_step=train_step, # tf.Variable
epoch_counter=epoch_counter, # tf.Variable
metrics=metric_utils.MetricsGroup(
train_metrics, 'train_metrics'))
policy_checkpointer = common.Checkpointer(ckpt_dir=second/dir,
policy=agent.policy)
rb_checkpointer = common.Checkpointer(ckpt_dir=third/dir,
max_to_keep=1,
replay_buffer=replay_buffer # TFUniformReplayBuffer
)
- 请注意,在
DqnAgent
的情况下,agent.policy
和agent.collect_policy
本质上是 QNetwork 的包装器。其含义如下面的代码所示(查看策略变量状态的注释)
agent = DqnAgent(...)
policy = agent.policy # Random initial policy ---> X
dataset = replay_buffer.as_dataset(...)
for data in dataset:
experience, _ = data
loss_agent_info = agent.train(experience=experience)
# policy variable stores a trained Policy object ---> Y
发生这种情况是因为 TF 中的张量在您的 运行 时间内共享。因此,当您使用 agent.train
更新代理的 QNetwork
权重时,这些相同的权重也会在您的 policy
变量的 QNetwork
中隐式更新。事实上,并不是 policy
的张量得到更新,而是它们与 agent
.