diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index b84833453c906d8ec6fc0cba8501a37b1cc328ce..80059ff14ca4d475b1a2c625ef1dcfe8912e6947 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -244,7 +244,8 @@ class MulDoubleGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null"); - if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) { + if (ctx->HasOutput("DDOut") && + (ctx->HasInput("DDX") || (ctx->HasInput("DDY")))) { ctx->ShareDim("DOut", "DDOut"); } if (ctx->HasOutput("DX") && ctx->HasInput("DDY")) { @@ -275,9 +276,9 @@ class MulDoubleGradMaker : public framework::SingleGradOpDescMaker { auto ddw = OutputGrad(framework::GradVarName("Y")); std::vector empty_str = {}; - retv->SetOutput("DDOut", (ddx.empty()) - ? empty_str - : InputGrad(framework::GradVarName("Out"))); + if (!ddx.empty() || !ddw.empty()) { + retv->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); + } retv->SetOutput("DX", ddw.empty() ? empty_str : InputGrad("X")); retv->SetOutput("DY", ddx.empty() ? empty_str : InputGrad("Y"));