diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index e51249beb913a8a75c3f2868c26019343102c226..9c1b7ea41dd1ed33b88799a665165faec847ae92 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -101,7 +101,8 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { "poolingType of pooling operator." "str constant equal to 'max' or 'avg'"); AddAttr>( - "ksize", "pooling size(height, width) of pooling operator."); + "ksize", "pooling size(height, width) of pooling operator.") + .AddCustomChecker(GreaterThanChecker_pool({0, 0})); AddAttr( "globalPooling", "whether to use the globalPooling." @@ -112,17 +113,49 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("strides", "strides(height, width) of pooling operator." "default {1,1}") - .SetDefault({1, 1}); + .SetDefault({1, 1}) + .AddCustomChecker(GreaterThanChecker_pool({0, 0})); AddAttr>("paddings", "paddings(height, width) of pooling operator." "default {0,0}") - .SetDefault({0, 0}); - + .SetDefault({0, 0}) + .AddCustomChecker(EqualGreaterThanChecker_pool({0, 0})); AddComment(R"DOC( The pooling2d operation calculates the output based on the input, poolingType and ksize, strides, paddings parameters. )DOC"); } + + private: + struct GreaterThanChecker_pool { + public: + explicit GreaterThanChecker_pool(std::vector lower_bound) + : lower_bound_(lower_bound) {} + void operator()(std::vector &value) const { + PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); + for (size_t i = 0; i < value.size(); ++i) { + PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails."); + } + } + + private: + std::vector lower_bound_; + }; + + struct EqualGreaterThanChecker_pool { + public: + explicit EqualGreaterThanChecker_pool(std::vector lower_bound) + : lower_bound_(lower_bound) {} + void operator()(std::vector &value) const { + PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); + for (size_t i = 0; i < value.size(); ++i) { + PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails."); + } + } + + private: + std::vector lower_bound_; + }; }; class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -142,7 +175,8 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { "poolingType of pooling operator." "str constant equal to 'max' or 'avg'"); AddAttr>( - "ksize", "pooling size(depth, height, width) of pooling operator."); + "ksize", "pooling size(depth, height, width) of pooling operator.") + .AddCustomChecker(GreaterThanChecker_pool({0, 0, 0})); AddAttr( "globalPooling", "whether to use the globalPooling." @@ -154,17 +188,50 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { "strides", "strides(depth, height, width) of pooling operator." "default {1,1,1}") - .SetDefault({1, 1, 1}); + .SetDefault({1, 1, 1}) + .AddCustomChecker(GreaterThanChecker_pool({0, 0, 0})); AddAttr>( "paddings", "paddings(depth, height, width) of pooling operator." "default {0,0,0}") - .SetDefault({0, 0, 0}); + .SetDefault({0, 0, 0}) + .AddCustomChecker(EqualGreaterThanChecker_pool({0, 0, 0})); AddComment(R"DOC( The pooling3d operation calculates the output based on the input, poolingType and ksize, strides, paddings parameters. )DOC"); } + + private: + struct GreaterThanChecker_pool { + public: + explicit GreaterThanChecker_pool(std::vector lower_bound) + : lower_bound_(lower_bound) {} + void operator()(std::vector &value) const { + PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); + for (size_t i = 0; i < value.size(); ++i) { + PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails."); + } + } + + private: + std::vector lower_bound_; + }; + + struct EqualGreaterThanChecker_pool { + public: + explicit EqualGreaterThanChecker_pool(std::vector lower_bound) + : lower_bound_(lower_bound) {} + void operator()(std::vector &value) const { + PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); + for (size_t i = 0; i < value.size(); ++i) { + PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails."); + } + } + + private: + std::vector lower_bound_; + }; }; } // namespace operators } // namespace paddle