未验证 提交 63d88b52 编写于 作者: Z Zhaolong Xing 提交者: GitHub

refine sqrt api check (#20254)

test=develop
上级 a69baf63
...@@ -22,6 +22,7 @@ from six.moves import cStringIO ...@@ -22,6 +22,7 @@ from six.moves import cStringIO
from ..proto import framework_pb2 from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_ from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..data_feeder import convert_dtype
__all__ = [ __all__ = [
'deprecated', 'generate_layer_fn', 'generate_activation_fn', 'autodoc', 'deprecated', 'generate_layer_fn', 'generate_activation_fn', 'autodoc',
...@@ -250,6 +251,18 @@ def generate_activation_fn(op_type): ...@@ -250,6 +251,18 @@ def generate_activation_fn(op_type):
def func(x, name=None): def func(x, name=None):
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
if not isinstance(x, Variable):
raise TypeError(
"The type of 'x' in %s must be Variable, but received %s" %
(op_type, type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in %s only support float16 in GPU now." %
(op_type))
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in %s must be float16 (only support on GPU), float32 or float64, but received %s."
% (op_type, convert_dtype(x.dtype)))
output = helper.create_variable_for_type_inference(dtype=x.dtype) output = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output}) helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output})
return output return output
......
...@@ -23,6 +23,22 @@ import paddle.fluid as fluid ...@@ -23,6 +23,22 @@ import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
class TestSqrtOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of sqrt op must be Variable or numpy.ndarray.
in1 = 1
self.assertRaises(TypeError, fluid.layers.sqrt, in1)
# The input dtype of sqrt op must be float16, float32, float64.
in2 = fluid.layers.data(
name='input2', shape=[12, 10], dtype="int32")
self.assertRaises(TypeError, fluid.layers.sqrt, in2)
in3 = fluid.layers.data(
name='input3', shape=[12, 10], dtype="float16")
fluid.layers.sqrt(x=in3)
class TestActivation(OpTest): class TestActivation(OpTest):
def setUp(self): def setUp(self):
self.op_type = "exp" self.op_type = "exp"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册