diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 6968ffd838e070084a0ba585c84f6977f620f974..2b788a76cafe198abb9aed8ba842e37cc6ff73a6 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -53,11 +53,11 @@ class GreaterThanChecker { }; template -class EqualLargerThanChecker { +class EqualGreaterThanChecker { public: - explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} + explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { - PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fails."); + PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails."); } private: @@ -127,8 +127,8 @@ class TypedAttrChecker { return *this; } - TypedAttrChecker& EqualLargerThan(const T& lower_bound) { - value_checkers_.push_back(EqualLargerThanChecker(lower_bound)); + TypedAttrChecker& EqualGreaterThan(const T& lower_bound) { + value_checkers_.push_back(EqualGreaterThanChecker(lower_bound)); return *this; } diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 34595adedda745cb8e1137aaa99009a07bb29433..710a56a0e8e2d17162d7d000df226f1537104eb9 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -65,14 +65,14 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { will be the product of tensor's first `rank - num_col_dims` dimensions. )DOC") .SetDefault(1) - .EqualLargerThan(1); + .EqualGreaterThan(1); AddAttr( "y_num_col_dims", R"DOC(mul_op can take tensors with more than two dimensions as input `Y`, in that case, tensors will be reshaped to a matrix. Just like input `X`. )DOC") .SetDefault(1) - .EqualLargerThan(1); + .EqualGreaterThan(1); AddComment(R"DOC( Two Element Mul Operator. diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index d8057f4ffaff0d29d20083596e6395cccb8e005e..8c827e242e866b267e0fc4b73c31bafa0ccc7c48 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -99,7 +99,5 @@ class TestMulGradTest2(GradientChecker): no_grad_set={"Y"}) -# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library - if __name__ == '__main__': unittest.main()