diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc index 930c295a9cb31238954efeb87ff5ac2d3ca7bdc6..51b5bcb38f9d60b1246f818de62275dba5b087f9 100644 --- a/paddle/operators/compare_op.cc +++ b/paddle/operators/compare_op.cc @@ -58,8 +58,8 @@ class CompareOpInferShape : public framework::InferShapeBase { comment.type); auto dim_x = context->GetInputDim("X"); auto dim_y = context->GetInputDim("Y"); - PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y), - "The number of elements in X and Y should be same"); + PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(), + "The size of dim_y should not be greater than dim_x's."); context->SetOutputDim("Out", context->GetInputDim("X")); context->ShareLoD("X", "Out");