global_value_getter_setter.cc 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/global_value_getter_setter.h"
16

17
#include <cctype>
18 19 20 21 22 23
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
24

25 26 27 28 29
#include "gflags/gflags.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
30
#include "paddle/phi/core/macros.h"
31 32
#include "pybind11/stl.h"

Z
Zeng Jinle 已提交
33 34 35 36 37 38 39 40 41 42
// FIXME(zengjinle): these 2 flags may be removed by the linker when compiling
// CPU-only Paddle. It is because they are only used in
// AutoGrowthBestFitAllocator, but AutoGrowthBestFitAllocator is not used
// (in the translation unit level) when compiling CPU-only Paddle. I do not
// want to add PADDLE_FORCE_LINK_FLAG, but I have not found any other methods
// to solve this problem.
PADDLE_FORCE_LINK_FLAG(free_idle_chunk);
PADDLE_FORCE_LINK_FLAG(free_when_no_cache_hit);

// NOTE: where are these 2 flags from?
G
guofei 已提交
43 44 45
#ifdef PADDLE_WITH_DISTRIBUTE
DECLARE_int32(rpc_get_thread_num);
DECLARE_int32(rpc_prefetch_thread_num);
46
#endif
47 48 49 50 51 52 53 54 55 56 57 58 59

namespace paddle {
namespace pybind {

namespace py = pybind11;

class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
  DISABLE_COPY_AND_ASSIGN(GlobalVarGetterSetterRegistry);

  GlobalVarGetterSetterRegistry() = default;

 public:
  using Getter = std::function<py::object()>;
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
  using Setter = std::function<void(const py::object &)>;

  template <typename T>
  static Getter CreateGetter(const T &var) {
    return [&]() -> py::object { return py::cast(var); };
  }

  template <typename T>
  static Setter CreateSetter(T *var) {
    return [var](const py::object &obj) { *var = py::cast<T>(obj); };
  }

 private:
  struct VarInfo {
    VarInfo(bool is_public, const Getter &getter)
        : is_public(is_public), getter(getter) {}

    VarInfo(bool is_public, const Getter &getter, const Setter &setter)
        : is_public(is_public), getter(getter), setter(setter) {}
79

80 81 82 83 84 85
    const bool is_public;
    const Getter getter;
    const Setter setter;
  };

 public:
86 87 88 89
  static const GlobalVarGetterSetterRegistry &Instance() { return instance_; }

  static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; }

90
  void Register(const std::string &name, bool is_public, const Getter &getter) {
91
    PADDLE_ENFORCE_EQ(
92 93
        HasGetterMethod(name),
        false,
94 95
        platform::errors::AlreadyExists(
            "Getter of global variable %s has been registered", name));
96 97 98 99
    PADDLE_ENFORCE_NOT_NULL(getter,
                            platform::errors::InvalidArgument(
                                "Getter of %s should not be null", name));
    var_infos_.insert({name, VarInfo(is_public, getter)});
100 101
  }

102 103 104
  void Register(const std::string &name,
                bool is_public,
                const Getter &getter,
105
                const Setter &setter) {
106
    PADDLE_ENFORCE_EQ(
107 108
        HasGetterMethod(name),
        false,
109 110
        platform::errors::AlreadyExists(
            "Getter of global variable %s has been registered", name));
111 112

    PADDLE_ENFORCE_EQ(
113 114
        HasSetterMethod(name),
        false,
115 116
        platform::errors::AlreadyExists(
            "Setter of global variable %s has been registered", name));
117 118 119 120 121 122 123 124 125

    PADDLE_ENFORCE_NOT_NULL(getter,
                            platform::errors::InvalidArgument(
                                "Getter of %s should not be null", name));

    PADDLE_ENFORCE_NOT_NULL(setter,
                            platform::errors::InvalidArgument(
                                "Setter of %s should not be null", name));
    var_infos_.insert({name, VarInfo(is_public, getter, setter)});
126 127 128 129
  }

  const Getter &GetterMethod(const std::string &name) const {
    PADDLE_ENFORCE_EQ(
130 131
        HasGetterMethod(name),
        true,
132
        platform::errors::NotFound("Cannot find global variable %s", name));
133
    return var_infos_.at(name).getter;
134 135 136
  }

  py::object GetOrReturnDefaultValue(const std::string &name,
137
                                     const py::object &default_value) const {
138 139 140 141 142 143 144
    if (HasGetterMethod(name)) {
      return GetterMethod(name)();
    } else {
      return default_value;
    }
  }

145
  py::object Get(const std::string &name) const { return GetterMethod(name)(); }
146 147 148

  const Setter &SetterMethod(const std::string &name) const {
    PADDLE_ENFORCE_EQ(
149 150
        HasSetterMethod(name),
        true,
151
        platform::errors::NotFound("Global variable %s is not writable", name));
152
    return var_infos_.at(name).setter;
153 154 155
  }

  void Set(const std::string &name, const py::object &value) const {
156
    VLOG(4) << "set " << name << " to " << value;
157 158 159 160
    SetterMethod(name)(value);
  }

  bool HasGetterMethod(const std::string &name) const {
161
    return var_infos_.count(name) > 0;
162 163 164
  }

  bool HasSetterMethod(const std::string &name) const {
165 166 167 168 169
    return var_infos_.count(name) > 0 && var_infos_.at(name).setter;
  }

  bool IsPublic(const std::string &name) const {
    return var_infos_.count(name) > 0 && var_infos_.at(name).is_public;
170 171 172 173
  }

  std::unordered_set<std::string> Keys() const {
    std::unordered_set<std::string> keys;
174 175
    keys.reserve(var_infos_.size());
    for (auto &pair : var_infos_) {
176 177 178 179 180 181 182 183
      keys.insert(pair.first);
    }
    return keys;
  }

 private:
  static GlobalVarGetterSetterRegistry instance_;

184
  std::unordered_map<std::string, VarInfo> var_infos_;
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
};

GlobalVarGetterSetterRegistry GlobalVarGetterSetterRegistry::instance_;

static void RegisterGlobalVarGetterSetter();

void BindGlobalValueGetterSetter(pybind11::module *module) {
  RegisterGlobalVarGetterSetter();

  py::class_<GlobalVarGetterSetterRegistry>(*module,
                                            "GlobalVarGetterSetterRegistry")
      .def("__getitem__", &GlobalVarGetterSetterRegistry::Get)
      .def("__setitem__", &GlobalVarGetterSetterRegistry::Set)
      .def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod)
      .def("keys", &GlobalVarGetterSetterRegistry::Keys)
200
      .def("is_public", &GlobalVarGetterSetterRegistry::IsPublic)
201 202 203 204
      .def("get",
           &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue,
           py::arg("key"),
           py::arg("default") = py::cast<py::none>(Py_None));
205

206 207
  module->def("globals",
              &GlobalVarGetterSetterRegistry::Instance,
208 209 210
              py::return_value_policy::reference);
}

211
/* Public vars are designed to be writable. */
Z
Zeng Jinle 已提交
212 213 214
#define REGISTER_PUBLIC_GLOBAL_VAR(var)                                    \
  do {                                                                     \
    auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();     \
215 216
    instance->Register(#var,                                               \
                       /*is_public=*/true,                                 \
Z
Zeng Jinle 已提交
217 218
                       GlobalVarGetterSetterRegistry::CreateGetter(var),   \
                       GlobalVarGetterSetterRegistry::CreateSetter(&var)); \
219 220
  } while (0)

221
struct RegisterGetterSetterVisitor {
222 223
  RegisterGetterSetterVisitor(const std::string &name,
                              bool is_writable,
Z
Zeng Jinle 已提交
224 225
                              void *value_ptr)
      : name_(name), is_writable_(is_writable), value_ptr_(value_ptr) {}
226

Z
Zeng Jinle 已提交
227 228 229 230 231 232
  template <typename T>
  void operator()(const T &) const {
    auto &value = *static_cast<T *>(value_ptr_);
    auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();
    bool is_public = is_writable_;  // currently, all writable vars are public
    if (is_writable_) {
233 234
      instance->Register(name_,
                         is_public,
Z
Zeng Jinle 已提交
235 236 237
                         GlobalVarGetterSetterRegistry::CreateGetter(value),
                         GlobalVarGetterSetterRegistry::CreateSetter(&value));
    } else {
238 239
      instance->Register(
          name_, is_public, GlobalVarGetterSetterRegistry::CreateGetter(value));
Z
Zeng Jinle 已提交
240 241
    }
  }
242

Z
Zeng Jinle 已提交
243 244 245 246 247 248 249
 private:
  std::string name_;
  bool is_writable_;
  void *value_ptr_;
};

static void RegisterGlobalVarGetterSetter() {
G
guofei 已提交
250
#ifdef PADDLE_WITH_DITRIBUTE
Z
Zeng Jinle 已提交
251 252
  REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_get_thread_num);
  REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_prefetch_thread_num);
253
#endif
Z
Zeng Jinle 已提交
254

255
  const auto &flag_map = phi::GetExportedFlagInfoMap();
Z
Zeng Jinle 已提交
256 257 258 259 260
  for (const auto &pair : flag_map) {
    const std::string &name = pair.second.name;
    bool is_writable = pair.second.is_writable;
    void *value_ptr = pair.second.value_ptr;
    const auto &default_value = pair.second.default_value;
261 262
    RegisterGetterSetterVisitor visitor(
        "FLAGS_" + name, is_writable, value_ptr);
R
Ruibiao Chen 已提交
263
    paddle::visit(visitor, default_value);
Z
Zeng Jinle 已提交
264
  }
265
}
Z
Zeng Jinle 已提交
266

267 268
}  // namespace pybind
}  // namespace paddle