TFAGENTS:Valid/Invalid REINFORCE 代理的操作
TFAGENTS: Valid/Invalid actions for REINFORCE agents
使用 TFAGENTS 创建 DQN 代理时,可以指定
屏蔽 valid/invalid 操作的函数。
这是通过指定 observation_and_action_constraint_splitter 函数完成的。
显然无法对 REINFORCE 代理执行相同的操作。
如何在使用 REINFORCE 代理时屏蔽 valid/invalid 操作?
编辑:
似乎有一个开箱即用的方法可以通过实施 MaskSplitterNetwork
:
假设过滤函数的形式为:
def filter_fun(observation):
return observation['observation'], observation['legal_moves']
创建演员网络(如果需要,价值网络)并在 MaskSplitterNetwork 构造函数中将其包装:
masked_actor_network = mask_splitter_network.MaskSplitterNetwork(
splitter_fn=filter_fun,
wrapped_network=actor_distribution_network.ActorDistributionNetwork(
train_env.observation_spec()['observation'],
train_env.action_spec(),
fc_layer_params=fc_layer_params
),
passthrough_mask=True
)
并将蒙面演员网络输入强化代理
agent = reinforce_agent.ReinforceAgent(
train_env.time_step_spec(),
train_env.action_spec(),
actor_network=masked_actor_network,
optimizer=optimizer,
normalize_returns=True,
train_step_counter=train_step_counter,
)
使用 TFAGENTS 创建 DQN 代理时,可以指定 屏蔽 valid/invalid 操作的函数。
这是通过指定 observation_and_action_constraint_splitter 函数完成的。
显然无法对 REINFORCE 代理执行相同的操作。
如何在使用 REINFORCE 代理时屏蔽 valid/invalid 操作?
编辑:
似乎有一个开箱即用的方法可以通过实施 MaskSplitterNetwork
:
假设过滤函数的形式为:
def filter_fun(observation):
return observation['observation'], observation['legal_moves']
创建演员网络(如果需要,价值网络)并在 MaskSplitterNetwork 构造函数中将其包装:
masked_actor_network = mask_splitter_network.MaskSplitterNetwork(
splitter_fn=filter_fun,
wrapped_network=actor_distribution_network.ActorDistributionNetwork(
train_env.observation_spec()['observation'],
train_env.action_spec(),
fc_layer_params=fc_layer_params
),
passthrough_mask=True
)
并将蒙面演员网络输入强化代理
agent = reinforce_agent.ReinforceAgent(
train_env.time_step_spec(),
train_env.action_spec(),
actor_network=masked_actor_network,
optimizer=optimizer,
normalize_returns=True,
train_step_counter=train_step_counter,
)