提交 b7444306 编写于 作者: F fengjiayi

Follow comments

上级 d7c8bdc8
...@@ -53,11 +53,11 @@ class GreaterThanChecker { ...@@ -53,11 +53,11 @@ class GreaterThanChecker {
}; };
template <typename T> template <typename T>
class EqualLargerThanChecker { class EqualGreaterThanChecker {
public: public:
explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const { void operator()(T& value) const {
PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fails."); PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails.");
} }
private: private:
...@@ -127,8 +127,8 @@ class TypedAttrChecker { ...@@ -127,8 +127,8 @@ class TypedAttrChecker {
return *this; return *this;
} }
TypedAttrChecker& EqualLargerThan(const T& lower_bound) { TypedAttrChecker& EqualGreaterThan(const T& lower_bound) {
value_checkers_.push_back(EqualLargerThanChecker<T>(lower_bound)); value_checkers_.push_back(EqualGreaterThanChecker<T>(lower_bound));
return *this; return *this;
} }
......
...@@ -65,14 +65,14 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -65,14 +65,14 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
will be the product of tensor's first `rank - num_col_dims` dimensions. will be the product of tensor's first `rank - num_col_dims` dimensions.
)DOC") )DOC")
.SetDefault(1) .SetDefault(1)
.EqualLargerThan(1); .EqualGreaterThan(1);
AddAttr<int>( AddAttr<int>(
"y_num_col_dims", "y_num_col_dims",
R"DOC(mul_op can take tensors with more than two dimensions as input `Y`, R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
in that case, tensors will be reshaped to a matrix. Just like input `X`. in that case, tensors will be reshaped to a matrix. Just like input `X`.
)DOC") )DOC")
.SetDefault(1) .SetDefault(1)
.EqualLargerThan(1); .EqualGreaterThan(1);
AddComment(R"DOC( AddComment(R"DOC(
Two Element Mul Operator. Two Element Mul Operator.
......
...@@ -99,7 +99,5 @@ class TestMulGradTest2(GradientChecker): ...@@ -99,7 +99,5 @@ class TestMulGradTest2(GradientChecker):
no_grad_set={"Y"}) no_grad_set={"Y"})
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册