TFAGENTS:Valid/Invalid REINFORCE 代理的操作

TFAGENTS: Valid/Invalid actions for REINFORCE agents

使用 TFAGENTS 创建 DQN 代理时,可以指定 屏蔽 valid/invalid 操作的函数。

这是通过指定 observation_and_action_constraint_splitter 函数完成的。

显然无法对 REINFORCE 代理执行相同的操作。

如何在使用 REINFORCE 代理时屏蔽 valid/invalid 操作?

编辑:

似乎有一个开箱即用的方法可以通过实施 MaskSplitterNetwork:

假设过滤函数的形式为:

def filter_fun(observation):
    return observation['observation'], observation['legal_moves']

创建演员网络(如果需要,价值网络)并在 MaskSplitterNetwork 构造函数中将其包装:

masked_actor_network = mask_splitter_network.MaskSplitterNetwork(
    splitter_fn=filter_fun,
    wrapped_network=actor_distribution_network.ActorDistributionNetwork(
        train_env.observation_spec()['observation'],
        train_env.action_spec(),
        fc_layer_params=fc_layer_params
    ),
    passthrough_mask=True
)

并将蒙面演员网络输入强化代理

agent = reinforce_agent.ReinforceAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    actor_network=masked_actor_network,
    optimizer=optimizer,
    normalize_returns=True,
    train_step_counter=train_step_counter,
)