diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 57e3b7fcad4d8e0ca07abaacca2563db2974ac45..acd100a8a69779dab3b452cd7e2b1e4ff8765591 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -604,25 +604,48 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { if (static_cast(kDepValue) & static_cast(kDepX)) { - // some op has no output DX, check HasOutputs("DX") here - if (HasOutputs("DX") && ctx->HasOutput("DX")) { + if (ctx->HasOutput("DX")) { ctx->ShareDim("X", "DX"); ctx->ShareLoD("X", "DX"); } - // some op has no output DDout, check HasOutputs("DDout") here - if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) { + if (ctx->HasOutput("DDOut")) { ctx->ShareDim("X", "DDOut"); ctx->ShareLoD("X", "DDOut"); } } if (static_cast(kDepValue) & static_cast(kDepOut)) { - // some op has no output DOut, check HasOutputs("DOut") here - if (HasOutputs("DOut") && ctx->HasOutput("DOut")) { + if (ctx->HasOutput("DOut")) { ctx->ShareDim("Out", "DOut"); ctx->ShareLoD("Out", "DOut"); } - // some op has no output DDOut, check HasOutputs("DDOut") here - if (HasOutputs("DDOut") && ctx->HasOutput("DDOut")) { + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("Out", "DDOut"); + ctx->ShareLoD("Out", "DDOut"); + } + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this, "DDX"); + } +}; + +template +class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + if (static_cast(kDepValue) & static_cast(kDepX)) { + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("X", "DDOut"); + ctx->ShareLoD("X", "DDOut"); + } + } + if (static_cast(kDepValue) & static_cast(kDepOut)) { + if (ctx->HasOutput("DDOut")) { ctx->ShareDim("Out", "DDOut"); ctx->ShareLoD("Out", "DDOut"); } @@ -775,7 +798,7 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, ops::ReluDoubleGradMaker); REGISTER_OPERATOR( relu_grad_grad, - ops::ActivationOpDoubleGrad::FwdDeps()>); + ops::ActivationOpDoubleGrad2::FwdDeps()>); REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); @@ -800,7 +823,7 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, ops::LeakyReluDoubleGradMaker); REGISTER_OPERATOR( leaky_relu_grad_grad, - ops::ActivationOpDoubleGrad::FwdDeps()>); + ops::ActivationOpDoubleGrad2::FwdDeps()>); REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, LeakyReluGradFunctor);