/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include "paddle/framework/framework.pb.h" #include "paddle/platform/enforce.h" #include "paddle/platform/variant.h" namespace paddle { namespace framework { // The order should be as same as framework.proto typedef boost::variant, std::vector, std::vector, bool, std::vector, BlockDesc*> Attribute; typedef std::unordered_map AttributeMap; ProgramDesc& GetProgramDesc(); template inline AttrType AttrTypeID() { Attribute tmp = T(); return static_cast(tmp.which() - 1); } Attribute GetAttrValue(const OpDesc::Attr& attr_desc); class AttrReader { public: explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} template inline const T& Get(const std::string& name) const { PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", name); return boost::get(attrs_.at(name)); } private: const AttributeMap& attrs_; }; // check whether a value(attribute) fit a certain limit template class GreaterThanChecker { public: explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails."); } private: T lower_bound_; }; template class EqualGreaterThanChecker { public: explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails."); } private: T lower_bound_; }; // we can provide users more common Checker, like 'LessThanChecker', // 'BetweenChecker'... template class DefaultValueSetter { public: explicit DefaultValueSetter(T default_value) : default_value_(default_value) {} void operator()(T& value) const { value = default_value_; } private: T default_value_; }; template class EnumInContainer { public: explicit EnumInContainer(const std::unordered_set& c) : container_(c) {} void operator()(T& val) const { PADDLE_ENFORCE(container_.find(val) != container_.end(), "Value %s is not in enum container %s", val, ContainerDebugString()); } private: std::string ContainerDebugString() const { std::ostringstream sout; sout << "["; size_t cnt = 0; for (auto& v : container_) { sout << v; ++cnt; if (cnt != container_.size()) { sout << " ,"; } } sout << "]"; return sout.str(); } std::unordered_set container_; }; // check whether a certain attribute fit its limits // an attribute can have more than one limits template class TypedAttrChecker { typedef std::function ValueChecker; public: explicit TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} TypedAttrChecker& InEnum(const std::unordered_set& range) { value_checkers_.push_back(EnumInContainer(range)); return *this; } TypedAttrChecker& GreaterThan(const T& lower_bound) { value_checkers_.push_back(GreaterThanChecker(lower_bound)); return *this; } TypedAttrChecker& EqualGreaterThan(const T& lower_bound) { value_checkers_.push_back(EqualGreaterThanChecker(lower_bound)); return *this; } // we can add more common limits, like LessThan(), Between()... TypedAttrChecker& SetDefault(const T& default_value) { PADDLE_ENFORCE(default_value_setter_.empty(), "%s can't have more than one default value!", attr_name_); default_value_setter_.push_back(DefaultValueSetter(default_value)); return *this; } // allow users provide their own checker TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) { value_checkers_.push_back(checker); return *this; } void operator()(AttributeMap& attr_map) const { if (!attr_map.count(attr_name_)) { // user do not set this attr PADDLE_ENFORCE(!default_value_setter_.empty(), "Attribute '%s' is required!", attr_name_); // default_value_setter_ has no more than one element T val; (default_value_setter_[0])(val); attr_map[attr_name_] = val; } Attribute& attr = attr_map.at(attr_name_); T& attr_value = boost::get(attr); for (const auto& checker : value_checkers_) { checker(attr_value); } } private: std::string attr_name_; std::vector value_checkers_; std::vector default_value_setter_; }; // check whether op's all attributes fit their own limits class OpAttrChecker { typedef std::function AttrChecker; public: template TypedAttrChecker& AddAttrChecker(const std::string& attr_name) { attr_checkers_.push_back(TypedAttrChecker(attr_name)); AttrChecker& checker = attr_checkers_.back(); return *(checker.target>()); } void Check(AttributeMap& attr_map) const { for (const auto& checker : attr_checkers_) { checker(attr_map); } } private: std::vector attr_checkers_; }; } // namespace framework } // namespace paddle