From ec7d11a4922f2013304a2203b07db727100e3816 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 9 Oct 2020 12:47:30 +0800 Subject: [PATCH] refine fused_elemwise_activation error message (#27734) --- .../fused/fused_elemwise_activation_op.h | 52 ++++++++++++------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h index 2c0c5f9ec0..c61b9a9e48 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.h +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h @@ -276,7 +276,8 @@ static void RunFunctors(const framework::ExecutionContext &ctx, ctx, paddle::operators::math::MulFunctor(), paddle::operators::math::SigmoidFunctor(), in_x, in_y, outputs); } else { - PADDLE_THROW("%s has not been implemented.", funcs_str); + PADDLE_THROW(platform::errors::InvalidArgument( + "%s has not been implemented.", funcs_str)); } } @@ -374,7 +375,8 @@ static void RunGradFunctors( paddle::operators::math::SigmoidGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else { - PADDLE_THROW("%s has not been implemented.", funcs_str); + PADDLE_THROW(platform::errors::InvalidArgument( + "%s has not been implemented.", funcs_str)); } } @@ -386,16 +388,21 @@ class FusedElemwiseActivationKernel : public framework::OpKernel { "X", "FusedElemwiseActivation"); auto &in_y = GET_DATA_SAFELY(ctx.Input("Y"), "Input", "Y", "FusedElemwiseActivation"); - PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty"); + + PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true, + platform::errors::InvalidArgument( + "The output(Out) should not be empty")); auto output = ctx.Output("Out"); std::vector outputs; outputs.emplace_back(output); if (ctx.Attr("save_intermediate_out")) { - PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"), - "The save_intermediate_out is enable, so the " - "IntermediateOut should not be empty."); + PADDLE_ENFORCE_EQ(ctx.HasOutput("IntermediateOut"), true, + platform::errors::InvalidArgument( + "The save_intermediate_out is enable, so the " + "IntermediateOut should not be empty.")); + auto intermediate_out = ctx.Output("IntermediateOut"); outputs.emplace_back(intermediate_out); } else { @@ -411,13 +418,18 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto in_y = ctx.Input("Y"); - PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr."); + PADDLE_ENFORCE_NE(in_y, nullptr, platform::errors::InvalidArgument( + "Input(Y) should not be nullptr.")); auto in_out = ctx.Input("Out"); - PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr."); + PADDLE_ENFORCE_NE( + in_out, nullptr, + platform::errors::InvalidArgument("Input(Out) should not be nullptr.")); auto in_out_grad = ctx.Input(framework::GradVarName("Out")); - PADDLE_ENFORCE(in_out_grad != nullptr, - "Input(Out@Grad) should not be nullptr."); + PADDLE_ENFORCE_NE(in_out_grad, nullptr, + platform::errors::InvalidArgument( + "Input(Out@Grad) should not be nullptr.")); + framework::Tensor *in_x = const_cast(ctx.Input("X")); framework::Tensor *x_grad = @@ -437,24 +449,28 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel { // recompute. in_intermediate_out = const_cast( ctx.Input("IntermediateOut")); - PADDLE_ENFORCE(in_intermediate_out != nullptr, - "The option of 'save_intermediate_out' is opened, " - "so the number of 'Out' should be two."); + PADDLE_ENFORCE_NE(in_intermediate_out, nullptr, + platform::errors::InvalidArgument( + "The option of 'save_intermediate_out' is opened," + " so the number of 'Out' should be two.")); } else { if (!InputXCanBeAbsent(functor_list)) { - PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null."); + PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument( + "Input(X) should not be null.")); } } // Get in_x if (ctx.HasInput("X")) { - PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr."); + PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument( + "Input(X) should not be null.")); } else { // If functor_list contains elementwise_add, the backward doesn't use // in_x, in_y and in_out. - PADDLE_ENFORCE(InputXCanBeAbsent(functor_list), - "Only when the compoundfunctor contains " - "elementwise_add_grad, the 'X' could be absent."); + PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list), true, + platform::errors::InvalidArgument( + "Only when the compoundfunctor contains " + "elementwise_add_grad, the 'X' could be absent.")); in_x = const_cast(in_out_grad); } -- GitLab