global_value_getter_setter.cc 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/global_value_getter_setter.h"
16

17
#include <cctype>
18 19 20 21 22 23
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
24

25 26 27 28 29 30 31
#include "gflags/gflags.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "pybind11/stl.h"

Z
Zeng Jinle 已提交
32 33 34 35 36 37 38 39 40 41
// FIXME(zengjinle): these 2 flags may be removed by the linker when compiling
// CPU-only Paddle. It is because they are only used in
// AutoGrowthBestFitAllocator, but AutoGrowthBestFitAllocator is not used
// (in the translation unit level) when compiling CPU-only Paddle. I do not
// want to add PADDLE_FORCE_LINK_FLAG, but I have not found any other methods
// to solve this problem.
PADDLE_FORCE_LINK_FLAG(free_idle_chunk);
PADDLE_FORCE_LINK_FLAG(free_when_no_cache_hit);

// NOTE: where are these 2 flags from?
G
guofei 已提交
42 43 44
#ifdef PADDLE_WITH_DISTRIBUTE
DECLARE_int32(rpc_get_thread_num);
DECLARE_int32(rpc_prefetch_thread_num);
45
#endif
46 47 48 49 50 51 52 53 54 55 56 57 58

namespace paddle {
namespace pybind {

namespace py = pybind11;

class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
  DISABLE_COPY_AND_ASSIGN(GlobalVarGetterSetterRegistry);

  GlobalVarGetterSetterRegistry() = default;

 public:
  using Getter = std::function<py::object()>;
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  using Setter = std::function<void(const py::object &)>;

  template <typename T>
  static Getter CreateGetter(const T &var) {
    return [&]() -> py::object { return py::cast(var); };
  }

  template <typename T>
  static Setter CreateSetter(T *var) {
    return [var](const py::object &obj) { *var = py::cast<T>(obj); };
  }

 private:
  struct VarInfo {
    VarInfo(bool is_public, const Getter &getter)
        : is_public(is_public), getter(getter) {}

    VarInfo(bool is_public, const Getter &getter, const Setter &setter)
        : is_public(is_public), getter(getter), setter(setter) {}
78

79 80 81 82 83 84
    const bool is_public;
    const Getter getter;
    const Setter setter;
  };

 public:
85 86 87 88
  static const GlobalVarGetterSetterRegistry &Instance() { return instance_; }

  static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; }

89
  void Register(const std::string &name, bool is_public, const Getter &getter) {
90
    PADDLE_ENFORCE_EQ(
91
        HasGetterMethod(name), false,
92 93
        platform::errors::AlreadyExists(
            "Getter of global variable %s has been registered", name));
94 95 96 97
    PADDLE_ENFORCE_NOT_NULL(getter,
                            platform::errors::InvalidArgument(
                                "Getter of %s should not be null", name));
    var_infos_.insert({name, VarInfo(is_public, getter)});
98 99
  }

100 101
  void Register(const std::string &name, bool is_public, const Getter &getter,
                const Setter &setter) {
102
    PADDLE_ENFORCE_EQ(
103 104 105
        HasGetterMethod(name), false,
        platform::errors::AlreadyExists(
            "Getter of global variable %s has been registered", name));
106 107

    PADDLE_ENFORCE_EQ(
108
        HasSetterMethod(name), false,
109 110
        platform::errors::AlreadyExists(
            "Setter of global variable %s has been registered", name));
111 112 113 114 115 116 117 118 119

    PADDLE_ENFORCE_NOT_NULL(getter,
                            platform::errors::InvalidArgument(
                                "Getter of %s should not be null", name));

    PADDLE_ENFORCE_NOT_NULL(setter,
                            platform::errors::InvalidArgument(
                                "Setter of %s should not be null", name));
    var_infos_.insert({name, VarInfo(is_public, getter, setter)});
120 121 122 123 124 125
  }

  const Getter &GetterMethod(const std::string &name) const {
    PADDLE_ENFORCE_EQ(
        HasGetterMethod(name), true,
        platform::errors::NotFound("Cannot find global variable %s", name));
126
    return var_infos_.at(name).getter;
127 128 129
  }

  py::object GetOrReturnDefaultValue(const std::string &name,
130
                                     const py::object &default_value) const {
131 132 133 134 135 136 137
    if (HasGetterMethod(name)) {
      return GetterMethod(name)();
    } else {
      return default_value;
    }
  }

138
  py::object Get(const std::string &name) const { return GetterMethod(name)(); }
139 140 141 142 143

  const Setter &SetterMethod(const std::string &name) const {
    PADDLE_ENFORCE_EQ(
        HasSetterMethod(name), true,
        platform::errors::NotFound("Global variable %s is not writable", name));
144
    return var_infos_.at(name).setter;
145 146 147 148 149 150 151
  }

  void Set(const std::string &name, const py::object &value) const {
    SetterMethod(name)(value);
  }

  bool HasGetterMethod(const std::string &name) const {
152
    return var_infos_.count(name) > 0;
153 154 155
  }

  bool HasSetterMethod(const std::string &name) const {
156 157 158 159 160
    return var_infos_.count(name) > 0 && var_infos_.at(name).setter;
  }

  bool IsPublic(const std::string &name) const {
    return var_infos_.count(name) > 0 && var_infos_.at(name).is_public;
161 162 163 164
  }

  std::unordered_set<std::string> Keys() const {
    std::unordered_set<std::string> keys;
165 166
    keys.reserve(var_infos_.size());
    for (auto &pair : var_infos_) {
167 168 169 170 171 172 173 174
      keys.insert(pair.first);
    }
    return keys;
  }

 private:
  static GlobalVarGetterSetterRegistry instance_;

175
  std::unordered_map<std::string, VarInfo> var_infos_;
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
};

GlobalVarGetterSetterRegistry GlobalVarGetterSetterRegistry::instance_;

static void RegisterGlobalVarGetterSetter();

void BindGlobalValueGetterSetter(pybind11::module *module) {
  RegisterGlobalVarGetterSetter();

  py::class_<GlobalVarGetterSetterRegistry>(*module,
                                            "GlobalVarGetterSetterRegistry")
      .def("__getitem__", &GlobalVarGetterSetterRegistry::Get)
      .def("__setitem__", &GlobalVarGetterSetterRegistry::Set)
      .def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod)
      .def("keys", &GlobalVarGetterSetterRegistry::Keys)
191
      .def("is_public", &GlobalVarGetterSetterRegistry::IsPublic)
192 193 194 195 196 197 198
      .def("get", &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue,
           py::arg("key"), py::arg("default") = py::cast<py::none>(Py_None));

  module->def("globals", &GlobalVarGetterSetterRegistry::Instance,
              py::return_value_policy::reference);
}

199
/* Public vars are designed to be writable. */
Z
Zeng Jinle 已提交
200 201 202 203 204 205
#define REGISTER_PUBLIC_GLOBAL_VAR(var)                                    \
  do {                                                                     \
    auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();     \
    instance->Register(#var, /*is_public=*/true,                           \
                       GlobalVarGetterSetterRegistry::CreateGetter(var),   \
                       GlobalVarGetterSetterRegistry::CreateSetter(&var)); \
206 207
  } while (0)

Z
Zeng Jinle 已提交
208 209 210 211
struct RegisterGetterSetterVisitor : public boost::static_visitor<void> {
  RegisterGetterSetterVisitor(const std::string &name, bool is_writable,
                              void *value_ptr)
      : name_(name), is_writable_(is_writable), value_ptr_(value_ptr) {}
212

Z
Zeng Jinle 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226
  template <typename T>
  void operator()(const T &) const {
    auto &value = *static_cast<T *>(value_ptr_);
    auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();
    bool is_public = is_writable_;  // currently, all writable vars are public
    if (is_writable_) {
      instance->Register(name_, is_public,
                         GlobalVarGetterSetterRegistry::CreateGetter(value),
                         GlobalVarGetterSetterRegistry::CreateSetter(&value));
    } else {
      instance->Register(name_, is_public,
                         GlobalVarGetterSetterRegistry::CreateGetter(value));
    }
  }
227

Z
Zeng Jinle 已提交
228 229 230 231 232 233 234
 private:
  std::string name_;
  bool is_writable_;
  void *value_ptr_;
};

static void RegisterGlobalVarGetterSetter() {
G
guofei 已提交
235
#ifdef PADDLE_WITH_DITRIBUTE
Z
Zeng Jinle 已提交
236 237
  REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_get_thread_num);
  REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_prefetch_thread_num);
238
#endif
Z
Zeng Jinle 已提交
239 240 241 242 243 244 245 246 247 248 249

  const auto &flag_map = platform::GetExportedFlagInfoMap();
  for (const auto &pair : flag_map) {
    const std::string &name = pair.second.name;
    bool is_writable = pair.second.is_writable;
    void *value_ptr = pair.second.value_ptr;
    const auto &default_value = pair.second.default_value;
    RegisterGetterSetterVisitor visitor("FLAGS_" + name, is_writable,
                                        value_ptr);
    boost::apply_visitor(visitor, default_value);
  }
250
}
Z
Zeng Jinle 已提交
251

252 253
}  // namespace pybind
}  // namespace paddle