From 7ac4818a9804b0198e118d940920b03cbeb20a0e Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 4 Jul 2019 10:37:09 +0800 Subject: [PATCH] Refine Infershape in activation_op for double_grad. (#18485) * Refine Infershape in activation_op for double_grad. --- paddle/fluid/operators/activation_op.cc | 39 +++++++++++++++++++++---- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index abf56b7a16..acd100a8a6 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -604,21 +604,48 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { if (static_cast(kDepValue) & static_cast(kDepX)) { - if (HasOutputs("DX") && ctx->HasOutput("DX")) { + if (ctx->HasOutput("DX")) { ctx->ShareDim("X", "DX"); ctx->ShareLoD("X", "DX"); } - if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) { + if (ctx->HasOutput("DDOut")) { ctx->ShareDim("X", "DDOut"); ctx->ShareLoD("X", "DDOut"); } } if (static_cast(kDepValue) & static_cast(kDepOut)) { - if (HasOutputs("DOut") && ctx->HasOutput("DOut")) { + if (ctx->HasOutput("DOut")) { ctx->ShareDim("Out", "DOut"); ctx->ShareLoD("Out", "DOut"); } - if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) { + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("Out", "DDOut"); + ctx->ShareLoD("Out", "DDOut"); + } + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this, "DDX"); + } +}; + +template +class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + if (static_cast(kDepValue) & static_cast(kDepX)) { + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("X", "DDOut"); + ctx->ShareLoD("X", "DDOut"); + } + } + if (static_cast(kDepValue) & static_cast(kDepOut)) { + if (ctx->HasOutput("DDOut")) { ctx->ShareDim("Out", "DDOut"); ctx->ShareLoD("Out", "DDOut"); } @@ -771,7 +798,7 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, ops::ReluDoubleGradMaker); REGISTER_OPERATOR( relu_grad_grad, - ops::ActivationOpDoubleGrad::FwdDeps()>); + ops::ActivationOpDoubleGrad2::FwdDeps()>); REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); @@ -796,7 +823,7 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, ops::LeakyReluDoubleGradMaker); REGISTER_OPERATOR( leaky_relu_grad_grad, - ops::ActivationOpDoubleGrad::FwdDeps()>); + ops::ActivationOpDoubleGrad2::FwdDeps()>); REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, LeakyReluGradFunctor); -- GitLab