From 0760aaf4401b2e87684a9ae8e7931cf9e51a74b8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Mar 2018 19:20:49 +0800 Subject: [PATCH] Shrink batch_norm_grad's inputs --- paddle/fluid/operators/batch_norm_op.cc | 31 +++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 5d27f5b60c..36049ee6a4 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -457,12 +457,39 @@ class BatchNormGradKernel } }; +class BatchNormGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op = new framework::OpDesc(); + op->SetType("batch_norm_grad"); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + + op->SetInput("Scale", Input("Scale")); + op->SetInput("SavedMean", Output("SavedMean")); + op->SetInput("SavedVariance", Output("SavedVariance")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale")); + op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); + + return std::unique_ptr(op); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, - batch_norm_grad, ops::BatchNormGradOp); +REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, + ops::BatchNormGradMaker); +REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); + REGISTER_OP_CPU_KERNEL( batch_norm, ops::BatchNormKernel); -- GitLab