提交 a377b419 编写于 作者: Y yangyaming

Follow GLOG enforcing style.

上级 3970f240
...@@ -40,7 +40,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -40,7 +40,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
"inputs must be same."); "inputs must be same.");
int rank = framework::arity(x_dims); int rank = framework::arity(x_dims);
PADDLE_ENFORCE(rank >= 2, "Tensor rank should be at least equal to 2."); PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2.");
PADDLE_ENFORCE_EQ(framework::product(x_dims) / x_dims[0], PADDLE_ENFORCE_EQ(framework::product(x_dims) / x_dims[0],
framework::product(y_dims) / y_dims[0], framework::product(y_dims) / y_dims[0],
"Product of dimensions expcet the first dimension of " "Product of dimensions expcet the first dimension of "
...@@ -87,7 +87,6 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { ...@@ -87,7 +87,6 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Gradient of Out should not be null"); "Gradient of Out should not be null");
// check out grad dimensions
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx.Input<Tensor>("Y")->dims();
......
...@@ -101,9 +101,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { ...@@ -101,9 +101,9 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
auto y_grad = auto y_grad =
EigenMatrix<T>::From(*y_g, framework::make_ddim({y_dims[0], cols})); EigenMatrix<T>::From(*y_g, framework::make_ddim({y_dims[0], cols}));
PADDLE_ENFORCE(sub_result.dimensions()[0] >= y_dims[0], PADDLE_ENFORCE_GE(sub_result.dimensions()[0], y_dims[0],
"First dimension of gradient must be greater or " "First dimension of gradient must be greater or "
"equal than first dimension of target"); "equal than first dimension of target.");
if (sub_result.dimensions()[0] == y_dims[0]) { if (sub_result.dimensions()[0] == y_dims[0]) {
y_grad.device(eigen_place) = -1 * grad_mat; y_grad.device(eigen_place) = -1 * grad_mat;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册