提交 2055c16d 编写于 作者: H Hongyu Liu 提交者: phlrain

Merge pull request #16890 from colourful-tree/dev

fix teacher_student op infer
上级 ece74510
...@@ -34,12 +34,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { ...@@ -34,12 +34,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
"Input(Label)'s rank should be 2."); "Input(Label)'s rank should be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
"The 1st dimension of Input(X) and Input(Label) should " "The 1st dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(label_dims[1], 1UL, PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
"The 2nd dimension of " "The 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
}
ctx->SetOutputDim("Y", {x_dims[0], 1}); ctx->SetOutputDim("Y", {x_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
...@@ -74,10 +76,12 @@ class TeacherStudentSigmoidLossGradientOp ...@@ -74,10 +76,12 @@ class TeacherStudentSigmoidLossGradientOp
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
"The 1st dimension of Input(X) and Input(Label) should " "The 1st dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], PADDLE_ENFORCE_EQ(
x_dims[0], dy_dims[0],
"The 1st dimension of Input(X) and Input(Y@Grad) should " "The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(dy_dims[1], 1, PADDLE_ENFORCE_EQ(dy_dims[1], 1,
...@@ -85,6 +89,7 @@ class TeacherStudentSigmoidLossGradientOp ...@@ -85,6 +89,7 @@ class TeacherStudentSigmoidLossGradientOp
PADDLE_ENFORCE_EQ(label_dims[1], 1, PADDLE_ENFORCE_EQ(label_dims[1], 1,
"When Attr(soft_label) == false, the 2nd dimension of " "When Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", framework::GradVarName("X"));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册