提交 8c478b36 编写于 作者: C chengduoZH

fix Atrr check

上级 b7285438
...@@ -101,7 +101,8 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -101,7 +101,8 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"poolingType of pooling operator." "poolingType of pooling operator."
"str constant equal to 'max' or 'avg'"); "str constant equal to 'max' or 'avg'");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "pooling size(height, width) of pooling operator."); "ksize", "pooling size(height, width) of pooling operator.")
.AddCustomChecker(GreaterThanChecker_pool({0, 0}));
AddAttr<int>( AddAttr<int>(
"globalPooling", "globalPooling",
"whether to use the globalPooling." "whether to use the globalPooling."
...@@ -112,17 +113,49 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -112,17 +113,49 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"strides(height, width) of pooling operator." "strides(height, width) of pooling operator."
"default {1,1}") "default {1,1}")
.SetDefault({1, 1}); .SetDefault({1, 1})
.AddCustomChecker(GreaterThanChecker_pool({0, 0}));
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>("paddings",
"paddings(height, width) of pooling operator." "paddings(height, width) of pooling operator."
"default {0,0}") "default {0,0}")
.SetDefault({0, 0}); .SetDefault({0, 0})
.AddCustomChecker(EqualGreaterThanChecker_pool({0, 0}));
AddComment(R"DOC( AddComment(R"DOC(
The pooling2d operation calculates the output based on The pooling2d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters. the input, poolingType and ksize, strides, paddings parameters.
)DOC"); )DOC");
} }
private:
struct GreaterThanChecker_pool {
public:
explicit GreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
struct EqualGreaterThanChecker_pool {
public:
explicit EqualGreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
}; };
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -142,7 +175,8 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -142,7 +175,8 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
"poolingType of pooling operator." "poolingType of pooling operator."
"str constant equal to 'max' or 'avg'"); "str constant equal to 'max' or 'avg'");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "pooling size(depth, height, width) of pooling operator."); "ksize", "pooling size(depth, height, width) of pooling operator.")
.AddCustomChecker(GreaterThanChecker_pool({0, 0, 0}));
AddAttr<int>( AddAttr<int>(
"globalPooling", "globalPooling",
"whether to use the globalPooling." "whether to use the globalPooling."
...@@ -154,17 +188,50 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -154,17 +188,50 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
"strides", "strides",
"strides(depth, height, width) of pooling operator." "strides(depth, height, width) of pooling operator."
"default {1,1,1}") "default {1,1,1}")
.SetDefault({1, 1, 1}); .SetDefault({1, 1, 1})
.AddCustomChecker(GreaterThanChecker_pool({0, 0, 0}));
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"paddings(depth, height, width) of pooling operator." "paddings(depth, height, width) of pooling operator."
"default {0,0,0}") "default {0,0,0}")
.SetDefault({0, 0, 0}); .SetDefault({0, 0, 0})
.AddCustomChecker(EqualGreaterThanChecker_pool({0, 0, 0}));
AddComment(R"DOC( AddComment(R"DOC(
The pooling3d operation calculates the output based on The pooling3d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters. the input, poolingType and ksize, strides, paddings parameters.
)DOC"); )DOC");
} }
private:
struct GreaterThanChecker_pool {
public:
explicit GreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
struct EqualGreaterThanChecker_pool {
public:
explicit EqualGreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册