Tensorflow 2:序列化和解码时形状不匹配
Tensorflow 2: shape mismatch when serialize and decode it back
我有一个形状为 (300,256,256) 的张量 A。
我想序列化 A 以保存为 tfrecord 格式。
但我无法将它转换回具有相同形状的张量。
A = tf.convert_to_tensor( *a numpy array with float32 type* )
B = tf.io.serialize_tensor(A)
C = tf.reshape(tf.io.decode_raw(B, out_type=tf.float32),[300,256,256])
如果我运行上面的代码,我得到一个形状错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 19660806 values, but the requested shape has 19660800 [Op:Reshape]
好像我序列化或者解码的时候,加了6个浮点数。 (很奇怪)
尝试使用:tf.io.parse_tensor()
,而不是 tf.io.decode_raw()
。
https://www.tensorflow.org/api_docs/python/tf/io/parse_tensor
我有一个形状为 (300,256,256) 的张量 A。 我想序列化 A 以保存为 tfrecord 格式。 但我无法将它转换回具有相同形状的张量。
A = tf.convert_to_tensor( *a numpy array with float32 type* )
B = tf.io.serialize_tensor(A)
C = tf.reshape(tf.io.decode_raw(B, out_type=tf.float32),[300,256,256])
如果我运行上面的代码,我得到一个形状错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 19660806 values, but the requested shape has 19660800 [Op:Reshape]
好像我序列化或者解码的时候,加了6个浮点数。 (很奇怪)
尝试使用:tf.io.parse_tensor()
,而不是 tf.io.decode_raw()
。
https://www.tensorflow.org/api_docs/python/tf/io/parse_tensor