global_value_getter_setter.cc 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/global_value_getter_setter.h"
16

17
#include <cctype>
18 19 20 21 22 23
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
24

25 26 27 28 29 30 31
#include "gflags/gflags.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "pybind11/stl.h"

S
sneaxiy 已提交
32
// NOTE: where are these 2 flags from?
G
guofei 已提交
33 34 35
#ifdef PADDLE_WITH_DISTRIBUTE
DECLARE_int32(rpc_get_thread_num);
DECLARE_int32(rpc_prefetch_thread_num);
36
#endif
37 38 39 40 41 42 43 44 45 46 47 48 49

namespace paddle {
namespace pybind {

namespace py = pybind11;

class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
  DISABLE_COPY_AND_ASSIGN(GlobalVarGetterSetterRegistry);

  GlobalVarGetterSetterRegistry() = default;

 public:
  using Getter = std::function<py::object()>;
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
  using Setter = std::function<void(const py::object &)>;

  template <typename T>
  static Getter CreateGetter(const T &var) {
    return [&]() -> py::object { return py::cast(var); };
  }

  template <typename T>
  static Setter CreateSetter(T *var) {
    return [var](const py::object &obj) { *var = py::cast<T>(obj); };
  }

 private:
  struct VarInfo {
    VarInfo(bool is_public, const Getter &getter)
        : is_public(is_public), getter(getter) {}

    VarInfo(bool is_public, const Getter &getter, const Setter &setter)
        : is_public(is_public), getter(getter), setter(setter) {}
69

70 71 72 73 74 75
    const bool is_public;
    const Getter getter;
    const Setter setter;
  };

 public:
76 77 78 79
  static const GlobalVarGetterSetterRegistry &Instance() { return instance_; }

  static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; }

80
  void Register(const std::string &name, bool is_public, const Getter &getter) {
81
    PADDLE_ENFORCE_EQ(
82
        HasGetterMethod(name), false,
83 84
        platform::errors::AlreadyExists(
            "Getter of global variable %s has been registered", name));
85 86 87 88
    PADDLE_ENFORCE_NOT_NULL(getter,
                            platform::errors::InvalidArgument(
                                "Getter of %s should not be null", name));
    var_infos_.insert({name, VarInfo(is_public, getter)});
89 90
  }

91 92
  void Register(const std::string &name, bool is_public, const Getter &getter,
                const Setter &setter) {
93
    PADDLE_ENFORCE_EQ(
94 95 96
        HasGetterMethod(name), false,
        platform::errors::AlreadyExists(
            "Getter of global variable %s has been registered", name));
97 98

    PADDLE_ENFORCE_EQ(
99
        HasSetterMethod(name), false,
100 101
        platform::errors::AlreadyExists(
            "Setter of global variable %s has been registered", name));
102 103 104 105 106 107 108 109 110

    PADDLE_ENFORCE_NOT_NULL(getter,
                            platform::errors::InvalidArgument(
                                "Getter of %s should not be null", name));

    PADDLE_ENFORCE_NOT_NULL(setter,
                            platform::errors::InvalidArgument(
                                "Setter of %s should not be null", name));
    var_infos_.insert({name, VarInfo(is_public, getter, setter)});
111 112 113 114 115 116
  }

  const Getter &GetterMethod(const std::string &name) const {
    PADDLE_ENFORCE_EQ(
        HasGetterMethod(name), true,
        platform::errors::NotFound("Cannot find global variable %s", name));
117
    return var_infos_.at(name).getter;
118 119 120
  }

  py::object GetOrReturnDefaultValue(const std::string &name,
121
                                     const py::object &default_value) const {
122 123 124 125 126 127 128
    if (HasGetterMethod(name)) {
      return GetterMethod(name)();
    } else {
      return default_value;
    }
  }

129
  py::object Get(const std::string &name) const { return GetterMethod(name)(); }
130 131 132 133 134

  const Setter &SetterMethod(const std::string &name) const {
    PADDLE_ENFORCE_EQ(
        HasSetterMethod(name), true,
        platform::errors::NotFound("Global variable %s is not writable", name));
135
    return var_infos_.at(name).setter;
136 137 138 139 140 141 142
  }

  void Set(const std::string &name, const py::object &value) const {
    SetterMethod(name)(value);
  }

  bool HasGetterMethod(const std::string &name) const {
143
    return var_infos_.count(name) > 0;
144 145 146
  }

  bool HasSetterMethod(const std::string &name) const {
147 148 149 150 151
    return var_infos_.count(name) > 0 && var_infos_.at(name).setter;
  }

  bool IsPublic(const std::string &name) const {
    return var_infos_.count(name) > 0 && var_infos_.at(name).is_public;
152 153 154 155
  }

  std::unordered_set<std::string> Keys() const {
    std::unordered_set<std::string> keys;
156 157
    keys.reserve(var_infos_.size());
    for (auto &pair : var_infos_) {
158 159 160 161 162 163 164 165
      keys.insert(pair.first);
    }
    return keys;
  }

 private:
  static GlobalVarGetterSetterRegistry instance_;

166
  std::unordered_map<std::string, VarInfo> var_infos_;
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
};

GlobalVarGetterSetterRegistry GlobalVarGetterSetterRegistry::instance_;

static void RegisterGlobalVarGetterSetter();

void BindGlobalValueGetterSetter(pybind11::module *module) {
  RegisterGlobalVarGetterSetter();

  py::class_<GlobalVarGetterSetterRegistry>(*module,
                                            "GlobalVarGetterSetterRegistry")
      .def("__getitem__", &GlobalVarGetterSetterRegistry::Get)
      .def("__setitem__", &GlobalVarGetterSetterRegistry::Set)
      .def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod)
      .def("keys", &GlobalVarGetterSetterRegistry::Keys)
182
      .def("is_public", &GlobalVarGetterSetterRegistry::IsPublic)
183 184 185 186 187 188 189
      .def("get", &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue,
           py::arg("key"), py::arg("default") = py::cast<py::none>(Py_None));

  module->def("globals", &GlobalVarGetterSetterRegistry::Instance,
              py::return_value_policy::reference);
}

190
/* Public vars are designed to be writable. */
S
sneaxiy 已提交
191 192 193 194 195 196
#define REGISTER_PUBLIC_GLOBAL_VAR(var)                                    \
  do {                                                                     \
    auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();     \
    instance->Register(#var, /*is_public=*/true,                           \
                       GlobalVarGetterSetterRegistry::CreateGetter(var),   \
                       GlobalVarGetterSetterRegistry::CreateSetter(&var)); \
197 198
  } while (0)

S
sneaxiy 已提交
199
struct RegisterGetterSetterVisitor : public boost::static_visitor<void> {
S
sneaxiy 已提交
200
  RegisterGetterSetterVisitor(const std::string &name, bool is_writable,
S
sneaxiy 已提交
201
                              void *value_ptr)
S
sneaxiy 已提交
202
      : name_(name), is_writable_(is_writable), value_ptr_(value_ptr) {}
203

S
sneaxiy 已提交
204 205 206 207
  template <typename T>
  void operator()(const T &) const {
    auto &value = *static_cast<T *>(value_ptr_);
    auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();
S
sneaxiy 已提交
208 209 210 211 212 213 214 215 216
    bool is_public = is_writable_;  // currently, all writable vars are public
    if (is_public) {
      instance->Register(name_, is_public,
                         GlobalVarGetterSetterRegistry::CreateGetter(value),
                         GlobalVarGetterSetterRegistry::CreateSetter(&value));
    } else {
      instance->Register(name_, is_public,
                         GlobalVarGetterSetterRegistry::CreateGetter(value));
    }
S
sneaxiy 已提交
217
  }
218

S
sneaxiy 已提交
219 220
 private:
  std::string name_;
S
sneaxiy 已提交
221
  bool is_writable_;
S
sneaxiy 已提交
222 223
  void *value_ptr_;
};
224

S
sneaxiy 已提交
225
static void RegisterGlobalVarGetterSetter() {
G
guofei 已提交
226
#ifdef PADDLE_WITH_DITRIBUTE
S
sneaxiy 已提交
227 228
  REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_get_thread_num);
  REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_prefetch_thread_num);
229
#endif
S
sneaxiy 已提交
230 231 232 233 234 235 236 237 238 239 240

  const auto &flag_map = platform::GetExportedFlagInfoMap();
  for (const auto &pair : flag_map) {
    const std::string &name = pair.second.name;
    bool is_writable = pair.second.is_writable;
    void *value_ptr = const_cast<void *>(pair.second.value_ptr);
    const auto &default_value = pair.second.default_value;
    RegisterGetterSetterVisitor visitor("FLAGS_" + name, is_writable,
                                        value_ptr);
    boost::apply_visitor(visitor, default_value);
  }
241
}
S
sneaxiy 已提交
242

243 244
}  // namespace pybind
}  // namespace paddle