From 0560733c2e4492db5ae0af2553e7fd7b6d883007 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Wed, 2 Aug 2017 18:16:17 +0800 Subject: [PATCH] Add sigmoid backward implenmention. --- paddle/operators/sigmoid_op.cc | 12 ++++++++---- paddle/operators/sigmoid_op.cu | 1 + paddle/operators/sigmoid_op.h | 19 +++++++++++++++++++ .../v2/framework/tests/test_sigmoid_op.py | 11 +++++++++++ 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index a81ab262cc..9e565bb23f 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -37,10 +37,12 @@ public: class SigmoidOpGrad : public OperatorWithKernel { protected: - void InferShape(const InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "SigmoidGrad"; - return ""; + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, + "Sigmoid Gradient Op only have one input"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, + "Sigmoid Gradient Op only have one output"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; @@ -51,3 +53,5 @@ REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); +REGISTER_OP_CPU_KERNEL(sigmoid_grad, + ops::SigmoidGradKernel); diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu index f679b20418..f83483131c 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/sigmoid_op.cu @@ -1,3 +1,4 @@ #include "paddle/operators/sigmoid_op.h" REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel); +REGISTER_OP_GPU_KERNEL(sigmoid_grad, ops::SigmoidGradKernel); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 3dd23a9ebc..2ea75b4885 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -32,5 +32,24 @@ public: 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(*input)).exp()); } }; + +template +class SigmoidGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + // TODO(qingqing) maybe a helper funciton is needed fo the name x@GRAD + auto y_t = context.Input("Y"); + auto dy_t = context.Input("Y@GRAD"); + auto dx_t = context.Output("X@GRAD"); + + dx_t->mutable_data(context.GetPlace()); + + auto dx = EigenVector::Flatten(*dx_t); + auto y = EigenVector::Flatten(*y_t); + auto dy = EigenVector::Flatten(*dy_t); + dx.device(*(context.GetEigenDevice())) = dy * y * (1. - y); + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 50044a122f..4b0acd3294 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -12,5 +12,16 @@ class TestSigmoidOp(unittest.TestCase): self.Y = 1 / (1 + np.exp(-self.X)) +#class TestSigmoidGradOp(unittest.TestCase): +# __metaclass__ = OpTestMeta +# +# def setUp(self): +# self.type = "sigmoid_grad" +# self.Y = np.random.random((32, 100)).astype("float32") +# self.dY = np.random.random((32, 100)).astype("float32") +# self.dX = self.dY * self.Y * (1 - self.Y) +# print self.dX +# + if __name__ == '__main__': unittest.main() -- GitLab