未验证 提交 d9bb55a1 编写于 作者: J Jiabin Yang 提交者: GitHub

Merge pull request #14756 from JiabinYang/fix_hs_op

fix bug in dist train on hs, test=develop
...@@ -150,13 +150,12 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -150,13 +150,12 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
"Output(W@Grad should not be null."); "Output(W@Grad should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad should not be null."); "Output(X@Grad should not be null.");
if (!ctx->Attrs().Get<bool>("is_sparse")) {
if (ctx->HasOutput(framework::GradVarName("Bias"))) { if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias")); ctx->GetInputDim("Bias"));
}
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
} }
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册