通过顺序容器构建的 PyTorch 扁平化层
Flatten layer of PyTorch build by sequential container
我正在尝试通过 PyTorch 的顺序容器构建一个 cnn,我的问题是我不知道如何展平该层。
main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2))
main.add_module('flatten', make_it_flatten)
我应该在 "make_it_flatten" 中输入什么?
我试图压平主要但它不起作用,主要不存在调用视图
main = main.view(-1, 16*3*3)
这可能不是您要找的东西,但您可以简单地创建自己的 nn.Module
来展平任何输入,然后您可以将其添加到 nn.Sequential()
对象:
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size()[0], -1)
x.size()[0]
将 select 批量暗淡,-1
将计算所有剩余的暗淡以适应元素的数量,从而展平任何 tensor/Variable。
并在 nn.Sequential
中使用它:
main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2))
main.add_module('flatten', Flatten())
扁平化层的最快方法不是创建新模块而是通过 main.add_module('flatten', Flatten())
.
将该模块添加到主模块
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
相反,正如我在 .
中展示的那样,您模型 forward
中的简单 out = inp.reshape(inp.size(0), -1)
速度更快
我正在尝试通过 PyTorch 的顺序容器构建一个 cnn,我的问题是我不知道如何展平该层。
main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2))
main.add_module('flatten', make_it_flatten)
我应该在 "make_it_flatten" 中输入什么? 我试图压平主要但它不起作用,主要不存在调用视图
main = main.view(-1, 16*3*3)
这可能不是您要找的东西,但您可以简单地创建自己的 nn.Module
来展平任何输入,然后您可以将其添加到 nn.Sequential()
对象:
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size()[0], -1)
x.size()[0]
将 select 批量暗淡,-1
将计算所有剩余的暗淡以适应元素的数量,从而展平任何 tensor/Variable。
并在 nn.Sequential
中使用它:
main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2))
main.add_module('flatten', Flatten())
扁平化层的最快方法不是创建新模块而是通过 main.add_module('flatten', Flatten())
.
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
相反,正如我在
forward
中的简单 out = inp.reshape(inp.size(0), -1)
速度更快