From ac4da77aa65ceac23cd1483772040a8f12a19b44 Mon Sep 17 00:00:00 2001 From: Steffy-zxf <48793257+Steffy-zxf@users.noreply.github.com> Date: Sun, 12 Apr 2020 19:32:54 +0800 Subject: [PATCH] =?UTF-8?q?update=20error=20info=20of=20ops=EF=BC=8Cadd=20?= =?UTF-8?q?some=20test=20cases=20for=20raise=20message=20(#23750)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. update error info of the ops (abs, acos, asin, atan, ceil, cos, exp, floor, log, pow, reciprocal, round, rsqrt, sin, sqrt, square, tanh) 2. add the unittests of the above refered ops (test error info) --- paddle/fluid/operators/activation_op.h | 156 +++++++++++------- paddle/fluid/operators/sign_op.cc | 7 +- .../fluid/layers/layer_function_generator.py | 11 +- python/paddle/fluid/layers/nn.py | 5 + .../tests/unittests/test_activation_op.py | 65 ++++++++ .../fluid/tests/unittests/test_sign_op.py | 9 +- 6 files changed, 181 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 30bb989bbf..cb210fca31 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #ifdef PADDLE_WITH_MKLDNN @@ -53,12 +54,14 @@ inline void ExtractActivationTensor(const framework::ExecutionContext& context, framework::Tensor** Out) { auto x_var = context.InputVar("X"); auto out_var = context.OutputVar("Out"); - PADDLE_ENFORCE(x_var != nullptr, - "Cannot get input Variable X, variable name = %s", - context.InputName("X")); - PADDLE_ENFORCE(out_var != nullptr, - "Cannot get output Variable Out, variable name = %s", - context.OutputName("Out")); + PADDLE_ENFORCE_NOT_NULL(x_var, + platform::errors::NotFound( + "Cannot get input Variable X, variable name = %s", + context.InputName("X"))); + PADDLE_ENFORCE_NOT_NULL( + out_var, platform::errors::NotFound( + "Cannot get output Variable Out, variable name = %s", + context.OutputName("Out"))); if (CanBeUsedBySelectedRows.count(context.Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( @@ -68,9 +71,10 @@ inline void ExtractActivationTensor(const framework::ExecutionContext& context, *Out = context.Output("Out"); } - PADDLE_ENFORCE(*Out != nullptr, - "Cannot get output tensor Out, variable name = %s", - context.OutputName("Out")); + PADDLE_ENFORCE_NOT_NULL(*Out, platform::errors::NotFound( + "Cannot get the tensor from the Variable " + "Output(Out), variable name = %s", + context.OutputName("Out"))); } template @@ -84,18 +88,22 @@ inline void ExtractActivationGradTensor( if (static_cast(kDepValue) & static_cast(kDepOut)) { out_var = context.InputVar("Out"); - PADDLE_ENFORCE(out_var != nullptr, - "Cannot get input Variable Out, variable name = %s", - context.InputName("Out")); - } - PADDLE_ENFORCE(out_grad_var != nullptr, - "Cannot get input Variable %s, variable name = %s", - framework::GradVarName("Out"), - context.InputName(framework::GradVarName("Out"))); - PADDLE_ENFORCE(x_grad_var != nullptr, - "Cannot get output Variable %s, variable name = %s", - framework::GradVarName("X"), - context.OutputName(framework::GradVarName("X"))); + PADDLE_ENFORCE_NOT_NULL( + out_var, platform::errors::NotFound( + "Cannot get input Variable Out, variable name = %s", + context.InputName("Out"))); + } + + PADDLE_ENFORCE_NOT_NULL( + out_grad_var, platform::errors::NotFound( + "Cannot get input Variable %s, variable name = %s", + framework::GradVarName("Out"), + context.InputName(framework::GradVarName("Out")))); + PADDLE_ENFORCE_NOT_NULL( + x_grad_var, platform::errors::NotFound( + "Cannot get output Variable %s, variable name = %s", + framework::GradVarName("X"), + context.OutputName(framework::GradVarName("X")))); if (CanBeUsedBySelectedRows.count(context.Type())) { *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( @@ -122,16 +130,18 @@ inline void ExtractActivationGradTensor( } } - PADDLE_ENFORCE(*dX != nullptr, - "Cannot get output tensor %s, variable name = %s", - framework::GradVarName("X"), - context.OutputName(framework::GradVarName("X"))); + PADDLE_ENFORCE_NOT_NULL(*dX, + platform::errors::NotFound( + "Cannot get the tensor from the Variable " + "Output(Out), variable name = %s", + context.OutputName(framework::GradVarName("X")))); if (static_cast(kDepValue) & static_cast(kDepX)) { auto x_var = context.InputVar("X"); - PADDLE_ENFORCE(x_var != nullptr, - "Cannot get input tensor X, variable name = %s", - context.InputName("X")); + PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound( + "Cannot get the tensor from the " + "Variable Input(X), variable name = %s", + context.InputName("X"))); if (CanBeUsedBySelectedRows.count(context.Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); } else { @@ -1186,9 +1196,10 @@ inline void ExtractActivationDoubleGradTensor( framework::Tensor** ddOut) { auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); - PADDLE_ENFORCE(ddx_var != nullptr, - "Cannot get input Variable Out, variable name = %s", - ctx.InputName("DDX")); + PADDLE_ENFORCE_NOT_NULL( + ddx_var, platform::errors::NotFound( + "Cannot get input Variable Out, variable name = %s", + ctx.InputName("DDX"))); if (CanBeUsedBySelectedRows.count(ctx.Type())) { *ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*ddx_var); if (ddo_var) { @@ -1201,15 +1212,18 @@ inline void ExtractActivationDoubleGradTensor( *ddOut = ctx.Output("DDOut"); } } - PADDLE_ENFORCE(*ddX != nullptr, - "Cannot get output tensor DDX, variable name = %s", - ctx.OutputName("DDX")); + PADDLE_ENFORCE_NOT_NULL( + *ddX, + platform::errors::NotFound( + "Cannot get the tensor from the Variable Output, variable name = %s", + ctx.OutputName("DDX"))); if (static_cast(kDepValue) & static_cast(kDepX)) { auto x_var = ctx.InputVar("X"); - PADDLE_ENFORCE(x_var != nullptr, + PADDLE_ENFORCE_NOT_NULL( + x_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", - ctx.InputName("X")); + ctx.InputName("X"))); auto dx_var = ctx.OutputVar("DX"); if (CanBeUsedBySelectedRows.count(ctx.Type())) { *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); @@ -1229,9 +1243,11 @@ inline void ExtractActivationDoubleGradTensor( } if (static_cast(kDepValue) & static_cast(kDepOut)) { auto out_var = ctx.InputVar("Out"); - PADDLE_ENFORCE(out_var != nullptr, - "Cannot get input tensor Out, variable name = %s", - ctx.InputName("Out")); + PADDLE_ENFORCE_NOT_NULL( + out_var, + platform::errors::NotFound( + "Cannot get the tensor from the Variable Out, variable name = %s", + ctx.InputName("Out"))); auto dout_var = ctx.OutputVar("DOut"); if (CanBeUsedBySelectedRows.count(ctx.Type())) { *Out = @@ -1438,22 +1454,26 @@ inline void ExtractDoubleGradTensorWithInputDOut( // extract ddX(output), ddOut(input) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); - PADDLE_ENFORCE(ddx_var != nullptr, - "Cannot get input Variable Out, variable name = %s", - ctx.InputName("DDX")); + PADDLE_ENFORCE_NOT_NULL( + ddx_var, platform::errors::NotFound( + "Cannot get input Variable Out, variable name = %s", + ctx.InputName("DDX"))); *ddX = ctx.Input("DDX"); if (ddo_var) { *ddOut = ctx.Output("DDOut"); } - PADDLE_ENFORCE(*ddX != nullptr, - "Cannot get output tensor DDX, variable name = %s", - ctx.OutputName("DDX")); + PADDLE_ENFORCE_NOT_NULL( + ddX, + platform::errors::NotFound( + "Cannot get the tensor from the Variable DDX, variable name = %s", + ctx.OutputName("DDX"))); // extract x(input), dx(output) auto x_var = ctx.InputVar("X"); - PADDLE_ENFORCE(x_var != nullptr, + PADDLE_ENFORCE_NOT_NULL( + x_var, platform::errors::NotFound( "Cannot get input Variable Out, variable name = %s", - ctx.InputName("X")); + ctx.InputName("X"))); auto dx_var = ctx.OutputVar("DX"); *X = ctx.Input("X"); if (dx_var) { @@ -1531,22 +1551,25 @@ class SqrtDoubleGradKernel // extract ddx(input), ddout(output) auto ddx_var = ctx.InputVar("DDX"); auto ddo_var = ctx.OutputVar("DDOut"); - PADDLE_ENFORCE(ddx_var != nullptr, - "Cannot get input Variable DDX, variable name = %s", - ctx.InputName("DDX")); + PADDLE_ENFORCE_NOT_NULL( + ddx_var, platform::errors::NotFound( + "Cannot get input Variable DDX, variable name = %s", + ctx.InputName("DDX"))); ddX = ctx.Input("DDX"); if (ddo_var) { ddOut = ctx.Output("DDOut"); } - PADDLE_ENFORCE(ddX != nullptr, - "Cannot get input Variable DDX, variable name = %s", - ctx.InputName("DDX")); + PADDLE_ENFORCE_NOT_NULL( + ddX, platform::errors::NotFound( + "Cannot get input Variable DDX, variable name = %s", + ctx.InputName("DDX"))); // extract out(input), dout(output) auto out_var = ctx.InputVar("Out"); - PADDLE_ENFORCE(out_var != nullptr, - "Cannot get input Variable Out, variable name = %s", - ctx.InputName("Out")); + PADDLE_ENFORCE_NOT_NULL( + out_var, platform::errors::NotFound( + "Cannot get input Variable Out, variable name = %s", + ctx.InputName("Out"))); auto dout_var = ctx.OutputVar("DOut"); Out = ctx.Input("Out"); if (dout_var) { @@ -1555,9 +1578,10 @@ class SqrtDoubleGradKernel // extract dx(input) auto dx_var = ctx.InputVar("DX"); - PADDLE_ENFORCE(dx_var != nullptr, - "Cannot get input Variable DX, variable name = %s", - ctx.InputName("DX")); + PADDLE_ENFORCE_NOT_NULL( + dx_var, platform::errors::NotFound( + "Cannot get input Variable DX, variable name = %s", + ctx.InputName("DX"))); if (dx_var) { dX = ctx.Input("DX"); } @@ -1608,8 +1632,11 @@ class PowKernel : public framework::OpKernel { } auto factor = std::vector(factor_data, factor_data + factor_tensor->numel()); - PADDLE_ENFORCE_EQ(factor.size(), 1, - "The shape of factor(tensor) MUST BE [1]."); + PADDLE_ENFORCE_EQ( + factor.size(), 1, + platform::errors::InvalidArgument( + "The shape of factor(tensor) must be [1] rather than %d", + factor.size())); for (auto& attr : attrs) { *attr.second = factor[0]; } @@ -1660,8 +1687,11 @@ class PowGradKernel } auto factor = std::vector(factor_data, factor_data + factor_tensor->numel()); - PADDLE_ENFORCE_EQ(factor.size(), 1, - "The shape of factor(tensor) MUST BE [1]."); + PADDLE_ENFORCE_EQ( + factor.size(), 1, + platform::errors::InvalidArgument( + "The shape of factor(tensor) must be [1] rather than %d", + factor.size())); for (auto& attr : attrs) { *attr.second = factor[0]; } diff --git a/paddle/fluid/operators/sign_op.cc b/paddle/fluid/operators/sign_op.cc index c71d36792e..3485b4e5c2 100644 --- a/paddle/fluid/operators/sign_op.cc +++ b/paddle/fluid/operators/sign_op.cc @@ -23,10 +23,9 @@ class SignOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SignOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of SignOp should not be null."); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "sign"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sign"); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py index 511f1274db..1ee8ba12d8 100755 --- a/python/paddle/fluid/layers/layer_function_generator.py +++ b/python/paddle/fluid/layers/layer_function_generator.py @@ -256,8 +256,15 @@ def generate_activation_fn(op_type): op = getattr(core.ops, op_type) return op(x) - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], - op_type) + if op_type not in ["abs", "exp", "square"]: + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + op_type) + else: + # abs exp square ops support dtype(int32, int64, float16, float32, float64) + check_variable_and_dtype( + x, 'x', ['int32', 'int64', 'float16', 'float32', 'float64'], + op_type) + helper = LayerHelper(op_type, **locals()) output = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1370aedbde..20008777e1 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8183,6 +8183,7 @@ def log(x, name=None): if in_dygraph_mode(): return core.ops.log(x) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], "log") inputs = {'X': [x]} helper = LayerHelper('log', **locals()) dtype = helper.input_dtype(input_param_name='x') @@ -8938,10 +8939,14 @@ def pow(x, factor=1.0, name=None): y_2 = fluid.layers.pow(x, factor=factor_tensor) # y_2 is x^{3.0} """ + check_variable_and_dtype(x, 'x', ['int32', 'int64', 'float32', 'float64'], + 'pow') + helper = LayerHelper('pow', **locals()) inputs = {'X': x} attrs = {} if isinstance(factor, Variable): + check_variable_and_dtype(factor, 'factor', ['float32'], 'pow') factor.stop_gradient = True inputs['FactorTensor'] = factor else: diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 6765a206f1..79ef30f85d 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -765,6 +765,15 @@ class TestLog(TestActivation): return self.check_grad(['X'], 'Out') + def test_error(self): + in1 = fluid.layers.data( + name="in1", shape=[11, 17], append_batch_size=False, dtype="int32") + in2 = fluid.layers.data( + name="in2", shape=[11, 17], append_batch_size=False, dtype="int64") + + self.assertRaises(TypeError, fluid.layers.log, in1) + self.assertRaises(TypeError, fluid.layers.log, in2) + class TestSquare(TestActivation): def setUp(self): @@ -856,6 +865,29 @@ class TestPow_factor_tensor(TestActivation): assert np.array_equal(res_3, res) assert np.array_equal(res_6, np.power(input, 3)) + def test_error(self): + in1 = fluid.layers.data( + name="in1", shape=[11, 17], append_batch_size=False, dtype="int32") + in2 = fluid.layers.data( + name="in2", shape=[11, 17], append_batch_size=False, dtype="int64") + in3 = fluid.layers.data( + name="in3", + shape=[11, 17], + append_batch_size=False, + dtype="float32") + in4 = fluid.layers.data( + name="in4", + shape=[11, 17], + append_batch_size=False, + dtype="float64") + + factor_1 = fluid.layers.fill_constant([1], "float64", 3.0) + + self.assertRaises(TypeError, fluid.layers.pow, x=in1, factor=factor_1) + self.assertRaises(TypeError, fluid.layers.pow, x=in2, factor=factor_1) + self.assertRaises(TypeError, fluid.layers.pow, x=in3, factor=factor_1) + self.assertRaises(TypeError, fluid.layers.pow, x=in4, factor=factor_1) + class TestSTanh(TestActivation): def setUp(self): @@ -1035,6 +1067,39 @@ class TestSwishOpError(unittest.TestCase): fluid.layers.swish(x_fp16) +#------------------ Test Error Activation---------------------- +def create_test_error_class(op_type): + class TestOpErrors(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + op = getattr(fluid.layers, op_type) + # The input dtype of op_type must be float32, float64. + in1 = fluid.layers.data( + name='input2', shape=[12, 10], dtype="int32") + in2 = fluid.layers.data( + name='input3', shape=[12, 10], dtype="int64") + self.assertRaises(TypeError, op, in1) + self.assertRaises(TypeError, op, in2) + + cls_name = "{0}_{1}".format(op_type, "test_errors") + TestOpErrors.__name__ = cls_name + globals()[cls_name] = TestOpErrors + + +create_test_error_class('acos') +create_test_error_class('asin') +create_test_error_class('atan') +create_test_error_class('ceil') +create_test_error_class('cos') +create_test_error_class('floor') +create_test_error_class('reciprocal') +create_test_error_class('round') +create_test_error_class('rsqrt') +create_test_error_class('sin') +create_test_error_class('sqrt') +create_test_error_class('tanh') + + #------------------ Test Cudnn Activation---------------------- def create_test_act_cudnn_class(parent, atol=1e-3, grad_atol=1e-3): @unittest.skipIf(not core.is_compiled_with_cuda(), diff --git a/python/paddle/fluid/tests/unittests/test_sign_op.py b/python/paddle/fluid/tests/unittests/test_sign_op.py index 96718ab458..b84e3b5377 100644 --- a/python/paddle/fluid/tests/unittests/test_sign_op.py +++ b/python/paddle/fluid/tests/unittests/test_sign_op.py @@ -45,10 +45,13 @@ class TestSignOpError(unittest.TestCase): # The input dtype of sign_op must be float16, float32, float64. input2 = fluid.layers.data( name='input2', shape=[12, 10], dtype="int32") - self.assertRaises(TypeError, fluid.layers.sign, input2) input3 = fluid.layers.data( - name='input3', shape=[4], dtype="float16") - fluid.layers.sign(input3) + name='input3', shape=[12, 10], dtype="int64") + self.assertRaises(TypeError, fluid.layers.sign, input2) + self.assertRaises(TypeError, fluid.layers.sign, input3) + input4 = fluid.layers.data( + name='input4', shape=[4], dtype="float16") + fluid.layers.sign(input4) if __name__ == "__main__": -- GitLab