From 647ff784e2fee81dab4fd7c0c7f94a820c1e7e6f Mon Sep 17 00:00:00 2001 From: lvmengsi Date: Fri, 27 Sep 2019 13:53:47 +0800 Subject: [PATCH] fix mul double grad (#20040) --- paddle/fluid/operators/mul_op.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index b84833453c9..80059ff14ca 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -244,7 +244,8 @@ class MulDoubleGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null"); - if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) { + if (ctx->HasOutput("DDOut") && + (ctx->HasInput("DDX") || (ctx->HasInput("DDY")))) { ctx->ShareDim("DOut", "DDOut"); } if (ctx->HasOutput("DX") && ctx->HasInput("DDY")) { @@ -275,9 +276,9 @@ class MulDoubleGradMaker : public framework::SingleGradOpDescMaker { auto ddw = OutputGrad(framework::GradVarName("Y")); std::vector empty_str = {}; - retv->SetOutput("DDOut", (ddx.empty()) - ? empty_str - : InputGrad(framework::GradVarName("Out"))); + if (!ddx.empty() || !ddw.empty()) { + retv->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); + } retv->SetOutput("DX", ddw.empty() ? empty_str : InputGrad("X")); retv->SetOutput("DY", ddx.empty() ? empty_str : InputGrad("Y")); -- GitLab