attr_checker.h 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
#pragma once

#include <boost/variant.hpp>
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/enforce.h"

namespace paddle {
namespace framework {

typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
                       std::vector<float>, std::vector<std::string>>
    Attribute;
typedef std::unordered_map<std::string, Attribute> AttributeMap;

// check whether a value(attribute) fit a certain limit
template <typename T>
class LargerThanChecker {
 public:
  LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
  void operator()(T& value) const {
    PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail");
  }

 private:
  T lower_bound_;
};

// we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'...

template <typename T>
class DefaultValueSetter {
 public:
  DefaultValueSetter(T default_value) : default_value_(default_value) {}
  void operator()(T& value) const { value = default_value_; }

 private:
  T default_value_;
};

// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template <typename T>
class TypedAttrChecker {
  typedef std::function<void(T&)> ValueChecker;

 public:
  TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}

  TypedAttrChecker& LargerThan(const T& lower_bound) {
    value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
    return *this;
  }

  // we can add more common limits, like LessThan(), Between()...

  TypedAttrChecker& SetDefault(const T& default_value) {
    PADDLE_ENFORCE(default_value_setter_.empty(),
                   "%s can't have more than one default value!", attr_name_);
    default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
    return *this;
  }

  // allow users provide their own checker
  TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
    value_checkers_.push_back(checker);
    return *this;
  }

  void operator()(AttributeMap& attr_map) const {
    if (!attr_map.count(attr_name_)) {
      // user do not set this attr
      PADDLE_ENFORCE(!default_value_setter_.empty(),
                     "Attribute '%s' is required!", attr_name_);
      // default_value_setter_ has no more than one element
      T val;
      (default_value_setter_[0])(val);
      attr_map[attr_name_] = val;
    }
    Attribute& attr = attr_map.at(attr_name_);
    T& attr_value = boost::get<T>(attr);
    for (const auto& checker : value_checkers_) {
      checker(attr_value);
    }
  }

 private:
  std::string attr_name_;
  std::vector<ValueChecker> value_checkers_;
  std::vector<ValueChecker> default_value_setter_;
};

// check whether op's all attributes fit their own limits
class OpAttrChecker {
  typedef std::function<void(AttributeMap&)> AttrChecker;

 public:
  template <typename T>
  TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name) {
    attr_checkers_.push_back(TypedAttrChecker<T>(attr_name));
    AttrChecker& checker = attr_checkers_.back();
    return *(checker.target<TypedAttrChecker<T>>());
  }

  void Check(AttributeMap& attr_map) const {
    for (const auto& checker : attr_checkers_) {
      checker(attr_map);
    }
  }

 private:
  std::vector<AttrChecker> attr_checkers_;
};

}  // namespace framework
}  // namespace paddle