diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index d13530e3408a54c7ecab87c3bd9e6288e342f9af..8a7a949346e73ca9d2a813ca2888755a23bb7d7b 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -120,6 +120,57 @@ class EnumInContainer { std::unordered_set container_; }; +template +struct ExtractAttribute { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + T* operator()(Attribute& attr) const { + T* attr_value = nullptr; + try { + attr_value = &boost::get(attr); + } catch (boost::bad_get& bad_get) { + PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", + attr_name_, typeid(T).name(), attr.type().name()); + } + return attr_value; + } + + const std::string& attr_name_; +}; + +// special handle bool +// FIXME(yuyang18): Currently we cast bool into int in python binding. It is +// hard to change the logic there. In another way, we should correct handle +// if the user set `some_flag=1`. +// +// FIX ME anytime if there is a better solution. +template <> +struct ExtractAttribute { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + bool* operator()(Attribute& attr) const { + if (attr.type() == typeid(int)) { // NOLINT + int val = boost::get(attr); + attr = static_cast(val); + } else if (attr.type() == typeid(float)) { // NOLINT + float val = boost::get(attr); + attr = static_cast(val); + } + bool* attr_value = nullptr; + try { + attr_value = &boost::get(attr); + } catch (boost::bad_get& bad_get) { + PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", + attr_name_, attr.type().name()); + } + return attr_value; + } + + const std::string& attr_name_; +}; + // check whether a certain attribute fit its limits // an attribute can have more than one limits template @@ -171,9 +222,10 @@ class TypedAttrChecker { attr_map[attr_name_] = val; } Attribute& attr = attr_map.at(attr_name_); - T& attr_value = boost::get(attr); + ExtractAttribute extract_attr(attr_name_); + T* attr_value = extract_attr(attr); for (const auto& checker : value_checkers_) { - checker(attr_value); + checker(*attr_value); } }