摆脱 maxpooling 层导致 运行 cuda out memory error pytorch
Getting rid of maxpooling layer causes running cuda out memory error pytorch
显卡:gtx1070ti 8Gb,batchsize 64,输入图片大小128*128。
我有这样的 UNET 和 resnet152 作为编码器,工作非常好:
class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool) #from that pool layer I would like to get rid off
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
self.dec5 = DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlockV(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
center = self.center(conv5)
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
# blocks
class DecoderBlockV(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class DecoderCenter(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderCenter, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels)
)
def forward(self, x):
return self.block(x)
然后我编辑了我的 class 外观,使其在没有池化层的情况下工作:
class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.relu = nn.ReLU(inplace=True)
self.input_adjust = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu)
self.conv1 = self.encoder.layer1
self.conv2 = self.encoder.layer2
self.conv3 = self.encoder.layer3
self.conv4 = self.encoder.layer4
self.dec4 = DecoderBlockV(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec2 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec1 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,is_deconv)
self.final = nn.Conv2d(num_filters * 2 * 2, num_classes, kernel_size=1)
def forward(self, x):
input_adjust = self.input_adjust(x)
conv1 = self.conv1(input_adjust)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
center = self.conv4(conv3)
dec4 = self.dec4(center) #now without centblock
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = F.dropout2d(self.dec1(torch.cat([dec2, conv1], 1)), p=self.dropout_2d)
return self.final(dec1)
is_deconv - 在这两种情况下都是正确的。更改后停止使用 batchsize 64,仅使用大小为 16 或使用 batchsize 64 但仅使用 resnet16 - 否则 cuda 内存不足。我做错了什么?
满栈错误:
~/Desktop/ml/salt/open-solution-salt-identification-master/common_blocks/unet_models.py in forward(self, x)
418 conv1 = self.conv1(input_adjust)
419 conv2 = self.conv2(conv1)
--> 420 conv3 = self.conv3(conv2)
421 center = self.conv4(conv3)
422 dec4 = self.dec4(center)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
355 result = self._slow_forward(*input, **kwargs)
356 else:
--> 357 result = self.forward(*input, **kwargs)
358 for hook in self._forward_hooks.values():
359 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
65 def forward(self, input):
66 for module in self._modules.values():
---> 67 input = module(input)
68 return input
69
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
355 result = self._slow_forward(*input, **kwargs)
356 else:
--> 357 result = self.forward(*input, **kwargs)
358 for hook in self._forward_hooks.values():
359 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.6/site-packages/torchvision-0.2.0-py3.6.egg/torchvision/models/resnet.py in forward(self, x)
79
80 out = self.conv2(out)
---> 81 out = self.bn2(out)
82 out = self.relu(out)
问题是你没有足够的内存,正如评论中已经提到的那样。
更具体地说,问题在于由于最大池化的移除导致大小增加,因为您已经正确地缩小了它的范围。最大池化的要点 - 除了设置的增加 invariance - 是减少 图像大小 ,这样它会消耗更少的内存,因为你需要存储更少的激活对于反向传播部分。
最好回答最大池化的作用可能会有所帮助。我最喜欢的是 Quora 上的那个。
正如您已经正确地发现的那样,批量大小在内存消耗方面也起着巨大的作用。一般情况下,it is also preferred to use smaller batch sizes anyways,只要你之后没有暴涨的处理时间即可。
显卡:gtx1070ti 8Gb,batchsize 64,输入图片大小128*128。 我有这样的 UNET 和 resnet152 作为编码器,工作非常好:
class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool) #from that pool layer I would like to get rid off
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
self.dec5 = DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlockV(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
center = self.center(conv5)
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
# blocks
class DecoderBlockV(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class DecoderCenter(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderCenter, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels)
)
def forward(self, x):
return self.block(x)
然后我编辑了我的 class 外观,使其在没有池化层的情况下工作:
class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.relu = nn.ReLU(inplace=True)
self.input_adjust = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu)
self.conv1 = self.encoder.layer1
self.conv2 = self.encoder.layer2
self.conv3 = self.encoder.layer3
self.conv4 = self.encoder.layer4
self.dec4 = DecoderBlockV(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec2 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec1 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,is_deconv)
self.final = nn.Conv2d(num_filters * 2 * 2, num_classes, kernel_size=1)
def forward(self, x):
input_adjust = self.input_adjust(x)
conv1 = self.conv1(input_adjust)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
center = self.conv4(conv3)
dec4 = self.dec4(center) #now without centblock
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = F.dropout2d(self.dec1(torch.cat([dec2, conv1], 1)), p=self.dropout_2d)
return self.final(dec1)
is_deconv - 在这两种情况下都是正确的。更改后停止使用 batchsize 64,仅使用大小为 16 或使用 batchsize 64 但仅使用 resnet16 - 否则 cuda 内存不足。我做错了什么?
满栈错误:
~/Desktop/ml/salt/open-solution-salt-identification-master/common_blocks/unet_models.py in forward(self, x)
418 conv1 = self.conv1(input_adjust)
419 conv2 = self.conv2(conv1)
--> 420 conv3 = self.conv3(conv2)
421 center = self.conv4(conv3)
422 dec4 = self.dec4(center)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
355 result = self._slow_forward(*input, **kwargs)
356 else:
--> 357 result = self.forward(*input, **kwargs)
358 for hook in self._forward_hooks.values():
359 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
65 def forward(self, input):
66 for module in self._modules.values():
---> 67 input = module(input)
68 return input
69
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
355 result = self._slow_forward(*input, **kwargs)
356 else:
--> 357 result = self.forward(*input, **kwargs)
358 for hook in self._forward_hooks.values():
359 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.6/site-packages/torchvision-0.2.0-py3.6.egg/torchvision/models/resnet.py in forward(self, x)
79
80 out = self.conv2(out)
---> 81 out = self.bn2(out)
82 out = self.relu(out)
问题是你没有足够的内存,正如评论中已经提到的那样。
更具体地说,问题在于由于最大池化的移除导致大小增加,因为您已经正确地缩小了它的范围。最大池化的要点 - 除了设置的增加 invariance - 是减少 图像大小 ,这样它会消耗更少的内存,因为你需要存储更少的激活对于反向传播部分。
最好回答最大池化的作用可能会有所帮助。我最喜欢的是 Quora 上的那个。 正如您已经正确地发现的那样,批量大小在内存消耗方面也起着巨大的作用。一般情况下,it is also preferred to use smaller batch sizes anyways,只要你之后没有暴涨的处理时间即可。