From 1b797468899097487c210b1ed761ae91beefcb11 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Mon, 18 Sep 2017 15:34:51 -0700 Subject: [PATCH] prelu --- paddle/operators/prelu_op.cc | 23 +++++++++----- paddle/operators/prelu_op.h | 58 +++++++++++++++++++----------------- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc index fd6269a4699..911df8ba672 100644 --- a/paddle/operators/prelu_op.cc +++ b/paddle/operators/prelu_op.cc @@ -29,6 +29,11 @@ class PReluOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); auto *in = ctx.Input("X"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Alpha"), + "Input(Alpha) should not be null"); + auto *alpha = ctx.Input("Alpha"); + PADDLE_ENFORCE(alpha->numel() == 1, "Size of weight Alpha must be one."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), "Output(Out) should not be null"); auto *out = ctx.Output("Out"); @@ -36,15 +41,13 @@ class PReluOp : public framework::OperatorWithKernel { } }; -template class PReluOpMaker : public framework::OpProtoAndCheckerMaker { public: PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input tensor of prelu operator."); + AddInput("Alpha", "The alpha weight of prelu operator."); AddOutput("Out", "The output tensor of prelu operator."); - AddAttr("alpha", "The scaling factor alpha of prelu.") - .SetDefault(0.0); AddComment(R"DOC(PRelu operator The equation is: @@ -66,11 +69,15 @@ class PReluGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - auto *X_grad = - ctx.Output(framework::GradVarName("X")); - auto *X = ctx.Input("X"); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *x = ctx.Input("X"); + + auto *dalpha = + ctx.Output(framework::GradVarName("Alpha")); + auto *alpha = ctx.Input("Alpha"); - X_grad->Resize(X->dims()); + dx->Resize(x->dims()); + dalpha->Resize(alpha->dims()); } }; @@ -79,7 +86,7 @@ class PReluGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad, +REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad, ops::PReluGradOp); REGISTER_OP_CPU_KERNEL(prelu, ops::PReluKernel); diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h index 31ae54d5bc1..f88ce94dc86 100644 --- a/paddle/operators/prelu_op.h +++ b/paddle/operators/prelu_op.h @@ -28,33 +28,35 @@ class PReluFunctor { public: explicit PReluFunctor(const T& alpha) : alpha_(alpha) {} - HOSTDEVICE T operator()(const T& X) const { - if (X > 0) - return X; + HOSTDEVICE T operator()(const T& x) const { + if (x > 0) + return x; else - return X * alpha_; + return x * alpha_; } private: T alpha_; }; -template +template class PReluKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Out = context.Output("Out"); + auto* x = context.Input("X"); + auto* alpha = context.Input("Alpha"); + auto* out = context.Output("Out"); - const T* X_ptr = X->data(); - T* O_ptr = Out->mutable_data(context.GetPlace()); + const T* x_ptr = x->data(); + T* o_ptr = out->mutable_data(context.GetPlace()); - auto alpha = static_cast(context.Attr("alpha")); + auto alpha_val = alpha->data()[0]; + // auto alpha = static_cast(context.Attr("alpha")); - int numel = X->numel(); + int numel = x->numel(); auto place = context.GetPlace(); - Transform(place, X_ptr, X_ptr + numel, O_ptr, PReluFunctor(alpha)); + Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor(alpha_val)); } }; @@ -63,36 +65,36 @@ class PReluGradFunctor { public: explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {} - HOSTDEVICE T operator()(const T& Out, const T& dOut) const { - if (Out > 0) - return dOut; + HOSTDEVICE T operator()(const T& out, const T& dout) const { + if (out > 0) + return dout; else - return dOut * alpha_; + return dout * alpha_; } private: T alpha_; }; -template +template class PReluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* dX = context.Output(framework::GradVarName("X")); - auto* dO = context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + auto* dout = context.Input(framework::GradVarName("Out")); - auto* Out = context.Input("Out"); + auto* out = context.Input("Out"); + auto* alpha = context.Input("Alpha"); + auto alpha_val = alpha->data()[0]; - auto alpha = static_cast(context.Attr("alpha")); - - T* dX_ptr = dX->mutable_data(context.GetPlace()); - const T* dO_ptr = dO->data(); - const T* O_ptr = Out->data(); - int numel = dX->numel(); + T* dx_ptr = dx->mutable_data(context.GetPlace()); + const T* dout_ptr = dout->data(); + const T* out_ptr = out->data(); + int numel = dx->numel(); auto place = context.GetPlace(); - Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr, - PReluGradFunctor(alpha)); + Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr, + PReluGradFunctor(alpha_val)); } }; -- GitLab