未验证 提交 647ff784 编写于 作者: L lvmengsi 提交者: GitHub

fix mul double grad (#20040)

上级 8f0b3c05
......@@ -244,7 +244,8 @@ class MulDoubleGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null");
if (ctx->HasOutput("DDOut") && ctx->HasInput("DDX")) {
if (ctx->HasOutput("DDOut") &&
(ctx->HasInput("DDX") || (ctx->HasInput("DDY")))) {
ctx->ShareDim("DOut", "DDOut");
}
if (ctx->HasOutput("DX") && ctx->HasInput("DDY")) {
......@@ -275,9 +276,9 @@ class MulDoubleGradMaker : public framework::SingleGradOpDescMaker {
auto ddw = OutputGrad(framework::GradVarName("Y"));
std::vector<std::string> empty_str = {};
retv->SetOutput("DDOut", (ddx.empty())
? empty_str
: InputGrad(framework::GradVarName("Out")));
if (!ddx.empty() || !ddw.empty()) {
retv->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
}
retv->SetOutput("DX", ddw.empty() ? empty_str : InputGrad("X"));
retv->SetOutput("DY", ddx.empty() ? empty_str : InputGrad("Y"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册