From 03073937215aeb3b70c0b6242ad162e2bd1de71f Mon Sep 17 00:00:00 2001 From: songyouwei Date: Fri, 10 Apr 2020 21:33:40 +0800 Subject: [PATCH] API(PRelu) error message enhancement (#23539) * err msg enhance for PRelu * add ut test=develop --- python/paddle/fluid/dygraph/nn.py | 1 + .../fluid/tests/unittests/test_prelu_op.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index bbe9d9d319..6b617f0468 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -2274,6 +2274,7 @@ class PRelu(layers.Layer): default_initializer=Constant(1.0)) def forward(self, input): + check_variable_and_dtype(input, 'input', ['float32'], 'PRelu') out = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( type="prelu", diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index 2676e036a2..a2ee49e594 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -16,12 +16,30 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid as fluid import six import paddle.fluid as fluid from paddle.fluid import Program, program_guard from op_test import OpTest, skip_check_grad_ci +class TestPReluAPIError(unittest.TestCase): + def test_errors(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + layer = fluid.PRelu( + mode='all', + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(1.0))) + # the input must be Variable. + x0 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + self.assertRaises(TypeError, layer, x0) + # the input dtype must be float32 + data_t = fluid.data( + name="input", shape=[5, 200, 100, 100], dtype="float64") + self.assertRaises(TypeError, layer, data_t) + + class PReluTest(OpTest): def setUp(self): self.init_input_shape() -- GitLab