From 63d88b522f3b04b31710d1b85eb707e1aef33a5b Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Fri, 11 Oct 2019 10:43:26 +0800 Subject: [PATCH] refine sqrt api check (#20254) test=develop --- .../fluid/layers/layer_function_generator.py | 13 +++++++++++++ .../fluid/tests/unittests/test_activation_op.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py index 27643143f8..dc3a42f1e9 100755 --- a/python/paddle/fluid/layers/layer_function_generator.py +++ b/python/paddle/fluid/layers/layer_function_generator.py @@ -22,6 +22,7 @@ from six.moves import cStringIO from ..proto import framework_pb2 from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_ from ..layer_helper import LayerHelper +from ..data_feeder import convert_dtype __all__ = [ 'deprecated', 'generate_layer_fn', 'generate_activation_fn', 'autodoc', @@ -250,6 +251,18 @@ def generate_activation_fn(op_type): def func(x, name=None): helper = LayerHelper(op_type, **locals()) + if not isinstance(x, Variable): + raise TypeError( + "The type of 'x' in %s must be Variable, but received %s" % + (op_type, type(x))) + if convert_dtype(x.dtype) in ['float16']: + warnings.warn( + "The data type of 'x' in %s only support float16 in GPU now." % + (op_type)) + if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']: + raise TypeError( + "The data type of 'x' in %s must be float16 (only support on GPU), float32 or float64, but received %s." + % (op_type, convert_dtype(x.dtype))) output = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output}) return output diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index c8e6cbc314..e325d4afb8 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -23,6 +23,22 @@ import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard +class TestSqrtOpError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of sqrt op must be Variable or numpy.ndarray. + in1 = 1 + self.assertRaises(TypeError, fluid.layers.sqrt, in1) + # The input dtype of sqrt op must be float16, float32, float64. + in2 = fluid.layers.data( + name='input2', shape=[12, 10], dtype="int32") + self.assertRaises(TypeError, fluid.layers.sqrt, in2) + + in3 = fluid.layers.data( + name='input3', shape=[12, 10], dtype="float16") + fluid.layers.sqrt(x=in3) + + class TestActivation(OpTest): def setUp(self): self.op_type = "exp" -- GitLab