diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index d9c76881b7e98d0b7cd29024b98c8f7720398c66..67054eccb3397ea40f0fb3e2ff2530ee1ea64736 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -165,7 +165,7 @@ template class GreaterThanChecker { public: explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} - void operator()(T& value) const { + void operator()(const T& value) const { PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails."); } @@ -177,7 +177,7 @@ template class EqualGreaterThanChecker { public: explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} - void operator()(T& value) const { + void operator()(const T& value) const { PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails."); } @@ -193,7 +193,7 @@ class DefaultValueSetter { public: explicit DefaultValueSetter(T default_value) : default_value_(default_value) {} - void operator()(T& value) const { value = default_value_; } // NOLINT + void operator()(T* value) const { *value = default_value_; } private: T default_value_; @@ -203,7 +203,7 @@ template class EnumInContainer { public: explicit EnumInContainer(const std::unordered_set& c) : container_(c) {} - void operator()(T& val) const { + void operator()(const T& val) const { PADDLE_ENFORCE(container_.find(val) != container_.end(), "Value %s is not in enum container %s", val, ContainerDebugString()); @@ -232,7 +232,8 @@ class EnumInContainer { // an attribute can have more than one limits template class TypedAttrChecker { - typedef std::function ValueChecker; + typedef std::function DefaultValueChecker; + typedef std::function ValueChecker; public: explicit TypedAttrChecker(const std::string& attr_name) @@ -268,17 +269,17 @@ class TypedAttrChecker { return *this; } - void operator()(AttributeMap& attr_map) const { // NOLINT - if (!attr_map.count(attr_name_)) { + void operator()(AttributeMap* attr_map) const { + if (!attr_map->count(attr_name_)) { // user do not set this attr PADDLE_ENFORCE(!default_value_setter_.empty(), "Attribute '%s' is required!", attr_name_); // default_value_setter_ has no more than one element T val; - (default_value_setter_[0])(val); - attr_map[attr_name_] = val; + (default_value_setter_[0])(&val); + (*attr_map)[attr_name_] = val; } - Attribute& attr = attr_map.at(attr_name_); + Attribute& attr = attr_map->at(attr_name_); ExtractAttribute extract_attr(attr_name_); T* attr_value = extract_attr(attr); for (const auto& checker : value_checkers_) { @@ -289,12 +290,12 @@ class TypedAttrChecker { private: std::string attr_name_; std::vector value_checkers_; - std::vector default_value_setter_; + std::vector default_value_setter_; }; // check whether op's all attributes fit their own limits class OpAttrChecker { - typedef std::function AttrChecker; + typedef std::function AttrChecker; public: template @@ -304,7 +305,7 @@ class OpAttrChecker { return *(checker.target>()); } - void Check(AttributeMap& attr_map) const { // NOLINT + void Check(AttributeMap* attr_map) const { for (const auto& checker : attr_checkers_) { checker(attr_map); } diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 2fe1c94ec02e8ff0a4acb81868ba2124ea89e506..0e7b0cbeb98f3b6bbf0b37f507fc6022be692bb1 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -643,7 +643,7 @@ void OpDesc::CheckAttrs() { // not by users. return; } - checker->Check(attrs_); + checker->Check(&attrs_); } void OpDesc::InferShape(const BlockDesc &block) const { diff --git a/paddle/fluid/framework/op_registry.cc b/paddle/fluid/framework/op_registry.cc index bfc411ca2c4a483e344b368da089392d8e4a87c1..346d14d408ea1ed2cfbdbed5f48e56902e6e95b2 100644 --- a/paddle/fluid/framework/op_registry.cc +++ b/paddle/fluid/framework/op_registry.cc @@ -24,7 +24,7 @@ std::unique_ptr OpRegistry::CreateOp( const VariableNameMap& outputs, AttributeMap attrs) { auto& info = OpInfoMap::Instance().Get(type); if (info.Checker() != nullptr) { - info.Checker()->Check(attrs); + info.Checker()->Check(&attrs); } auto op = info.Creator()(type, inputs, outputs, attrs); return std::unique_ptr(op); diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h index 8fceed3558b4357b7863368c18add329ea9922b3..57d6f4b3ea98d7437f7fa72ed724384a19bcea4a 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h @@ -52,7 +52,7 @@ class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker { "The maximum length of the sequence. If maxlen < 0, maxlen " "= max(Input(X)).") .SetDefault(-1) - .AddCustomChecker([](int &v) { + .AddCustomChecker([](const int &v) { PADDLE_ENFORCE(v < 0 || v >= 1, "Attr(maxlen) must be less than 0 or larger than 1"); });