From 11a1284c4ba78f862f00a7d9ab77dda22c78caff Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 25 Jul 2019 14:21:15 +0800 Subject: [PATCH] Refine Infershape in activation_op for double_grad (#18731) --- paddle/fluid/operators/activation_op.cc | 43 +++++++++++++++++++------ 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 57e3b7fcad..acd100a8a6 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -604,25 +604,48 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { if (static_cast(kDepValue) & static_cast(kDepX)) { - // some op has no output DX, check HasOutputs("DX") here - if (HasOutputs("DX") && ctx->HasOutput("DX")) { + if (ctx->HasOutput("DX")) { ctx->ShareDim("X", "DX"); ctx->ShareLoD("X", "DX"); } - // some op has no output DDout, check HasOutputs("DDout") here - if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) { + if (ctx->HasOutput("DDOut")) { ctx->ShareDim("X", "DDOut"); ctx->ShareLoD("X", "DDOut"); } } if (static_cast(kDepValue) & static_cast(kDepOut)) { - // some op has no output DOut, check HasOutputs("DOut") here - if (HasOutputs("DOut") && ctx->HasOutput("DOut")) { + if (ctx->HasOutput("DOut")) { ctx->ShareDim("Out", "DOut"); ctx->ShareLoD("Out", "DOut"); } - // some op has no output DDOut, check HasOutputs("DDOut") here - if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) { + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("Out", "DDOut"); + ctx->ShareLoD("Out", "DDOut"); + } + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this, "DDX"); + } +}; + +template +class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + if (static_cast(kDepValue) & static_cast(kDepX)) { + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("X", "DDOut"); + ctx->ShareLoD("X", "DDOut"); + } + } + if (static_cast(kDepValue) & static_cast(kDepOut)) { + if (ctx->HasOutput("DDOut")) { ctx->ShareDim("Out", "DDOut"); ctx->ShareLoD("Out", "DDOut"); } @@ -775,7 +798,7 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, ops::ReluDoubleGradMaker); REGISTER_OPERATOR( relu_grad_grad, - ops::ActivationOpDoubleGrad::FwdDeps()>); + ops::ActivationOpDoubleGrad2::FwdDeps()>); REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); @@ -800,7 +823,7 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, ops::LeakyReluDoubleGradMaker); REGISTER_OPERATOR( leaky_relu_grad_grad, - ops::ActivationOpDoubleGrad::FwdDeps()>); + ops::ActivationOpDoubleGrad2::FwdDeps()>); REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, LeakyReluGradFunctor); -- GitLab