未验证 提交 ec7d11a4 编写于 作者: C cc 提交者: GitHub

refine fused_elemwise_activation error message (#27734)

上级 994438b1
......@@ -276,7 +276,8 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
}
}
......@@ -374,7 +375,8 @@ static void RunGradFunctors(
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
}
}
......@@ -386,16 +388,21 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
"X", "FusedElemwiseActivation");
auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
"Y", "FusedElemwiseActivation");
PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
platform::errors::InvalidArgument(
"The output(Out) should not be empty"));
auto output = ctx.Output<framework::Tensor>("Out");
std::vector<framework::Tensor *> outputs;
outputs.emplace_back(output);
if (ctx.Attr<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
"The save_intermediate_out is enable, so the "
"IntermediateOut should not be empty.");
PADDLE_ENFORCE_EQ(ctx.HasOutput("IntermediateOut"), true,
platform::errors::InvalidArgument(
"The save_intermediate_out is enable, so the "
"IntermediateOut should not be empty."));
auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
outputs.emplace_back(intermediate_out);
} else {
......@@ -411,13 +418,18 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto in_y = ctx.Input<framework::Tensor>("Y");
PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr.");
PADDLE_ENFORCE_NE(in_y, nullptr, platform::errors::InvalidArgument(
"Input(Y) should not be nullptr."));
auto in_out = ctx.Input<framework::Tensor>("Out");
PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr.");
PADDLE_ENFORCE_NE(
in_out, nullptr,
platform::errors::InvalidArgument("Input(Out) should not be nullptr."));
auto in_out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(in_out_grad != nullptr,
"Input(Out@Grad) should not be nullptr.");
PADDLE_ENFORCE_NE(in_out_grad, nullptr,
platform::errors::InvalidArgument(
"Input(Out@Grad) should not be nullptr."));
framework::Tensor *in_x =
const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
framework::Tensor *x_grad =
......@@ -437,24 +449,28 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
// recompute.
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE(in_intermediate_out != nullptr,
"The option of 'save_intermediate_out' is opened, "
"so the number of 'Out' should be two.");
PADDLE_ENFORCE_NE(in_intermediate_out, nullptr,
platform::errors::InvalidArgument(
"The option of 'save_intermediate_out' is opened,"
" so the number of 'Out' should be two."));
} else {
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null.");
PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
"Input(X) should not be null."));
}
}
// Get in_x
if (ctx.HasInput("X")) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr.");
PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
"Input(X) should not be null."));
} else {
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, in_y and in_out.
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent.");
PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list), true,
platform::errors::InvalidArgument(
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent."));
in_x = const_cast<framework::Tensor *>(in_out_grad);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册