神经网络之后的[np.arange(0, self.batch_size), action]的目的是什么?
What is the purpose of [np.arange(0, self.batch_size), action] after the neural network?
我按照 PyTorch 教程学习了强化学习(TRAIN A MARIO-PLAYING RL AGENT),但我对以下代码感到困惑:
current_Q = self.net(state, model="online")[np.arange(0, self.batch_size), action] # Q_online(s,a)
神经网络后的[np.arange(0, self.batch_size),action]的作用是什么?(我知道TD_estimate是取state和action的,只是一头雾水这在编程方面)这是什么用法(在self.net之后放一个列表)?
教程中引用的更多相关代码:
class MarioNet(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
c, h, w = input_dim
if h != 84:
raise ValueError(f"Expecting input height: 84, got: {h}")
if w != 84:
raise ValueError(f"Expecting input width: 84, got: {w}")
self.online = nn.Sequential(
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
nn.ReLU(),
nn.Linear(512, output_dim),
)
self.target = copy.deepcopy(self.online)
# Q_target parameters are frozen.
for p in self.target.parameters():
p.requires_grad = False
def forward(self, input, model):
if model == "online":
return self.online(input)
elif model == "target":
return self.target(input)
self.net:
self.net = MarioNet(self.state_dim, self.action_dim).float()
感谢您的帮助!
本质上,这里发生的是网络的输出被切片以获得 Q 的所需部分 table。
[np.arange(0, self.batch_size), action]
的(有点令人困惑的)索引索引每个轴。因此,对于索引为 1 的轴,我们选择 action
指示的项目。对于索引 0,我们选择 0 到 self.batch_size
.
之间的所有项目
如果self.batch_size
与这个数组的0维长度相同,那么这个切片就可以简化为[:, action]
,这可能是大多数用户比较熟悉的
我按照 PyTorch 教程学习了强化学习(TRAIN A MARIO-PLAYING RL AGENT),但我对以下代码感到困惑:
current_Q = self.net(state, model="online")[np.arange(0, self.batch_size), action] # Q_online(s,a)
神经网络后的[np.arange(0, self.batch_size),action]的作用是什么?(我知道TD_estimate是取state和action的,只是一头雾水这在编程方面)这是什么用法(在self.net之后放一个列表)?
教程中引用的更多相关代码:
class MarioNet(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
c, h, w = input_dim
if h != 84:
raise ValueError(f"Expecting input height: 84, got: {h}")
if w != 84:
raise ValueError(f"Expecting input width: 84, got: {w}")
self.online = nn.Sequential(
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
nn.ReLU(),
nn.Linear(512, output_dim),
)
self.target = copy.deepcopy(self.online)
# Q_target parameters are frozen.
for p in self.target.parameters():
p.requires_grad = False
def forward(self, input, model):
if model == "online":
return self.online(input)
elif model == "target":
return self.target(input)
self.net:
self.net = MarioNet(self.state_dim, self.action_dim).float()
感谢您的帮助!
本质上,这里发生的是网络的输出被切片以获得 Q 的所需部分 table。
[np.arange(0, self.batch_size), action]
的(有点令人困惑的)索引索引每个轴。因此,对于索引为 1 的轴,我们选择 action
指示的项目。对于索引 0,我们选择 0 到 self.batch_size
.
如果self.batch_size
与这个数组的0维长度相同,那么这个切片就可以简化为[:, action]
,这可能是大多数用户比较熟悉的