keras 连接操作在 anaconda 上失败但在 google colab 上工作
keras concatenate operation failing on anaconda but working on google colab
我正在使用以下代码来创建具有三重态损失的孪生网络,我在本地使用 python 3.6 和 anaconda。
def triplet_loss_wrapper(margin=1.0, l2_norm=True):
def triplet_loss(y_true, y_pred):
import tensorflow as tf
return tf.subtract(y_pred[:,:,0], y_pred[:,:,1])
return triplet_loss
from keras.layers import Input, Concatenate, concatenate
from keras import Model
from keras.layers import Concatenate, Lambda
input_layer = Input(shape=(784,))
a = Dense(100, activation="relu")(input_layer)
o = Dense(40, activation="relu")(a)
layer1 = Lambda(lambda x: K.expand_dims(x, axis=-1))(o)
layer2 = Lambda(lambda x: K.expand_dims(x, axis=-1))(o)
concat_layer = concatenate([layer1, layer1], axis=2)
model = Model(input_layer, concat_layer)
model.compile(optimizer=SGD(), loss=triplet_loss_wrapper())
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_test = x_test.reshape(x_test.shape[0], 784)
model.fit(x_test, [1] * len(x_test), batch_size =1)
当 运行 在 google colab 上运行和训练时
但是当 运行 在本地使用 anaconda 时,它会失败并出现以下错误
(np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (32, 1) for Tensor
'concatenate_1_target:0', which has shape '(?, ?, ?)'
我的 anaconda 中有以下软件包,我只是 运行 更新所有
_license 1.1 py36_1 alabaster 0.7.10 py36_0 anaconda custom py36_0 anaconda-client 1.6.3 py36_0 anaconda-navigator 1.6.4 py36_0
anaconda-project 0.6.0 py36_0 asn1crypto 0.22.0 py36_0 astroid 1.5.3
py36_0 astropy 2.0.1 np113py36_0 babel 2.5.0 py36_0 backports 1.0
py36_0 backports.weakref 1.0rc1 py36_0 beautifulsoup4 4.6.0 py36_0
bitarray 0.8.1 py36_0 bkcharts 0.2 py36_0 blas 1.0 mkl blaze 0.10.1
py36_0 bleach 1.5.0 py36_0 bokeh 0.12.7 py36_0 boto 2.48.0 py36_0
bottleneck 1.2.1 np113py36_0 cairo 1.14.8 0 certifi 2016.2.28 py36_0
cffi 1.10.0 py36_0 chardet 3.0.4 py36_0 click 6.7 py36_0 cloudpickle
0.4.0 py36_0 clyent 1.2.2 py36_0 colorama 0.3.9 py36_0 contextlib2 0.5.5 py36_0 cryptography 1.8.1 py36_0 cudatoolkit 8.0 3 cudnn 6.0.21 cuda8.0_0 curl 7.52.1 0 cycler 0.10.0 py36_0 cython 0.26 py36_0
cytoolz 0.8.2 py36_0 dask 0.15.2 py36_0 datashape 0.5.4 py36_0 dbus
1.10.20 0 decorator 4.1.2 py36_0 distributed 1.18.1 py36_0 docutils 0.14 py36_0 entrypoints 0.2.3 py36_0 et_xmlfile 1.0.1 py36_0 expat 2.1.0 0 fastcache 1.0.2 py36_1 flask 0.12.2 py36_0 flask-cors 3.0.3 py36_0 fontconfig 2.12.1 3 freetype 2.5.5 2 get_terminal_size 1.0.0
py36_0 gevent 1.2.2 py36_0 glib 2.50.2 1 greenlet 0.4.12 py36_0
gst-plugins-base 1.8.0 0 gstreamer 1.8.0 0 h5py 2.7.0 np113py36_0
harfbuzz 0.9.39 2 hdf5 1.8.17 2 heapdict 1.0.0 py36_1 html5lib
0.9999999 py36_0 icu 54.1 0 idna 2.6 py36_0 imagesize 0.7.1 py36_0 ipykernel 4.6.1 py36_0 ipython 6.1.0 py36_0 ipython_genutils 0.2.0
py36_0 ipywidgets 6.0.0 py36_0 isort 4.2.15 py36_0 itsdangerous 0.24
py36_0 jbig 2.1 0 jdcal 1.3 py36_0 jedi 0.10.2 py36_2 jinja2 2.9.6
py36_0 jpeg 9b 0 jsonschema 2.6.0 py36_0 jupyter 1.0.0 py36_3
jupyter_client 5.1.0 py36_0 jupyter_console 5.2.0 py36_0 jupyter_core
4.3.0 py36_0 keras 2.0.5 py36_0 lazy-object-proxy 1.3.1 py36_0 libffi 3.2.1 1 libgcc 5.2.0 0 libgfortran 3.0.0 1 libgpuarray 0.6.9 0 libiconv 1.14 0 libpng 1.6.30 1 libprotobuf 3.4.0 0 libsodium 1.0.10 0
libtiff 4.0.6 3 libtool 2.4.2 0 libxcb 1.12 1 libxml2 2.9.4 0 libxslt
1.1.29 0 llvmlite 0.20.0 py36_0 locket 0.2.0 py36_1 lxml 3.8.0 py36_0 mako 1.0.6 py36_0 markdown 2.6.9 py36_0 markupsafe 1.0 py36_0
matplotlib 2.0.2 np113py36_0 mistune 0.7.4 py36_0 mkl 2017.0.3 0
mkl-service 1.1.2 py36_3 mpmath 0.19 py36_1 msgpack-python 0.4.8
py36_0 multipledispatch 0.4.9 py36_0 navigator-updater 0.1.0 py36_0
nbconvert 5.2.1 py36_0 nbformat 4.4.0 py36_0 nccl 1.3.4 cuda8.0_1
networkx 1.11 py36_0 nltk 3.2.4 py36_0 nose 1.3.7 py36_1 notebook
5.0.0 py36_0 numba 0.35.0 np113py36_0 numexpr 2.6.2 np113py36_0 numpy 1.13.1 py36_0 numpydoc 0.7.0 py36_0 odo 0.5.1 py36_0 olefile 0.44 py36_0 openpyxl 2.4.8 py36_0 openssl 1.0.2l 0 packaging 16.8 py36_0
pandas 0.20.3 py36_0 pandocfilters 1.4.2 py36_0 pango 1.40.3 1 partd
0.3.8 py36_0 path.py 10.3.1 py36_0 pathlib2 2.3.0 py36_0 patsy 0.4.1 py36_0 pcre 8.39 1 pep8 1.7.0 py36_0 pexpect 4.2.1 py36_0 pickleshare
0.7.4 py36_0 pillow 4.2.1 py36_0 pip 9.0.1 py36_1 pixman 0.34.0 0 ply 3.10 py36_0 prompt_toolkit 1.0.15 py36_0 protobuf 3.4.0 py36_0 psutil 5.2.2 py36_0 ptyprocess 0.5.2 py36_0 py 1.4.33 py36_0 pycodestyle 2.3.1 py36_0 pycosat 0.6.2 py36_0 pycparser 2.18 py36_0 pycrypto 2.6.1 py36_6 pycurl 7.43.0 py36_2 pyflakes 1.6.0 py36_0 pygments 2.2.0
py36_0 pygpu 0.6.9 py36_0 pylint 1.7.2 py36_0 pyodbc 4.0.17 py36_0
pyopenssl 17.0.0 py36_0 pyparsing 2.2.0 py36_0 pyqt 5.6.0 py36_2
pytables 3.4.2 np113py36_0 pytest 3.2.1 py36_0 python 3.6.2 0
python-dateutil 2.6.1 py36_0 pytorch 0.1.12 py36cuda8.0cudnn6.0_1 pytz
2017.2 py36_0 pywavelets 0.5.2 np113py36_0 pyyaml 3.12 py36_0 pyzmq 16.0.2 py36_0 qt 5.6.2 4 qtawesome 0.4.4 py36_0 qtconsole 4.3.1 py36_0 qtpy 1.3.1 py36_0 readline 6.2 2 requests 2.14.2 py36_0 rope 0.9.4
py36_1 ruamel_yaml 0.11.14 py36_1 scikit-image 0.13.0 np113py36_0
scikit-learn 0.19.0 np113py36_0 scipy 0.19.1 np113py36_0 seaborn 0.8
py36_0 setuptools 36.4.0 py36_1 simplegeneric 0.8.1 py36_1
singledispatch 3.4.0.3 py36_0 sip 4.18 py36_0 six 1.10.0 py36_0
snowballstemmer 1.2.1 py36_0 sortedcollections 0.5.3 py36_0
sortedcontainers 1.5.7 py36_0 sphinx 1.6.3 py36_0 sphinxcontrib 1.0
py36_0 sphinxcontrib-websupport 1.0.1 py36_0 spyder 3.2.3 py36_0
sqlalchemy 1.1.13 py36_0 sqlite 3.13.0 0 statsmodels 0.8.0 np113py36_0
sympy 1.1.1 py36_0 tblib 1.3.2 py36_0 tensorflow 1.3.0 0
tensorflow-base 1.3.0 py36h5293eaa_1 tensorflow-tensorboard 0.1.5
py36_0 terminado 0.6 py36_0 testpath 0.3.1 py36_0 theano 0.9.0 py36_0
tk 8.5.18 0 toolz 0.8.2 py36_0 torchvision 0.1.8 py36_0 tornado 4.5.2
py36_0 traitlets 4.3.2 py36_0 unicodecsv 0.14.1 py36_0 unixodbc 2.3.4
0 wcwidth 0.1.7 py36_0 werkzeug 0.12.2 py36_0 wheel 0.29.0 py36_0
widgetsnbextension 3.0.2 py36_0 wrapt 1.10.11 py36_0 xlrd 1.1.0 py36_0
xlsxwriter 0.9.8 py36_0 xlwt 1.3.0 py36_0 xz 5.2.3 0 yaml 0.1.6 0
zeromq 4.1.5 0 zict 0.1.2 py36_0 zlib 1.2.11 0
知道为什么它在 colab 上运行但在本地崩溃了吗?
更新Keras没有解决问题,更新tensorflow解决了问题。
我有 tf 1.3,使用
升级到 1.10
conda install -c conda-forge tensorflow
它确实解决了我的问题。现在 运行 好了。
我正在使用以下代码来创建具有三重态损失的孪生网络,我在本地使用 python 3.6 和 anaconda。
def triplet_loss_wrapper(margin=1.0, l2_norm=True):
def triplet_loss(y_true, y_pred):
import tensorflow as tf
return tf.subtract(y_pred[:,:,0], y_pred[:,:,1])
return triplet_loss
from keras.layers import Input, Concatenate, concatenate
from keras import Model
from keras.layers import Concatenate, Lambda
input_layer = Input(shape=(784,))
a = Dense(100, activation="relu")(input_layer)
o = Dense(40, activation="relu")(a)
layer1 = Lambda(lambda x: K.expand_dims(x, axis=-1))(o)
layer2 = Lambda(lambda x: K.expand_dims(x, axis=-1))(o)
concat_layer = concatenate([layer1, layer1], axis=2)
model = Model(input_layer, concat_layer)
model.compile(optimizer=SGD(), loss=triplet_loss_wrapper())
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_test = x_test.reshape(x_test.shape[0], 784)
model.fit(x_test, [1] * len(x_test), batch_size =1)
当 运行 在 google colab 上运行和训练时 但是当 运行 在本地使用 anaconda 时,它会失败并出现以下错误
(np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) ValueError: Cannot feed value of shape (32, 1) for Tensor 'concatenate_1_target:0', which has shape '(?, ?, ?)'
我的 anaconda 中有以下软件包,我只是 运行 更新所有
_license 1.1 py36_1 alabaster 0.7.10 py36_0 anaconda custom py36_0 anaconda-client 1.6.3 py36_0 anaconda-navigator 1.6.4 py36_0 anaconda-project 0.6.0 py36_0 asn1crypto 0.22.0 py36_0 astroid 1.5.3 py36_0 astropy 2.0.1 np113py36_0 babel 2.5.0 py36_0 backports 1.0 py36_0 backports.weakref 1.0rc1 py36_0 beautifulsoup4 4.6.0 py36_0 bitarray 0.8.1 py36_0 bkcharts 0.2 py36_0 blas 1.0 mkl blaze 0.10.1 py36_0 bleach 1.5.0 py36_0 bokeh 0.12.7 py36_0 boto 2.48.0 py36_0 bottleneck 1.2.1 np113py36_0 cairo 1.14.8 0 certifi 2016.2.28 py36_0 cffi 1.10.0 py36_0 chardet 3.0.4 py36_0 click 6.7 py36_0 cloudpickle 0.4.0 py36_0 clyent 1.2.2 py36_0 colorama 0.3.9 py36_0 contextlib2 0.5.5 py36_0 cryptography 1.8.1 py36_0 cudatoolkit 8.0 3 cudnn 6.0.21 cuda8.0_0 curl 7.52.1 0 cycler 0.10.0 py36_0 cython 0.26 py36_0 cytoolz 0.8.2 py36_0 dask 0.15.2 py36_0 datashape 0.5.4 py36_0 dbus 1.10.20 0 decorator 4.1.2 py36_0 distributed 1.18.1 py36_0 docutils 0.14 py36_0 entrypoints 0.2.3 py36_0 et_xmlfile 1.0.1 py36_0 expat 2.1.0 0 fastcache 1.0.2 py36_1 flask 0.12.2 py36_0 flask-cors 3.0.3 py36_0 fontconfig 2.12.1 3 freetype 2.5.5 2 get_terminal_size 1.0.0 py36_0 gevent 1.2.2 py36_0 glib 2.50.2 1 greenlet 0.4.12 py36_0 gst-plugins-base 1.8.0 0 gstreamer 1.8.0 0 h5py 2.7.0 np113py36_0 harfbuzz 0.9.39 2 hdf5 1.8.17 2 heapdict 1.0.0 py36_1 html5lib 0.9999999 py36_0 icu 54.1 0 idna 2.6 py36_0 imagesize 0.7.1 py36_0 ipykernel 4.6.1 py36_0 ipython 6.1.0 py36_0 ipython_genutils 0.2.0 py36_0 ipywidgets 6.0.0 py36_0 isort 4.2.15 py36_0 itsdangerous 0.24 py36_0 jbig 2.1 0 jdcal 1.3 py36_0 jedi 0.10.2 py36_2 jinja2 2.9.6 py36_0 jpeg 9b 0 jsonschema 2.6.0 py36_0 jupyter 1.0.0 py36_3 jupyter_client 5.1.0 py36_0 jupyter_console 5.2.0 py36_0 jupyter_core 4.3.0 py36_0 keras 2.0.5 py36_0 lazy-object-proxy 1.3.1 py36_0 libffi 3.2.1 1 libgcc 5.2.0 0 libgfortran 3.0.0 1 libgpuarray 0.6.9 0 libiconv 1.14 0 libpng 1.6.30 1 libprotobuf 3.4.0 0 libsodium 1.0.10 0 libtiff 4.0.6 3 libtool 2.4.2 0 libxcb 1.12 1 libxml2 2.9.4 0 libxslt 1.1.29 0 llvmlite 0.20.0 py36_0 locket 0.2.0 py36_1 lxml 3.8.0 py36_0 mako 1.0.6 py36_0 markdown 2.6.9 py36_0 markupsafe 1.0 py36_0 matplotlib 2.0.2 np113py36_0 mistune 0.7.4 py36_0 mkl 2017.0.3 0 mkl-service 1.1.2 py36_3 mpmath 0.19 py36_1 msgpack-python 0.4.8 py36_0 multipledispatch 0.4.9 py36_0 navigator-updater 0.1.0 py36_0 nbconvert 5.2.1 py36_0 nbformat 4.4.0 py36_0 nccl 1.3.4 cuda8.0_1 networkx 1.11 py36_0 nltk 3.2.4 py36_0 nose 1.3.7 py36_1 notebook 5.0.0 py36_0 numba 0.35.0 np113py36_0 numexpr 2.6.2 np113py36_0 numpy 1.13.1 py36_0 numpydoc 0.7.0 py36_0 odo 0.5.1 py36_0 olefile 0.44 py36_0 openpyxl 2.4.8 py36_0 openssl 1.0.2l 0 packaging 16.8 py36_0 pandas 0.20.3 py36_0 pandocfilters 1.4.2 py36_0 pango 1.40.3 1 partd 0.3.8 py36_0 path.py 10.3.1 py36_0 pathlib2 2.3.0 py36_0 patsy 0.4.1 py36_0 pcre 8.39 1 pep8 1.7.0 py36_0 pexpect 4.2.1 py36_0 pickleshare 0.7.4 py36_0 pillow 4.2.1 py36_0 pip 9.0.1 py36_1 pixman 0.34.0 0 ply 3.10 py36_0 prompt_toolkit 1.0.15 py36_0 protobuf 3.4.0 py36_0 psutil 5.2.2 py36_0 ptyprocess 0.5.2 py36_0 py 1.4.33 py36_0 pycodestyle 2.3.1 py36_0 pycosat 0.6.2 py36_0 pycparser 2.18 py36_0 pycrypto 2.6.1 py36_6 pycurl 7.43.0 py36_2 pyflakes 1.6.0 py36_0 pygments 2.2.0 py36_0 pygpu 0.6.9 py36_0 pylint 1.7.2 py36_0 pyodbc 4.0.17 py36_0 pyopenssl 17.0.0 py36_0 pyparsing 2.2.0 py36_0 pyqt 5.6.0 py36_2 pytables 3.4.2 np113py36_0 pytest 3.2.1 py36_0 python 3.6.2 0 python-dateutil 2.6.1 py36_0 pytorch 0.1.12 py36cuda8.0cudnn6.0_1 pytz 2017.2 py36_0 pywavelets 0.5.2 np113py36_0 pyyaml 3.12 py36_0 pyzmq 16.0.2 py36_0 qt 5.6.2 4 qtawesome 0.4.4 py36_0 qtconsole 4.3.1 py36_0 qtpy 1.3.1 py36_0 readline 6.2 2 requests 2.14.2 py36_0 rope 0.9.4 py36_1 ruamel_yaml 0.11.14 py36_1 scikit-image 0.13.0 np113py36_0 scikit-learn 0.19.0 np113py36_0 scipy 0.19.1 np113py36_0 seaborn 0.8 py36_0 setuptools 36.4.0 py36_1 simplegeneric 0.8.1 py36_1 singledispatch 3.4.0.3 py36_0 sip 4.18 py36_0 six 1.10.0 py36_0 snowballstemmer 1.2.1 py36_0 sortedcollections 0.5.3 py36_0 sortedcontainers 1.5.7 py36_0 sphinx 1.6.3 py36_0 sphinxcontrib 1.0 py36_0 sphinxcontrib-websupport 1.0.1 py36_0 spyder 3.2.3 py36_0 sqlalchemy 1.1.13 py36_0 sqlite 3.13.0 0 statsmodels 0.8.0 np113py36_0 sympy 1.1.1 py36_0 tblib 1.3.2 py36_0 tensorflow 1.3.0 0 tensorflow-base 1.3.0 py36h5293eaa_1 tensorflow-tensorboard 0.1.5 py36_0 terminado 0.6 py36_0 testpath 0.3.1 py36_0 theano 0.9.0 py36_0 tk 8.5.18 0 toolz 0.8.2 py36_0 torchvision 0.1.8 py36_0 tornado 4.5.2 py36_0 traitlets 4.3.2 py36_0 unicodecsv 0.14.1 py36_0 unixodbc 2.3.4 0 wcwidth 0.1.7 py36_0 werkzeug 0.12.2 py36_0 wheel 0.29.0 py36_0 widgetsnbextension 3.0.2 py36_0 wrapt 1.10.11 py36_0 xlrd 1.1.0 py36_0 xlsxwriter 0.9.8 py36_0 xlwt 1.3.0 py36_0 xz 5.2.3 0 yaml 0.1.6 0 zeromq 4.1.5 0 zict 0.1.2 py36_0 zlib 1.2.11 0
知道为什么它在 colab 上运行但在本地崩溃了吗?
更新Keras没有解决问题,更新tensorflow解决了问题。 我有 tf 1.3,使用
升级到 1.10conda install -c conda-forge tensorflow
它确实解决了我的问题。现在 运行 好了。