From 8c478b36e2f1c723d48b5c0e96fd77b7d950c467 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 25 Sep 2017 21:21:52 +0800 Subject: [PATCH] fix Atrr check --- paddle/operators/pool_op.cc | 81 +++++++++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 7 deletions(-) diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index e51249beb91..9c1b7ea41dd 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -101,7 +101,8 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { "poolingType of pooling operator." "str constant equal to 'max' or 'avg'"); AddAttr>( - "ksize", "pooling size(height, width) of pooling operator."); + "ksize", "pooling size(height, width) of pooling operator.") + .AddCustomChecker(GreaterThanChecker_pool({0, 0})); AddAttr( "globalPooling", "whether to use the globalPooling." @@ -112,17 +113,49 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("strides", "strides(height, width) of pooling operator." "default {1,1}") - .SetDefault({1, 1}); + .SetDefault({1, 1}) + .AddCustomChecker(GreaterThanChecker_pool({0, 0})); AddAttr>("paddings", "paddings(height, width) of pooling operator." "default {0,0}") - .SetDefault({0, 0}); - + .SetDefault({0, 0}) + .AddCustomChecker(EqualGreaterThanChecker_pool({0, 0})); AddComment(R"DOC( The pooling2d operation calculates the output based on the input, poolingType and ksize, strides, paddings parameters. )DOC"); } + + private: + struct GreaterThanChecker_pool { + public: + explicit GreaterThanChecker_pool(std::vector lower_bound) + : lower_bound_(lower_bound) {} + void operator()(std::vector &value) const { + PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); + for (size_t i = 0; i < value.size(); ++i) { + PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails."); + } + } + + private: + std::vector lower_bound_; + }; + + struct EqualGreaterThanChecker_pool { + public: + explicit EqualGreaterThanChecker_pool(std::vector lower_bound) + : lower_bound_(lower_bound) {} + void operator()(std::vector &value) const { + PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); + for (size_t i = 0; i < value.size(); ++i) { + PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails."); + } + } + + private: + std::vector lower_bound_; + }; }; class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -142,7 +175,8 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { "poolingType of pooling operator." "str constant equal to 'max' or 'avg'"); AddAttr>( - "ksize", "pooling size(depth, height, width) of pooling operator."); + "ksize", "pooling size(depth, height, width) of pooling operator.") + .AddCustomChecker(GreaterThanChecker_pool({0, 0, 0})); AddAttr( "globalPooling", "whether to use the globalPooling." @@ -154,17 +188,50 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { "strides", "strides(depth, height, width) of pooling operator." "default {1,1,1}") - .SetDefault({1, 1, 1}); + .SetDefault({1, 1, 1}) + .AddCustomChecker(GreaterThanChecker_pool({0, 0, 0})); AddAttr>( "paddings", "paddings(depth, height, width) of pooling operator." "default {0,0,0}") - .SetDefault({0, 0, 0}); + .SetDefault({0, 0, 0}) + .AddCustomChecker(EqualGreaterThanChecker_pool({0, 0, 0})); AddComment(R"DOC( The pooling3d operation calculates the output based on the input, poolingType and ksize, strides, paddings parameters. )DOC"); } + + private: + struct GreaterThanChecker_pool { + public: + explicit GreaterThanChecker_pool(std::vector lower_bound) + : lower_bound_(lower_bound) {} + void operator()(std::vector &value) const { + PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); + for (size_t i = 0; i < value.size(); ++i) { + PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails."); + } + } + + private: + std::vector lower_bound_; + }; + + struct EqualGreaterThanChecker_pool { + public: + explicit EqualGreaterThanChecker_pool(std::vector lower_bound) + : lower_bound_(lower_bound) {} + void operator()(std::vector &value) const { + PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails."); + for (size_t i = 0; i < value.size(); ++i) { + PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails."); + } + } + + private: + std::vector lower_bound_; + }; }; } // namespace operators } // namespace paddle -- GitLab