diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 9d201eb93a2c0e34dd8e6869e97b43c4e278596e..1eb795faa858796f7a34aa495b43d043fdb5dd43 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -37,10 +37,8 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker { class SigmoidOpGrad : public OperatorWithKernel { protected: - void InferShape(const InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "SigmoidGrad"; - return ""; + void InferShape(const InferShapeContext &ctx) const override { + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; @@ -51,3 +49,5 @@ REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); +REGISTER_OP_CPU_KERNEL(sigmoid_grad, + ops::SigmoidGradKernel); diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu index 2123b17e4b5e90c22c2d6e9177f2a8956f8a4ac9..e80ba081f2ff805664cf92f3cb47e9ad51889058 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/sigmoid_op.cu @@ -16,3 +16,5 @@ #include "paddle/operators/sigmoid_op.h" REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel); +REGISTER_OP_GPU_KERNEL(sigmoid_grad, + ops::SigmoidGradKernel); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index eb473920a5f866825b52ecb946653ccead7000ea..d513261e74423ce93a50eaaaec1c7d5fadb8f4a8 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -27,6 +27,7 @@ class SigmoidKernel : public OpKernel { auto output = context.Output(0); output->mutable_data(context.GetPlace()); + // The clipping is used in Paddle's raw implenmention auto X = EigenVector::Flatten(*input); auto Y = EigenVector::Flatten(*output); auto place = context.GetEigenDevice(); @@ -34,5 +35,23 @@ class SigmoidKernel : public OpKernel { Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); } }; + +template +class SigmoidGradKernel : public OpKernel { + public: + void Compute(const ExecutionContext& context) const override { + auto Y_t = context.Input("Y"); + auto dY_t = context.Input(framework::GradVarName("Y")); + auto dX_t = context.Output(framework::GradVarName("X")); + + dX_t->mutable_data(context.GetPlace()); + + auto dX = EigenVector::Flatten(*dX_t); + auto Y = EigenVector::Flatten(*Y_t); + auto dY = EigenVector::Flatten(*dY_t); + dX.device(context.GetEigenDevice()) = dY * Y * (1. - Y); + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 2610bcf16303d492dce3ce63c93b54b0c88f6bba..2a57a41ed8b718fd420062ba68e853a4861b7359 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -12,5 +12,8 @@ class TestSigmoidOp(unittest.TestCase): self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} +#class TestSigmoidGradOp(unittest.TestCase): +#TODO(qingqing) add unit test + if __name__ == '__main__': unittest.main()