diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc index 694b00e493149f332e94d2353d3c13501a59ebd0..dc30644a5e7e33d4289e48cac093aa5fde7e75e7 100644 --- a/paddle/operators/squared_l2_distance_op.cc +++ b/paddle/operators/squared_l2_distance_op.cc @@ -40,7 +40,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { "inputs must be same."); int rank = framework::arity(x_dims); - PADDLE_ENFORCE(rank >= 2, "Tensor rank should be at least equal to 2."); + PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2."); PADDLE_ENFORCE_EQ(framework::product(x_dims) / x_dims[0], framework::product(y_dims) / y_dims[0], "Product of dimensions expcet the first dimension of " @@ -87,7 +87,6 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Gradient of Out should not be null"); - // check out grad dimensions auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); diff --git a/paddle/operators/squared_l2_distance_op.h b/paddle/operators/squared_l2_distance_op.h index 1015513bdf51525b9ff47c90625ff32ec4495b41..77c5a0a5c91c93fa9cd8760873f88af657475b2a 100644 --- a/paddle/operators/squared_l2_distance_op.h +++ b/paddle/operators/squared_l2_distance_op.h @@ -101,9 +101,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { auto y_grad = EigenMatrix::From(*y_g, framework::make_ddim({y_dims[0], cols})); - PADDLE_ENFORCE(sub_result.dimensions()[0] >= y_dims[0], - "First dimension of gradient must be greater or " - "equal than first dimension of target"); + PADDLE_ENFORCE_GE(sub_result.dimensions()[0], y_dims[0], + "First dimension of gradient must be greater or " + "equal than first dimension of target."); if (sub_result.dimensions()[0] == y_dims[0]) { y_grad.device(eigen_place) = -1 * grad_mat;