From 1c25c88aba21d7a65e0bb4503dfe0a2ea12cf84f Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 10 Sep 2019 12:30:07 +0800 Subject: [PATCH] refine memory usage of some operators, test=develop (#19700) --- paddle/fluid/operators/expand_op.cc | 10 +++++++--- paddle/fluid/operators/expand_op.h | 5 +++-- paddle/fluid/operators/roi_align_op.cc | 7 +++++-- .../sigmoid_cross_entropy_with_logits_op.cc | 12 ++++++++++-- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index e15f848c23d..2ca27c0cbf5 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -163,8 +163,9 @@ class ExpandGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -195,13 +196,16 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker, ops::ExpandGradOpDescMaker); -REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp); +REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp, + ops::ExpandGradNoNeedBufVarsInferer); REGISTER_OP_CPU_KERNEL( expand, ops::ExpandKernel, ops::ExpandKernel, diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index 8153987d6c7..32e6332823e 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -186,7 +186,6 @@ class ExpandGradKernel : public framework::OpKernel { "reduce dimensions."); auto* in0 = context.Input(framework::GradVarName("Out")); auto* out0 = context.Output(framework::GradVarName("X")); - auto x = EigenVector::Flatten(*(context.Input("X"))); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); Eigen::DSizes reshape_dims; @@ -200,7 +199,9 @@ class ExpandGradKernel : public framework::OpKernel { auto out_grad = EigenVector::Flatten(*in0); x_grad.device( *context.template device_context().eigen_device()) = - out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions()); + out_grad.reshape(reshape_dims) + .sum(reduce_dims) + .reshape(x_grad.dimensions()); } }; diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index d0dd861af7b..21c3dd27f02 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -85,7 +85,7 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), + return framework::OpKernelType(ctx.Input("ROIs")->type(), ctx.device_context()); } }; @@ -167,13 +167,16 @@ class ROIAlignGradDescMaker : public framework::SingleGradOpDescMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(RoiAlignGradNoNeedBufVarsInferer, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker, ops::ROIAlignGradDescMaker); -REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp); +REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp, + ops::RoiAlignGradNoNeedBufVarsInferer); REGISTER_OP_CPU_KERNEL( roi_align, ops::CPUROIAlignOpKernel, diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc index 1c2726454f3..c453b03dddf 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -168,6 +168,12 @@ class SigmoidCrossEntropyWithLogitsGradOpDescMaker } }; +DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsInplaceInferer, + {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsGradInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + } // namespace operators } // namespace paddle @@ -175,9 +181,11 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits, ops::SigmoidCrossEntropyWithLogitsOp, ops::SigmoidCrossEntropyWithLogitsOpMaker, - ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker); + ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker, + ops::SigmoidCrossEntropyWithLogitsInplaceInferer); REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad, - ops::SigmoidCrossEntropyWithLogitsGradOp); + ops::SigmoidCrossEntropyWithLogitsGradOp, + ops::SigmoidCrossEntropyWithLogitsGradInplaceInferer); REGISTER_OP_CPU_KERNEL( sigmoid_cross_entropy_with_logits, ops::SigmoidCrossEntropyWithLogitsKernel