未验证 提交 647ff784 编写于 作者: L lvmengsi 提交者: GitHub

fix mul double grad (#20040)

上级 8f0b3c05
...@@ -244,7 +244,8 @@ class MulDoubleGradOp : public framework::OperatorWithKernel { ...@@ -244,7 +244,8 @@ class MulDoubleGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null"); PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null");
if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) { if (ctx->HasOutput("DDOut") &&
(ctx->HasInput("DDX") || (ctx->HasInput("DDY")))) {
ctx->ShareDim("DOut", "DDOut"); ctx->ShareDim("DOut", "DDOut");
} }
if (ctx->HasOutput("DX") && ctx->HasInput("DDY")) { if (ctx->HasOutput("DX") && ctx->HasInput("DDY")) {
...@@ -275,9 +276,9 @@ class MulDoubleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -275,9 +276,9 @@ class MulDoubleGradMaker : public framework::SingleGradOpDescMaker {
auto ddw = OutputGrad(framework::GradVarName("Y")); auto ddw = OutputGrad(framework::GradVarName("Y"));
std::vector<std::string> empty_str = {}; std::vector<std::string> empty_str = {};
retv->SetOutput("DDOut", (ddx.empty()) if (!ddx.empty() || !ddw.empty()) {
? empty_str retv->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
: InputGrad(framework::GradVarName("Out"))); }
retv->SetOutput("DX", ddw.empty() ? empty_str : InputGrad("X")); retv->SetOutput("DX", ddw.empty() ? empty_str : InputGrad("X"));
retv->SetOutput("DY", ddx.empty() ? empty_str : InputGrad("Y")); retv->SetOutput("DY", ddx.empty() ? empty_str : InputGrad("Y"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册