From 1b2374ad3b2831229d7db5e8cf38c81706fd65ce Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 15 Sep 2017 22:30:21 -0700 Subject: [PATCH] new prelu with functor --- paddle/operators/prelu_op.cc | 15 ++-- paddle/operators/prelu_op.h | 69 ++++++++++++++----- .../v2/framework/tests/test_prelu_op.py | 2 +- 3 files changed, 62 insertions(+), 24 deletions(-) diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc index eafd66579f1..d15352110f1 100644 --- a/paddle/operators/prelu_op.cc +++ b/paddle/operators/prelu_op.cc @@ -27,13 +27,14 @@ class PReluOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); auto *in = ctx.Input("X"); auto *out = ctx.Output("Out"); out->Resize(in->dims()); } }; -// template +template class PReluOpMaker : public framework::OpProtoAndCheckerMaker { public: PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -43,10 +44,12 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC(PRelu operator The equation is: -f(x) = alpha * x , for x < 0 -f(x) = x , for x >= 0 + + f(x) = alpha * x , for x < 0 + f(x) = x , for x >= 0 + )DOC"); - AddAttr("alpha", "The scaling factor alpha of prelu.") + AddAttr("alpha", "The scaling factor alpha of prelu.") .SetDefault(0.0); } }; @@ -59,6 +62,8 @@ class PReluGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); auto *X_grad = ctx.Output(framework::GradVarName("X")); auto *X = ctx.Input("X"); @@ -72,7 +77,7 @@ class PReluGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad, +REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad, ops::PReluGradOp); REGISTER_OP_CPU_KERNEL(prelu, ops::PReluKernel); diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h index a7e34744ba8..a98d4898396 100644 --- a/paddle/operators/prelu_op.h +++ b/paddle/operators/prelu_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/platform/transform.h" namespace paddle { namespace operators { @@ -23,28 +24,60 @@ using Tensor = framework::Tensor; template using EigenVector = framework::EigenVector; +using platform::Transform; -template +template +class Prelu_functor { + public: + explicit Prelu_functor(const T& alpha) : alpha_(alpha) {} + + HOSTDEVICE T operator()(const T& X) const { + if (X > 0) + return X; + else + return X * alpha_; + } + + private: + T alpha_; +}; + +template class PReluKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Out = context.Output("Out"); - Out->mutable_data(context.GetPlace()); + const T* X_ptr = X->data(); + T* O_ptr = Out->mutable_data(context.GetPlace()); - auto alpha = static_cast(context.Attr("alpha")); + auto alpha = static_cast(context.Attr("alpha")); - auto X_vec = EigenVector::Flatten(*X); - auto Out_vec = EigenVector::Flatten(*Out); + int numel = X->numel(); - // auto place = context.GetEigenDevice(); - // Out_vec.device(place) - Out_vec = X_vec.cwiseMax(0.f) + X_vec.cwiseMin(0.f) * alpha; + auto place = context.GetPlace(); + Transform(place, X_ptr, X_ptr + numel, O_ptr, Prelu_functor(alpha)); } }; -template +template +class Prelu_Grad_functor { + public: + explicit Prelu_Grad_functor(const T& alpha) : alpha_(alpha) {} + + HOSTDEVICE T operator()(const T& Out, const T& dOut) const { + if (Out > 0) + return dOut; + else + return dOut * alpha_; + } + + private: + T alpha_; +}; + +template class PReluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -53,16 +86,16 @@ class PReluGradKernel : public framework::OpKernel { auto* Out = context.Input("Out"); - auto alpha = static_cast(context.Attr("alpha")); + auto alpha = static_cast(context.Attr("alpha")); + + T* dX_ptr = dX->mutable_data(context.GetPlace()); + const T* dO_ptr = dO->data(); + const T* O_ptr = Out->data(); + int numel = dX->numel(); - dX->mutable_data(context.GetPlace()); - for (int i = 0; i < dX->numel(); ++i) { - if (Out->data()[i] > 0) { - dX->data()[i] = dO->data()[i]; - } else { - dX->data()[i] = dO->data()[i] * alpha; - } - } + auto place = context.GetPlace(); + Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr, + Prelu_Grad_functor(alpha)); } }; diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py index 39b6f673fdb..cbf2e6b2a88 100644 --- a/python/paddle/v2/framework/tests/test_prelu_op.py +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -6,7 +6,7 @@ from op_test import OpTest class PreluTest(OpTest): def setUp(self): self.op_type = "prelu" - self.inputs = {'X': np.random.normal(size=(3, 5)).astype("float32")} + self.inputs = {'X': np.random.normal(size=(10, 10)).astype("float32")} self.attrs = {'alpha': 0.1} out_np = np.maximum(self.inputs['X'], 0.) out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.attrs['alpha'] -- GitLab