#pragma once #include #include #include #include #include #include "paddle/framework/enforce.h" namespace paddle { namespace framework { typedef boost::variant, std::vector, std::vector> Attribute; typedef std::unordered_map AttributeMap; // check whether a value(attribute) fit a certain limit template class LargerThanChecker { public: LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail"); } private: T lower_bound_; }; // we can provide users more common Checker, like 'LessThanChecker', // 'BetweenChecker'... template class DefaultValueSetter { public: DefaultValueSetter(T default_value) : default_value_(default_value) {} void operator()(T& value) const { value = default_value_; } private: T default_value_; }; // check whether a certain attribute fit its limits // an attribute can have more than one limits template class TypedAttrChecker { typedef std::function ValueChecker; public: TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} TypedAttrChecker& LargerThan(const T& lower_bound) { value_checkers_.push_back(LargerThanChecker(lower_bound)); return *this; } // we can add more common limits, like LessThan(), Between()... TypedAttrChecker& SetDefault(const T& default_value) { PADDLE_ENFORCE(default_value_setter_.empty(), "%s can't have more than one default value!", attr_name_); default_value_setter_.push_back(DefaultValueSetter(default_value)); return *this; } // allow users provide their own checker TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) { value_checkers_.push_back(checker); return *this; } void operator()(AttributeMap& attr_map) const { if (!attr_map.count(attr_name_)) { // user do not set this attr PADDLE_ENFORCE(!default_value_setter_.empty(), "Attribute '%s' is required!", attr_name_); // default_value_setter_ has no more than one element T val; (default_value_setter_[0])(val); attr_map[attr_name_] = val; } Attribute& attr = attr_map.at(attr_name_); T& attr_value = boost::get(attr); for (const auto& checker : value_checkers_) { checker(attr_value); } } private: std::string attr_name_; std::vector value_checkers_; std::vector default_value_setter_; }; // check whether op's all attributes fit their own limits class OpAttrChecker { typedef std::function AttrChecker; public: template TypedAttrChecker& AddAttrChecker(const std::string& attr_name) { attr_checkers_.push_back(TypedAttrChecker(attr_name)); AttrChecker& checker = attr_checkers_.back(); return *(checker.target>()); } void Check(AttributeMap& attr_map) const { for (const auto& checker : attr_checkers_) { checker(attr_map); } } private: std::vector attr_checkers_; }; } // namespace framework } // namespace paddle