diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index a81ab262cc6fe7bdff0045259e0030f3d46f503f..9e565bb23f29786fcfbb6569f019bb3ab7b20e8b 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -37,10 +37,12 @@ public: class SigmoidOpGrad : public OperatorWithKernel { protected: - void InferShape(const InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "SigmoidGrad"; - return ""; + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, + "Sigmoid Gradient Op only have one input"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, + "Sigmoid Gradient Op only have one output"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); } }; @@ -51,3 +53,5 @@ REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); +REGISTER_OP_CPU_KERNEL(sigmoid_grad, + ops::SigmoidGradKernel); diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu index f679b20418f04eff4310efe4e121963ce5a235e0..f83483131cc0dd5141078c76a10b09cf7fa041c6 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/sigmoid_op.cu @@ -1,3 +1,4 @@ #include "paddle/operators/sigmoid_op.h" REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel); +REGISTER_OP_GPU_KERNEL(sigmoid_grad, ops::SigmoidGradKernel); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 3dd23a9ebc7ac0972d6ee07b9ac051d59e66f62f..2ea75b4885e19c536068f301a9f21af993068a8e 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -32,5 +32,24 @@ public: 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(*input)).exp()); } }; + +template +class SigmoidGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + // TODO(qingqing) maybe a helper funciton is needed fo the name x@GRAD + auto y_t = context.Input("Y"); + auto dy_t = context.Input("Y@GRAD"); + auto dx_t = context.Output("X@GRAD"); + + dx_t->mutable_data(context.GetPlace()); + + auto dx = EigenVector::Flatten(*dx_t); + auto y = EigenVector::Flatten(*y_t); + auto dy = EigenVector::Flatten(*dy_t); + dx.device(*(context.GetEigenDevice())) = dy * y * (1. - y); + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 50044a122f1d66dd54a24f6cce76074a60ee2262..4b0acd3294f193743bc84ad8f990f24a3bc093d1 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -12,5 +12,16 @@ class TestSigmoidOp(unittest.TestCase): self.Y = 1 / (1 + np.exp(-self.X)) +#class TestSigmoidGradOp(unittest.TestCase): +# __metaclass__ = OpTestMeta +# +# def setUp(self): +# self.type = "sigmoid_grad" +# self.Y = np.random.random((32, 100)).astype("float32") +# self.dY = np.random.random((32, 100)).astype("float32") +# self.dX = self.dY * self.Y * (1 - self.Y) +# print self.dX +# + if __name__ == '__main__': unittest.main()