From 18d924173f2b85cc8defd88958bc448077caf1e5 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 2 Aug 2017 19:32:45 +0800 Subject: [PATCH] Add Gradient Operator for mean --- paddle/operators/mean_op.cc | 12 +++++++++++- paddle/operators/mean_op.cu | 1 + paddle/operators/mean_op.h | 17 +++++++++++++++++ paddle/operators/type_alias.h | 1 + 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index fe34d6ad401..78131b26808 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -33,13 +33,23 @@ public: MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op"); + AddOutput("Out", "The output of mean op").IgnoreGradient(); AddComment("Mean Operator"); } }; +class MeanGradOp : public OperatorWithKernel { +protected: + void InferShape(const InferShapeContext &ctx) const override { + ctx.Output("X" + GRAD_VAR_SUFFIX()) + ->Resize(ctx.Input("X")->dims()); + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); +REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp); +REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel); diff --git a/paddle/operators/mean_op.cu b/paddle/operators/mean_op.cu index 740157cbc57..e15de2fd0dd 100644 --- a/paddle/operators/mean_op.cu +++ b/paddle/operators/mean_op.cu @@ -3,3 +3,4 @@ #include "paddle/operators/mean_op.h" REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel); +REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel); \ No newline at end of file diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index 5f7d443751d..555b45b0705 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -32,5 +32,22 @@ public: } }; +template +class MeanGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + auto OG = context.Input("Out" + OperatorBase::GRAD_VAR_SUFFIX()); + PADDLE_ENFORCE(framework::product(OG->dims()) == 1, + "Mean Gradient should be scalar"); + auto IG = context.Output("X" + OperatorBase::GRAD_VAR_SUFFIX()); + IG->mutable_data(context.GetPlace()); + + T ig_size = (T)framework::product(IG->dims()); + + EigenVector::Flatten(*IG).device(*(context.GetEigenDevice())) = + EigenScalar::From(*OG) / ig_size; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index 93b62cddc81..9049ffda1da 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -51,6 +51,7 @@ using CPUPlace = platform::CPUPlace; using GPUPlace = platform::GPUPlace; using NetOp = framework::NetOp; using OpRegistry = framework::OpRegistry; +using OperatorBase = framework::OperatorBase; } // namespace operators } // namespace paddle -- GitLab