提交 1cca7114 编写于 作者: H heqiaozhi

fix infer

test=develop
上级 aab9ea6c
...@@ -37,12 +37,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { ...@@ -37,12 +37,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
"Input(Label)'s rank should be 2."); "Input(Label)'s rank should be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
"The 1st dimension of Input(X) and Input(Label) should " "The 1st dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(label_dims[1], 1UL, PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
"The 2nd dimension of " "The 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
}
ctx->SetOutputDim("Y", {x_dims[0], 1}); ctx->SetOutputDim("Y", {x_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
...@@ -99,10 +101,12 @@ class TeacherStudentSigmoidLossGradientOp ...@@ -99,10 +101,12 @@ class TeacherStudentSigmoidLossGradientOp
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
"The 1st dimension of Input(X) and Input(Label) should " "The 1st dimension of Input(X) and Input(Label) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], PADDLE_ENFORCE_EQ(
x_dims[0], dy_dims[0],
"The 1st dimension of Input(X) and Input(Y@Grad) should " "The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(dy_dims[1], 1, PADDLE_ENFORCE_EQ(dy_dims[1], 1,
...@@ -110,6 +114,7 @@ class TeacherStudentSigmoidLossGradientOp ...@@ -110,6 +114,7 @@ class TeacherStudentSigmoidLossGradientOp
PADDLE_ENFORCE_EQ(label_dims[1], 1, PADDLE_ENFORCE_EQ(label_dims[1], 1,
"When Attr(soft_label) == false, the 2nd dimension of " "When Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", framework::GradVarName("X"));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册