From 5c8f210ce3e7e654b03478d3d94ef45594420904 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 29 Aug 2019 09:30:16 +0800 Subject: [PATCH] refine inplace inference registry, test=develop (#19032) --- paddle/fluid/framework/inplace_op_inference.h | 10 +++++++ paddle/fluid/operators/activation_op.cc | 10 +++---- paddle/fluid/operators/batch_norm_op.cc | 27 ++----------------- .../operators/elementwise/elementwise_op.h | 19 +++---------- paddle/fluid/operators/flatten_op.cc | 19 +++---------- paddle/fluid/operators/group_norm_op.cc | 19 +++---------- paddle/fluid/operators/reshape_op.cc | 19 +++---------- .../softmax_with_cross_entropy_op.cc | 20 +++----------- paddle/fluid/operators/sum_op.cc | 8 +----- 9 files changed, 36 insertions(+), 115 deletions(-) diff --git a/paddle/fluid/framework/inplace_op_inference.h b/paddle/fluid/framework/inplace_op_inference.h index 95fd5b046a5..40026eaca9a 100644 --- a/paddle/fluid/framework/inplace_op_inference.h +++ b/paddle/fluid/framework/inplace_op_inference.h @@ -53,5 +53,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference { } }; +#define DECLARE_INPLACE_OP_INFERER(class_name, ...) \ + class class_name final : public ::paddle::framework::InplaceOpInference { \ + public: \ + std::unordered_map operator()( \ + const ::paddle::framework::OpDesc& op_desc, \ + bool use_cuda) const final { \ + return {__VA_ARGS__}; \ + } \ + } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 75e7e240eb4..531e89a5efd 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -774,13 +774,9 @@ class SquareDoubleGradMaker } }; -class ActivationGradOpInplaceInference : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc& op_desc, bool use_cuda) const override { - return {{framework::GradVarName("Out"), framework::GradVarName("X")}}; - } -}; +DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 56111c66339..bb76904bff0 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -598,36 +598,13 @@ std::unique_ptr BatchNormGradMaker::Apply() const { return std::unique_ptr(op); } -class BatchNormInplaceInToOut : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - return {{"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"}}; - } -}; - -class BatchNormGradInplaceInToOut : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - // Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C] - return { - {framework::GradVarName("Y"), framework::GradVarName("X")}, - {"SavedMean", framework::GradVarName("Scale")}, - {"SavedVariance", framework::GradVarName("Bias")}, - }; - } -}; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, - ops::BatchNormOpInferVarType, ops::BatchNormGradMaker) -// ops::BatchNormInplaceInToOut); -REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp) -// ops::BatchNormGradInplaceInToOut); + ops::BatchNormOpInferVarType, ops::BatchNormGradMaker); +REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); REGISTER_OP_CPU_KERNEL( batch_norm, ops::BatchNormKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index c251cc72270..aa6375d3000 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -317,21 +317,10 @@ class ElemwiseGradKernel : public framework::OpKernel { } }; -class ElementwiseOpInplace : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - return {{"X", "Out"}}; - } -}; - -class ElementwiseGradOpInplace : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - return {{framework::GradVarName("Out"), framework::GradVarName("X")}}; - } -}; +DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplace, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y"); diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 3111cace66c..350d40ce832 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -268,21 +268,10 @@ class Flatten2GradOp : public framework::OperatorBase { } }; -class FlattenOpInplaceInToOut : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - return {{"X", "Out"}}; - } -}; - -class FlattenGradInplaceinToOut : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - return {{framework::GradVarName("Out"), framework::GradVarName("X")}}; - } -}; +DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInToOut, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceinToOut, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 2b1e8038fc4..92772f2bc39 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -170,21 +170,10 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker { } }; -class GroupNormInplaceInToOut : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - return {{"X", "Y"}}; - } -}; - -class GroupNormGradInplaceInToOut : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - return {{framework::GradVarName("Y"), framework::GradVarName("X")}}; - } -}; +DECLARE_INPLACE_OP_INFERER(GroupNormInplaceInToOut, {"X", "Y"}); +DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut, + {framework::GradVarName("Y"), + framework::GradVarName("X")}); class GroupNormOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 9750bc87b00..6341fa935ec 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -393,21 +393,10 @@ class Reshape2GradOp : public framework::OperatorWithKernel { } }; -class ReshapeOpInplaceInToOut : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - return {{"X", "Out"}}; - } -}; - -class ReshapeGradInplaceInToOut : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc &op_desc, bool use_cuda) const override { - return {{framework::GradVarName("Out"), framework::GradVarName("X")}}; - } -}; +DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 716826bf156..8cde72921cb 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -255,23 +255,11 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { } }; -class SoftmaxWithCrossEntropyInplaceInference - : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc& op_desc, bool use_cuda) const { - return {{"Logits", "Softmax"}}; - } -}; +DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInference, + {"Logits", "Softmax"}); -class SoftmaxWithCrossEntropyGradInplaceInference - : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc& op_desc, bool use_cuda) const { - return {{"Softmax", framework::GradVarName("Logits")}}; - } -}; +DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInference, + {"Softmax", framework::GradVarName("Logits")}); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 3c64ebe9950..37204fd72ae 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -238,13 +238,7 @@ class SumGradMaker : public framework::GradOpDescMakerBase { } }; -class SumInplace : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc& op_desc, bool use_cuda) const override { - return {{"X", "Out"}}; - } -}; +DECLARE_INPLACE_OP_INFERER(SumInplace, {"X", "Out"}); } // namespace operators } // namespace paddle -- GitLab