Tensorflow:卷积中的无效参数错误
Tensorflow: Invalid Argument Error in Convolution
我正在尝试 运行 这段 Python 代码,但似乎无法解决错误:
tf.nn.conv2d(tf.reshape(x, [5, 5]), tf.reshape(wt, [3, 3]), strides=[1, 1], padding='SAME')
这里,x 是来自 (5,5) numpy 数组的 tf.Variable
,w 是来自 (3,3) numpy 数组的 tf.Variable。
我得到的错误是:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
C:\Anaconda3\lib\site-packages\tensorflow\python\framework\common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)
669 node_def_str, input_shapes, input_tensors, input_tensors_as_shapes,
--> 670 status)
671 except errors.InvalidArgumentError as err:
C:\Anaconda3\lib\contextlib.py in __exit__(self, type, value, traceback)
65 try:
---> 66 next(self.gen)
67 except StopIteration:
C:\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py in raise_exception_on_not_ok_status()
468 compat.as_text(pywrap_tensorflow.TF_Message(status)),
--> 469 pywrap_tensorflow.TF_GetCode(status))
470 finally:
InvalidArgumentError: Shape must be rank 4 but is rank 2 for 'Conv2D_19' (op: 'Conv2D') with input shapes: [5,5], [3,3].
为了使用tf.nn.conv2d
。您的输入和过滤器都应转换为 4D。此外,strides
应该是 1-D of length 4
(为输入的每个维度滑动 window)。以下摘自documentation:
Given an input tensor of shape [batch, in_height, in_width,
in_channels] and a filter / kernel tensor of shape [filter_height,
filter_width, in_channels, out_channels], this op performs the
following:
Flattens the filter to a 2-D matrix with shape [filter_height *
filter_width * in_channels, output_channels]. Extracts image patches
from the input tensor to form a virtual tensor of shape [batch,
out_height, out_width, filter_height * filter_width * in_channels].
For each patch, right-multiplies the filter matrix and the image patch
vector.
您可以采用:tf.reshape(x, [1, 5, 5, 1])
用于数据,tf.reshape(wt, [3, 3, 1, 1])
用于过滤器,以及 strides=[1, 1, 1, 1]
。这导致:
tf.nn.conv2d(tf.reshape(x, [1, 5, 5, 1]), tf.reshape(wt, [3, 3, 1, 1]), strides=[1, 1, 1, 1], padding='SAME')
我正在尝试 运行 这段 Python 代码,但似乎无法解决错误:
tf.nn.conv2d(tf.reshape(x, [5, 5]), tf.reshape(wt, [3, 3]), strides=[1, 1], padding='SAME')
这里,x 是来自 (5,5) numpy 数组的 tf.Variable
,w 是来自 (3,3) numpy 数组的 tf.Variable。
我得到的错误是:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
C:\Anaconda3\lib\site-packages\tensorflow\python\framework\common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)
669 node_def_str, input_shapes, input_tensors, input_tensors_as_shapes,
--> 670 status)
671 except errors.InvalidArgumentError as err:
C:\Anaconda3\lib\contextlib.py in __exit__(self, type, value, traceback)
65 try:
---> 66 next(self.gen)
67 except StopIteration:
C:\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py in raise_exception_on_not_ok_status()
468 compat.as_text(pywrap_tensorflow.TF_Message(status)),
--> 469 pywrap_tensorflow.TF_GetCode(status))
470 finally:
InvalidArgumentError: Shape must be rank 4 but is rank 2 for 'Conv2D_19' (op: 'Conv2D') with input shapes: [5,5], [3,3].
为了使用tf.nn.conv2d
。您的输入和过滤器都应转换为 4D。此外,strides
应该是 1-D of length 4
(为输入的每个维度滑动 window)。以下摘自documentation:
Given an input tensor of shape [batch, in_height, in_width, in_channels] and a filter / kernel tensor of shape [filter_height, filter_width, in_channels, out_channels], this op performs the following:
Flattens the filter to a 2-D matrix with shape [filter_height * filter_width * in_channels, output_channels]. Extracts image patches from the input tensor to form a virtual tensor of shape [batch, out_height, out_width, filter_height * filter_width * in_channels]. For each patch, right-multiplies the filter matrix and the image patch vector.
您可以采用:tf.reshape(x, [1, 5, 5, 1])
用于数据,tf.reshape(wt, [3, 3, 1, 1])
用于过滤器,以及 strides=[1, 1, 1, 1]
。这导致:
tf.nn.conv2d(tf.reshape(x, [1, 5, 5, 1]), tf.reshape(wt, [3, 3, 1, 1]), strides=[1, 1, 1, 1], padding='SAME')