“5c18decf3c680f44121c77682f884be04791f30b”上不存在“develop/doc_cn/design/executor.html”
提交 3064ae33 编写于 作者: G gangliao 提交者: GitHub

Merge pull request #3180 from reyoung/feature/MeanOpGrad

Add Gradient Operator for mean
...@@ -33,13 +33,23 @@ public: ...@@ -33,13 +33,23 @@ public:
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker) MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op"); AddOutput("Out", "The output of mean op").IgnoreGradient();
AddComment("Mean Operator"); AddComment("Mean Operator");
} }
}; };
class MeanGradOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + GRAD_VAR_SUFFIX())
->Resize(ctx.Input<Tensor>("X")->dims());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);
REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::CPUPlace, float>);
...@@ -3,3 +3,4 @@ ...@@ -3,3 +3,4 @@
#include "paddle/operators/mean_op.h" #include "paddle/operators/mean_op.h"
REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::GPUPlace, float>);
\ No newline at end of file
...@@ -35,5 +35,22 @@ public: ...@@ -35,5 +35,22 @@ public:
} }
}; };
template <typename Place, typename T>
class MeanGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + OperatorBase::GRAD_VAR_SUFFIX());
PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar");
auto IG = context.Output<Tensor>("X" + OperatorBase::GRAD_VAR_SUFFIX());
IG->mutable_data<T>(context.GetPlace());
T ig_size = (T)framework::product(IG->dims());
EigenVector<T>::Flatten(*IG).device(*(context.GetEigenDevice<Place>())) =
EigenScalar<T>::From(*OG) / ig_size;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -51,6 +51,7 @@ using CPUPlace = platform::CPUPlace; ...@@ -51,6 +51,7 @@ using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace; using GPUPlace = platform::GPUPlace;
using NetOp = framework::NetOp; using NetOp = framework::NetOp;
using OpRegistry = framework::OpRegistry; using OpRegistry = framework::OpRegistry;
using OperatorBase = framework::OperatorBase;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册