pytorch cyclegann 测试时出现 Missing key 错误

pytorch cyclegann gives a Missing key error when testing

我已经使用 pix2pix pytorch implementation 训练了一个模型并想对其进行测试。

但是当我测试它时我得到了错误

model [CycleGANModel] was created
loading the model from ./checkpoints/cycbw50/latest_net_G_A.pth
Traceback (most recent call last):
  File "test.py", line 47, in <module>
    model.setup(opt)               # regular setup: load and print networks; create schedulers
  File "/media/bitlockermount/SmartImageToDigitalTwin/SmartImageToDigitalTwin/bin/python/cyclegann/pytorch-CycleGAN-and-pix2pix/models/base_model.py", line 88, in setup
    self.load_networks(load_suffix)
  File "/media/bitlockermount/SmartImageToDigitalTwin/SmartImageToDigitalTwin/bin/python/cyclegann/pytorch-CycleGAN-and-pix2pix/models/base_model.py", line 199, in load_networks
    net.load_state_dict(state_dict)
  File "/home/bst/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 846, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResnetGenerator:
    Missing key(s) in state_dict: "model.1.bias", "model.4.bias", "model.7.bias", "model.10.conv_block.1.bias", "model.10.conv_block.5.bias", "model.11.conv_block.1.bias", "model.11.conv_block.5.bias", "model.12.conv_block.1.bias", "model.12.conv_block.5.bias", "model.13.conv_block.1.bias", "model.13.conv_block.5.bias", "model.14.conv_block.1.bias", "model.14.conv_block.5.bias", "model.15.conv_block.1.bias", "model.15.conv_block.5.bias", "model.16.conv_block.1.bias", "model.16.conv_block.5.bias", "model.17.conv_block.1.bias", "model.17.conv_block.5.bias", "model.18.conv_block.1.bias", "model.18.conv_block.5.bias", "model.19.bias", "model.22.bias". 
    Unexpected key(s) in state_dict: "model.2.weight", "model.2.bias", "model.5.weight", "model.5.bias", "model.8.weight", "model.8.bias", "model.10.conv_block.2.weight", "model.10.conv_block.2.bias", "model.10.conv_block.6.weight", "model.10.conv_block.6.bias", "model.11.conv_block.2.weight", "model.11.conv_block.2.bias", "model.11.conv_block.6.weight", "model.11.conv_block.6.bias", "model.12.conv_block.2.weight", "model.12.conv_block.2.bias", "model.12.conv_block.6.weight", "model.12.conv_block.6.bias", "model.13.conv_block.2.weight", "model.13.conv_block.2.bias", "model.13.conv_block.6.weight", "model.13.conv_block.6.bias", "model.14.conv_block.2.weight", "model.14.conv_block.2.bias", "model.14.conv_block.6.weight", "model.14.conv_block.6.bias", "model.15.conv_block.2.weight", "model.15.conv_block.2.bias", "model.15.conv_block.6.weight", "model.15.conv_block.6.bias", "model.16.conv_block.2.weight", "model.16.conv_block.2.bias", "model.16.conv_block.6.weight", "model.16.conv_block.6.bias", "model.17.conv_block.2.weight", "model.17.conv_block.2.bias", "model.17.conv_block.6.weight", "model.17.conv_block.6.bias", "model.18.conv_block.2.weight", "model.18.conv_block.2.bias", "model.18.conv_block.6.weight", "model.18.conv_block.6.bias", "model.20.weight", "model.20.bias", "model.23.weight", "model.23.bias". 

培训的选择文件:

----------------- Options ---------------
               batch_size: 1                             
                    beta1: 0.5                           
          checkpoints_dir: ./checkpoints                 
           continue_train: False                         
                crop_size: 256                           
                 dataroot: ./datasets/datasets/boundedwalls_50_0.1/ [default: None]
             dataset_mode: aligned                          [default: unaligned]
                direction: AtoB                          
              display_env: main                          
             display_freq: 400                           
               display_id: 1                             
            display_ncols: 4                             
             display_port: 8097                          
           display_server: http://localhost              
          display_winsize: 256                           
                    epoch: latest                        
              epoch_count: 1                             
                 gan_mode: lsgan                         
                  gpu_ids: 0                             
                init_gain: 0.02                          
                init_type: normal                        
                 input_nc: 3                             
                  isTrain: True                             [default: None]
                 lambda_A: 10.0                          
                 lambda_B: 10.0                          
          lambda_identity: 0.5                           
                load_iter: 0                                [default: 0]
                load_size: 286                           
                       lr: 0.0002                        
           lr_decay_iters: 50                            
                lr_policy: linear                        
         max_dataset_size: inf                           
                    model: cycle_gan                     
                 n_epochs: 100                           
           n_epochs_decay: 100                           
               n_layers_D: 3                             
                     name: cycbw50                          [default: experiment_name]
                      ndf: 64                            
                     netD: basic                         
                     netG: resnet_9blocks                
                      ngf: 64                            
               no_dropout: True                          
                  no_flip: False                         
                  no_html: False                         
                     norm: batch                            [default: instance]
              num_threads: 4                             
                output_nc: 3                             
                    phase: train                         
                pool_size: 50                            
               preprocess: resize_and_crop               
               print_freq: 100                           
             save_by_iter: False                         
          save_epoch_freq: 5                             
         save_latest_freq: 5000                          
           serial_batches: False                         
                   suffix:                               
         update_html_freq: 1000                          
                  verbose: False                         
----------------- End -------------------

测试时我使用了以下测试设置

python3 test.py --dataroot ./datasets/datasets/boundedwalls_50_0.1 --name cycbw50 --model pix2pix --netG resnet_9blocks --direction BtoA --dataset_mode aligned --norm batch --load_size 286

----------------- Options ---------------
             aspect_ratio: 1.0                           
               batch_size: 1                             
          checkpoints_dir: ./checkpoints                 
                crop_size: 256                           
                 dataroot: ./datasets/datasets/boundedwalls_50_0.1/ [default: None]
             dataset_mode: unaligned                     
                direction: AtoB                          
          display_winsize: 256                           
                    epoch: latest                        
                     eval: False                         
                  gpu_ids: 0                             
                init_gain: 0.02                          
                init_type: normal                        
                 input_nc: 3                             
                  isTrain: False                            [default: None]
                load_iter: 0                                [default: 0]
                load_size: 256                           
         max_dataset_size: inf                           
                    model: cycle_gan                        [default: test]
               n_layers_D: 3                             
                     name: cycbw50                          [default: experiment_name]
                      ndf: 64                            
                     netD: basic                         
                     netG: resnet_9blocks                
                      ngf: 64                            
               no_dropout: True                          
                  no_flip: False                         
                     norm: instance                      
                 num_test: 50                            
              num_threads: 4                             
                output_nc: 3                             
                    phase: test                          
               preprocess: resize_and_crop               
              results_dir: ./results/                    
           serial_batches: False                         
                   suffix:                               
                  verbose: False                         
----------------- End -------------------

有人看到我在这里做错了什么吗?我想 运行 这个网络,这样我就可以得到单个图像的结果,到目前为止,测试函数似乎是最有希望的,但它只是在这个神经网络上崩溃了。

我认为这里的问题是 bias=None 的某个层,但是在测试模型时需要这个,您应该检查代码以了解详细信息。

我在训练和测试中检查了你的配置后,norm 是不同的。对于GitHub中的代码,范数差可以设置偏置项为True或False。

if type(norm_layer) == functools.partial:
   use_bias = norm_layer.func == nn.InstanceNorm2d
else:
   use_bias = norm_layer == nn.InstanceNorm2d

model = [nn.ReflectionPad2d(3), 
         nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
         norm_layer(ngf),
         nn.ReLU(True)]

你可以查看一下here