未验证 提交 03073937 编写于 作者: S songyouwei 提交者: GitHub

API(PRelu) error message enhancement (#23539)

* err msg enhance for PRelu

* add ut
test=develop
上级 cae9340e
......@@ -2274,6 +2274,7 @@ class PRelu(layers.Layer):
default_initializer=Constant(1.0))
def forward(self, input):
check_variable_and_dtype(input, 'input', ['float32'], 'PRelu')
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="prelu",
......
......@@ -16,12 +16,30 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import six
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci
class TestPReluAPIError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
layer = fluid.PRelu(
mode='all',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.0)))
# the input must be Variable.
x0 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
self.assertRaises(TypeError, layer, x0)
# the input dtype must be float32
data_t = fluid.data(
name="input", shape=[5, 200, 100, 100], dtype="float64")
self.assertRaises(TypeError, layer, data_t)
class PReluTest(OpTest):
def setUp(self):
self.init_input_shape()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册