From 8dcae0c55d9d352f6ac62573165baa1348de6171 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Mon, 11 Jan 2021 16:22:45 +0800 Subject: [PATCH] register OPMaker and Infer Shape Check for fused_elementwise_add (#30259) --- .../fused/fused_elemwise_activation_op.cc | 66 +++++++++++++++++-- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index cde0912eb2..4ff66d0d2b 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -287,6 +287,15 @@ class FusedElemwiseActivationGradMaker } }; +class FusedElemwiseAddActivationMaker : public FusedElemwiseActivationMaker {}; + +template +class FusedElemwiseAddActivationGradMaker + : public FusedElemwiseActivationGradMaker { + public: + using FusedElemwiseActivationGradMaker::FusedElemwiseActivationGradMaker; +}; + class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -367,6 +376,53 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { } }; +class FusedElemwiseAddActivationOp : public FusedElemwiseActivationOp { + public: + using FusedElemwiseActivationOp::FusedElemwiseActivationOp; + void InferShape(framework::InferShapeContext *ctx) const override { + FusedElemwiseActivationOp::InferShape(ctx); + std::vector functor_names = + ctx->Attrs().Get>("functor_list"); + bool elemntwise_add_detected = false; + for (auto names : functor_names) { + if (names == "elementwise_add") { + elemntwise_add_detected = true; + break; + } + } + PADDLE_ENFORCE_EQ( + elemntwise_add_detected, true, + platform::errors::InvalidArgument( + "When the FusedElemwiseAddActivationOp Is used in fused pass, the " + "elementwise_add Op must be" + "detected and used, Please check the fuse pass pattern")); + } +}; + +class FusedElemwiseAddActivationOpGrad : public FusedElemwiseActivationOpGrad { + public: + using FusedElemwiseActivationOpGrad::FusedElemwiseActivationOpGrad; + + void InferShape(framework::InferShapeContext *ctx) const override { + FusedElemwiseActivationOpGrad::InferShape(ctx); + std::vector functor_names = + ctx->Attrs().Get>("functor_list"); + bool elemntwise_add_grad_detected = false; + for (auto names : functor_names) { + if (names == "elementwise_add_grad") { + elemntwise_add_grad_detected = true; + break; + } + } + PADDLE_ENFORCE_EQ( + elemntwise_add_grad_detected, true, + platform::errors::InvalidArgument( + "When the FusedElemwiseAddActivationOpGrad Is used in fused pass, " + "the elementwise_add_grad Op must be" + "detected and used, Please check the fuse pass pattern")); + } +}; + DECLARE_NO_NEED_BUFFER_VARS_INFERER( FusedElemwiseAddActivationNoNeddBufVarInferer, "X", "Y"); } // namespace operators @@ -397,13 +453,13 @@ REGISTER_OP_CPU_KERNEL( // for memory optimization, we register the fused_elemwise_add_activation OP REGISTER_OPERATOR( - fused_elemwise_add_activation, ops::FusedElemwiseActivationOp, - ops::FusedElemwiseActivationMaker, - ops::FusedElemwiseActivationGradMaker, - ops::FusedElemwiseActivationGradMaker); + fused_elemwise_add_activation, ops::FusedElemwiseAddActivationOp, + ops::FusedElemwiseAddActivationMaker, + ops::FusedElemwiseAddActivationGradMaker, + ops::FusedElemwiseAddActivationGradMaker); REGISTER_OPERATOR(fused_elemwise_add_activation_grad, ops::FusedElemwiseAddActivationNoNeddBufVarInferer, - ops::FusedElemwiseActivationOpGrad); + ops::FusedElemwiseAddActivationOpGrad); REGISTER_OP_CPU_KERNEL( fused_elemwise_add_activation, -- GitLab