diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc index d15352110f1dca966213aeca1daadda94cd42c31..fd6269a4699765f7333ffa9f1037b41fc961a8e9 100644 --- a/paddle/operators/prelu_op.cc +++ b/paddle/operators/prelu_op.cc @@ -29,6 +29,8 @@ class PReluOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); auto *in = ctx.Input("X"); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) should not be null"); auto *out = ctx.Output("Out"); out->Resize(in->dims()); } @@ -41,6 +43,8 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input tensor of prelu operator."); AddOutput("Out", "The output tensor of prelu operator."); + AddAttr("alpha", "The scaling factor alpha of prelu.") + .SetDefault(0.0); AddComment(R"DOC(PRelu operator The equation is: @@ -49,8 +53,6 @@ The equation is: f(x) = x , for x >= 0 )DOC"); - AddAttr("alpha", "The scaling factor alpha of prelu.") - .SetDefault(0.0); } }; diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h index d3d8f76e5a7f1c8e086e0e3c48a8f0766b310f4c..31ae54d5bc17c5cb2b524b401796230afe9120af 100644 --- a/paddle/operators/prelu_op.h +++ b/paddle/operators/prelu_op.h @@ -24,9 +24,9 @@ using Tensor = framework::Tensor; using platform::Transform; template -class Prelu_functor { +class PReluFunctor { public: - explicit Prelu_functor(const T& alpha) : alpha_(alpha) {} + explicit PReluFunctor(const T& alpha) : alpha_(alpha) {} HOSTDEVICE T operator()(const T& X) const { if (X > 0) @@ -54,14 +54,14 @@ class PReluKernel : public framework::OpKernel { int numel = X->numel(); auto place = context.GetPlace(); - Transform(place, X_ptr, X_ptr + numel, O_ptr, Prelu_functor(alpha)); + Transform(place, X_ptr, X_ptr + numel, O_ptr, PReluFunctor(alpha)); } }; template -class Prelu_Grad_functor { +class PReluGradFunctor { public: - explicit Prelu_Grad_functor(const T& alpha) : alpha_(alpha) {} + explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {} HOSTDEVICE T operator()(const T& Out, const T& dOut) const { if (Out > 0) @@ -92,7 +92,7 @@ class PReluGradKernel : public framework::OpKernel { auto place = context.GetPlace(); Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr, - Prelu_Grad_functor(alpha)); + PReluGradFunctor(alpha)); } };