将 ResNeXt 导入 Keras

Import ResNeXt into Keras

这个问题看起来很难,但我需要知道如何将 ResNeXt 模型导入 Keras Tensor-flow,我试过但没有用

from keras.applications.resnext import ResNeXt50

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-1-ca380748170a> in <module>
----> 1 from keras.applications.resnext import ResNeXt50

~/opt/anaconda3/lib/python3.8/site-packages/keras/__init__.py in <module>
  1 from __future__ import absolute_import
  ----> 2 from . import backend
  3 from . import datasets
  4 from . import engine
  5 from . import layers

 ~/opt/anaconda3/lib/python3.8/site-packages/keras/backend/__init__.py in <module>
 65 elif _BACKEND == 'tensorflow':
 66     sys.stderr.write('Using TensorFlow backend.\n')
 ---> 67     from .tensorflow_backend import *
 68 else:
 69     raise ValueError('Unknown backend: ' + str(_BACKEND))

 ~/opt/anaconda3/lib/python3.8/site-packages/keras/backend/tensorflow_backend.py in <module>
 ----> 1 import tensorflow as tf
  2 
  3 from tensorflow.python.training import moving_averages
  4 from tensorflow.python.ops import tensor_array_ops
  5 from tensorflow.python.ops import control_flow_ops

  No module named 'keras.applications.resnext'

我一直不明白为什么一些常用的模型架构不是 keras 应用程序的一部分,例如 SE-NetResNeXt。但是,有一个著名的 keras 模型动物园存储库,您可以从中获取所需内容。 Classification models Zoo - Keras (and TensorFlow Keras)..

正在安装

!pip install git+https://github.com/qubvel/classification_models.git

正在导入

# for keras
from classification_models.keras import Classifiers

# for tensorflow keras
from classification_models.tfkeras import Classifiers

Classifiers.models_names()
['resnet18',
 'resnet34',
 'resnet50',
 'resnet101',
 'resnet152',
 'seresnet18',
 'seresnet34',
 'seresnet50',
 'seresnet101',
 'seresnet152',
 'seresnext50',
 'seresnext101',
 'senet154',
 'resnet50v2',
 'resnet101v2',
 'resnet152v2',
 'resnext50',
 'resnext101',
 'vgg16',
 'vgg19',
 'densenet121',
 'densenet169',
 'densenet201',
 'inceptionresnetv2',
 'inceptionv3',
 'xception',
 'nasnetlarge',
 'nasnetmobile',
 'mobilenet',
 'mobilenetv2']

如何使用

SeResNeXT, preprocess_input = Classifiers.get('seresnext50')
model = SeResNeXT(include_top = False, input_shape=(224, 224, 3), weights='imagenet')
ResNeXt50, preprocess_input = Classifiers.get('resnext50')
model = ResNeXt50(include_top = False, input_shape=(224, 224, 3), weights='imagenet')