diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index e15f848c23df7ca25dd15b9595b18b62cb7c2790..2ca27c0cbf58fd273a290881a0928ffc0f1c3ca1 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -163,8 +163,9 @@ class ExpandGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -195,13 +196,16 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker, ops::ExpandGradOpDescMaker); -REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp); +REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp, + ops::ExpandGradNoNeedBufVarsInferer); REGISTER_OP_CPU_KERNEL( expand, ops::ExpandKernel, ops::ExpandKernel, diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index 8153987d6c721c39544ff02a8adc925e3f01fd14..32e6332823e60fdf17ab6d5186b8b9cd22ade38a 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -186,7 +186,6 @@ class ExpandGradKernel : public framework::OpKernel { "reduce dimensions."); auto* in0 = context.Input(framework::GradVarName("Out")); auto* out0 = context.Output(framework::GradVarName("X")); - auto x = EigenVector::Flatten(*(context.Input("X"))); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); Eigen::DSizes reshape_dims; @@ -200,7 +199,9 @@ class ExpandGradKernel : public framework::OpKernel { auto out_grad = EigenVector::Flatten(*in0); x_grad.device( *context.template device_context().eigen_device()) = - out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions()); + out_grad.reshape(reshape_dims) + .sum(reduce_dims) + .reshape(x_grad.dimensions()); } }; diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index d0dd861af7be80ede75b9d14867087ec687fc1da..21c3dd27f02b18fa78108f6a291dbf4c12724786 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -85,7 +85,7 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), + return framework::OpKernelType(ctx.Input("ROIs")->type(), ctx.device_context()); } }; @@ -167,13 +167,16 @@ class ROIAlignGradDescMaker : public framework::SingleGradOpDescMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(RoiAlignGradNoNeedBufVarsInferer, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker, ops::ROIAlignGradDescMaker); -REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp); +REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp, + ops::RoiAlignGradNoNeedBufVarsInferer); REGISTER_OP_CPU_KERNEL( roi_align, ops::CPUROIAlignOpKernel, diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc index 1c2726454f3d1fb8545e5d3260e59fcafbcb2aee..c453b03dddf68a7f4638aa0eceaa2aa70dc3d5f4 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -168,6 +168,12 @@ class SigmoidCrossEntropyWithLogitsGradOpDescMaker } }; +DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsInplaceInferer, + {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(SigmoidCrossEntropyWithLogitsGradInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + } // namespace operators } // namespace paddle @@ -175,9 +181,11 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits, ops::SigmoidCrossEntropyWithLogitsOp, ops::SigmoidCrossEntropyWithLogitsOpMaker, - ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker); + ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker, + ops::SigmoidCrossEntropyWithLogitsInplaceInferer); REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad, - ops::SigmoidCrossEntropyWithLogitsGradOp); + ops::SigmoidCrossEntropyWithLogitsGradOp, + ops::SigmoidCrossEntropyWithLogitsGradInplaceInferer); REGISTER_OP_CPU_KERNEL( sigmoid_cross_entropy_with_logits, ops::SigmoidCrossEntropyWithLogitsKernel