diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc index 6a4bea94376bb66fcabc1fa9872f9dc9b6febac2..7f95d16f09b5182e4da33763751ac87b53f41cf3 100644 --- a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -37,12 +37,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(label_dims[1], 1UL, - "The 2nd dimension of " - "Input(Label) should be 1."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], + "The 1st dimension of Input(X) and Input(Label) should " + "be equal."); + PADDLE_ENFORCE_EQ(label_dims[1], 1UL, + "The 2nd dimension of " + "Input(Label) should be 1."); + } ctx->SetOutputDim("Y", {x_dims[0], 1}); ctx->ShareLoD("X", /*->*/ "Y"); } @@ -99,17 +101,20 @@ class TeacherStudentSigmoidLossGradientOp PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], - "The 1st dimension of Input(X) and Input(Y@Grad) should " - "be equal."); - PADDLE_ENFORCE_EQ(dy_dims[1], 1, - "The 2nd dimension of Input(Y@Grad) should be 1."); - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "When Attr(soft_label) == false, the 2nd dimension of " - "Input(Label) should be 1."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], + "The 1st dimension of Input(X) and Input(Label) should " + "be equal."); + PADDLE_ENFORCE_EQ( + x_dims[0], dy_dims[0], + "The 1st dimension of Input(X) and Input(Y@Grad) should " + "be equal."); + PADDLE_ENFORCE_EQ(dy_dims[1], 1, + "The 2nd dimension of Input(Y@Grad) should be 1."); + PADDLE_ENFORCE_EQ(label_dims[1], 1, + "When Attr(soft_label) == false, the 2nd dimension of " + "Input(Label) should be 1."); + } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->ShareLoD("X", framework::GradVarName("X")); }