未验证 提交 60d1f98a 编写于 作者: Z zhongpu 提交者: GitHub

error message enhancement for Pool2D, test=develop (#23607)

上级 f9c9d50e
...@@ -838,6 +838,10 @@ class Pool2D(layers.Layer): ...@@ -838,6 +838,10 @@ class Pool2D(layers.Layer):
'use_mkldnn', False, 'exclusive', self._exclusive) 'use_mkldnn', False, 'exclusive', self._exclusive)
return core.ops.pool2d(input, *attrs) return core.ops.pool2d(input, *attrs)
check_variable_and_dtype(
input, 'input', ['int8', 'uint8', 'float16', 'float32', 'float64'],
'Pool2D')
attrs = { attrs = {
"pooling_type": self._pool_type, "pooling_type": self._pool_type,
"ksize": self._pool_size, "ksize": self._pool_size,
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def adaptive_start_index(index, input_size, output_size): def adaptive_start_index(index, input_size, output_size):
...@@ -1275,5 +1276,25 @@ class TestPool2dAPI_Error(unittest.TestCase): ...@@ -1275,5 +1276,25 @@ class TestPool2dAPI_Error(unittest.TestCase):
self.assertRaises(ValueError, run_5) self.assertRaises(ValueError, run_5)
class TestDygraphPool2DAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of Pool2D must be Variable.
data1 = np.random.random((3, 32, 32, 5)).astype('float32')
pool2d = fluid.dygraph.Pool2D(
pool_size=2,
pool_type='max',
pool_stride=1,
global_pooling=False)
self.assertRaises(TypeError, pool2d, data1)
# the input dtype of Pool2D must be uint8 or int8 or float16 or float32 or float64
# uint8 and int8 only can be set on mkldnn
# float16 only can be set on GPU place
data2 = fluid.layers.data(
name='x1', shape=[3, 32, 32, 5], dtype="int32")
self.assertRaises(TypeError, pool2d, data2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册