未验证 提交 dc8eca82 编写于 作者: T tangwei12 提交者: GitHub

code style fix, test=develop (#15045)

* code style fix, test=develop
上级 55e3c651
......@@ -165,7 +165,7 @@ template <typename T>
class GreaterThanChecker {
public:
explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const {
void operator()(const T& value) const {
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails.");
}
......@@ -177,7 +177,7 @@ template <typename T>
class EqualGreaterThanChecker {
public:
explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const {
void operator()(const T& value) const {
PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails.");
}
......@@ -193,7 +193,7 @@ class DefaultValueSetter {
public:
explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {}
void operator()(T& value) const { value = default_value_; } // NOLINT
void operator()(T* value) const { *value = default_value_; }
private:
T default_value_;
......@@ -203,7 +203,7 @@ template <typename T>
class EnumInContainer {
public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(T& val) const {
void operator()(const T& val) const {
PADDLE_ENFORCE(container_.find(val) != container_.end(),
"Value %s is not in enum container %s", val,
ContainerDebugString());
......@@ -232,7 +232,8 @@ class EnumInContainer {
// an attribute can have more than one limits
template <typename T>
class TypedAttrChecker {
typedef std::function<void(T&)> ValueChecker;
typedef std::function<void(T*)> DefaultValueChecker;
typedef std::function<void(const T&)> ValueChecker;
public:
explicit TypedAttrChecker(const std::string& attr_name)
......@@ -268,17 +269,17 @@ class TypedAttrChecker {
return *this;
}
void operator()(AttributeMap& attr_map) const { // NOLINT
if (!attr_map.count(attr_name_)) {
void operator()(AttributeMap* attr_map) const {
if (!attr_map->count(attr_name_)) {
// user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(),
"Attribute '%s' is required!", attr_name_);
// default_value_setter_ has no more than one element
T val;
(default_value_setter_[0])(val);
attr_map[attr_name_] = val;
(default_value_setter_[0])(&val);
(*attr_map)[attr_name_] = val;
}
Attribute& attr = attr_map.at(attr_name_);
Attribute& attr = attr_map->at(attr_name_);
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(attr);
for (const auto& checker : value_checkers_) {
......@@ -289,12 +290,12 @@ class TypedAttrChecker {
private:
std::string attr_name_;
std::vector<ValueChecker> value_checkers_;
std::vector<ValueChecker> default_value_setter_;
std::vector<DefaultValueChecker> default_value_setter_;
};
// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap&)> AttrChecker;
typedef std::function<void(AttributeMap*)> AttrChecker;
public:
template <typename T>
......@@ -304,7 +305,7 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>());
}
void Check(AttributeMap& attr_map) const { // NOLINT
void Check(AttributeMap* attr_map) const {
for (const auto& checker : attr_checkers_) {
checker(attr_map);
}
......
......@@ -643,7 +643,7 @@ void OpDesc::CheckAttrs() {
// not by users.
return;
}
checker->Check(attrs_);
checker->Check(&attrs_);
}
void OpDesc::InferShape(const BlockDesc &block) const {
......
......@@ -24,7 +24,7 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const VariableNameMap& outputs, AttributeMap attrs) {
auto& info = OpInfoMap::Instance().Get(type);
if (info.Checker() != nullptr) {
info.Checker()->Check(attrs);
info.Checker()->Check(&attrs);
}
auto op = info.Creator()(type, inputs, outputs, attrs);
return std::unique_ptr<OperatorBase>(op);
......
......@@ -52,7 +52,7 @@ class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker {
"The maximum length of the sequence. If maxlen < 0, maxlen "
"= max(Input(X)).")
.SetDefault(-1)
.AddCustomChecker([](int &v) {
.AddCustomChecker([](const int &v) {
PADDLE_ENFORCE(v < 0 || v >= 1,
"Attr(maxlen) must be less than 0 or larger than 1");
});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册