未验证 提交 c1c8c7e4 编写于 作者: Z zhongpu 提交者: GitHub

error message enhancement for Conv2D, test=develop (#23561)

上级 f3456071
......@@ -243,6 +243,9 @@ class Conv2D(layers.Layer):
'use_cudnn': self._use_cudnn,
'use_mkldnn': False,
}
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'Conv2D')
pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
......
......@@ -286,6 +286,27 @@ class TestLayer(LayerTest):
dy_ret = conv2d(base.to_variable(images))
self.assertTrue(conv2d.bias is None)
with self.static_graph():
# the input of Conv2D must be Variable.
def test_Variable():
images = np.ones([2, 3, 5, 5], dtype='float32')
conv2d = nn.Conv2D(
num_channels=3, num_filters=3, filter_size=[2, 2])
conv2d_ret1 = conv2d(images)
self.assertRaises(TypeError, test_Variable)
# the input dtype of Conv2D must be float16 or float32 or float64
# float16 only can be set on GPU place
def test_type():
images = layers.data(
name='pixel', shape=[3, 5, 5], dtype='int32')
conv2d = nn.Conv2D(
num_channels=3, num_filters=3, filter_size=[2, 2])
conv2d_ret2 = conv2d(images)
self.assertRaises(TypeError, test_type)
self.assertTrue(np.allclose(static_ret, dy_ret_value))
self.assertTrue(np.allclose(static_ret, static_ret2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册