global_value_getter_setter.cc 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/global_value_getter_setter.h"
16

17
#include <cctype>
18 19 20 21 22 23
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
24

25 26 27 28 29 30 31
#include "gflags/gflags.h"
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "pybind11/stl.h"

Z
Zeng Jinle 已提交
32 33 34 35 36 37 38 39 40 41
// FIXME(zengjinle): these 2 flags may be removed by the linker when compiling
// CPU-only Paddle. It is because they are only used in
// AutoGrowthBestFitAllocator, but AutoGrowthBestFitAllocator is not used
// (in the translation unit level) when compiling CPU-only Paddle. I do not
// want to add PADDLE_FORCE_LINK_FLAG, but I have not found any other methods
// to solve this problem.
PADDLE_FORCE_LINK_FLAG(free_idle_chunk);
PADDLE_FORCE_LINK_FLAG(free_when_no_cache_hit);

// NOTE: where are these 2 flags from?
G
guofei 已提交
42 43 44
#ifdef PADDLE_WITH_DISTRIBUTE
DECLARE_int32(rpc_get_thread_num);
DECLARE_int32(rpc_prefetch_thread_num);
45
#endif
46 47 48 49 50 51 52 53 54 55 56 57 58

namespace paddle {
namespace pybind {

namespace py = pybind11;

class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
  DISABLE_COPY_AND_ASSIGN(GlobalVarGetterSetterRegistry);

  GlobalVarGetterSetterRegistry() = default;

 public:
  using Getter = std::function<py::object()>;
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  using Setter = std::function<void(const py::object &)>;

  template <typename T>
  static Getter CreateGetter(const T &var) {
    return [&]() -> py::object { return py::cast(var); };
  }

  template <typename T>
  static Setter CreateSetter(T *var) {
    return [var](const py::object &obj) { *var = py::cast<T>(obj); };
  }

 private:
  struct VarInfo {
    VarInfo(bool is_public, const Getter &getter)
        : is_public(is_public), getter(getter) {}

    VarInfo(bool is_public, const Getter &getter, const Setter &setter)
        : is_public(is_public), getter(getter), setter(setter) {}
78

79 80 81 82 83 84
    const bool is_public;
    const Getter getter;
    const Setter setter;
  };

 public:
85 86 87 88
  static const GlobalVarGetterSetterRegistry &Instance() { return instance_; }

  static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; }

89
  void Register(const std::string &name, bool is_public, const Getter &getter) {
90
    PADDLE_ENFORCE_EQ(
91 92
        HasGetterMethod(name),
        false,
93 94
        platform::errors::AlreadyExists(
            "Getter of global variable %s has been registered", name));
95 96 97 98
    PADDLE_ENFORCE_NOT_NULL(getter,
                            platform::errors::InvalidArgument(
                                "Getter of %s should not be null", name));
    var_infos_.insert({name, VarInfo(is_public, getter)});
99 100
  }

101 102 103
  void Register(const std::string &name,
                bool is_public,
                const Getter &getter,
104
                const Setter &setter) {
105
    PADDLE_ENFORCE_EQ(
106 107
        HasGetterMethod(name),
        false,
108 109
        platform::errors::AlreadyExists(
            "Getter of global variable %s has been registered", name));
110 111

    PADDLE_ENFORCE_EQ(
112 113
        HasSetterMethod(name),
        false,
114 115
        platform::errors::AlreadyExists(
            "Setter of global variable %s has been registered", name));
116 117 118 119 120 121 122 123 124

    PADDLE_ENFORCE_NOT_NULL(getter,
                            platform::errors::InvalidArgument(
                                "Getter of %s should not be null", name));

    PADDLE_ENFORCE_NOT_NULL(setter,
                            platform::errors::InvalidArgument(
                                "Setter of %s should not be null", name));
    var_infos_.insert({name, VarInfo(is_public, getter, setter)});
125 126 127 128
  }

  const Getter &GetterMethod(const std::string &name) const {
    PADDLE_ENFORCE_EQ(
129 130
        HasGetterMethod(name),
        true,
131
        platform::errors::NotFound("Cannot find global variable %s", name));
132
    return var_infos_.at(name).getter;
133 134 135
  }

  py::object GetOrReturnDefaultValue(const std::string &name,
136
                                     const py::object &default_value) const {
137 138 139 140 141 142 143
    if (HasGetterMethod(name)) {
      return GetterMethod(name)();
    } else {
      return default_value;
    }
  }

144
  py::object Get(const std::string &name) const { return GetterMethod(name)(); }
145 146 147

  const Setter &SetterMethod(const std::string &name) const {
    PADDLE_ENFORCE_EQ(
148 149
        HasSetterMethod(name),
        true,
150
        platform::errors::NotFound("Global variable %s is not writable", name));
151
    return var_infos_.at(name).setter;
152 153 154
  }

  void Set(const std::string &name, const py::object &value) const {
155
    VLOG(4) << "set " << name << " to " << value;
156 157 158 159
    SetterMethod(name)(value);
  }

  bool HasGetterMethod(const std::string &name) const {
160
    return var_infos_.count(name) > 0;
161 162 163
  }

  bool HasSetterMethod(const std::string &name) const {
164 165 166 167 168
    return var_infos_.count(name) > 0 && var_infos_.at(name).setter;
  }

  bool IsPublic(const std::string &name) const {
    return var_infos_.count(name) > 0 && var_infos_.at(name).is_public;
169 170 171 172
  }

  std::unordered_set<std::string> Keys() const {
    std::unordered_set<std::string> keys;
173 174
    keys.reserve(var_infos_.size());
    for (auto &pair : var_infos_) {
175 176 177 178 179 180 181 182
      keys.insert(pair.first);
    }
    return keys;
  }

 private:
  static GlobalVarGetterSetterRegistry instance_;

183
  std::unordered_map<std::string, VarInfo> var_infos_;
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
};

GlobalVarGetterSetterRegistry GlobalVarGetterSetterRegistry::instance_;

static void RegisterGlobalVarGetterSetter();

void BindGlobalValueGetterSetter(pybind11::module *module) {
  RegisterGlobalVarGetterSetter();

  py::class_<GlobalVarGetterSetterRegistry>(*module,
                                            "GlobalVarGetterSetterRegistry")
      .def("__getitem__", &GlobalVarGetterSetterRegistry::Get)
      .def("__setitem__", &GlobalVarGetterSetterRegistry::Set)
      .def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod)
      .def("keys", &GlobalVarGetterSetterRegistry::Keys)
199
      .def("is_public", &GlobalVarGetterSetterRegistry::IsPublic)
200 201 202 203
      .def("get",
           &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue,
           py::arg("key"),
           py::arg("default") = py::cast<py::none>(Py_None));
204

205 206
  module->def("globals",
              &GlobalVarGetterSetterRegistry::Instance,
207 208 209
              py::return_value_policy::reference);
}

210
/* Public vars are designed to be writable. */
Z
Zeng Jinle 已提交
211 212 213
#define REGISTER_PUBLIC_GLOBAL_VAR(var)                                    \
  do {                                                                     \
    auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();     \
214 215
    instance->Register(#var,                                               \
                       /*is_public=*/true,                                 \
Z
Zeng Jinle 已提交
216 217
                       GlobalVarGetterSetterRegistry::CreateGetter(var),   \
                       GlobalVarGetterSetterRegistry::CreateSetter(&var)); \
218 219
  } while (0)

Z
Zeng Jinle 已提交
220
struct RegisterGetterSetterVisitor : public boost::static_visitor<void> {
221 222
  RegisterGetterSetterVisitor(const std::string &name,
                              bool is_writable,
Z
Zeng Jinle 已提交
223 224
                              void *value_ptr)
      : name_(name), is_writable_(is_writable), value_ptr_(value_ptr) {}
225

Z
Zeng Jinle 已提交
226 227 228 229 230 231
  template <typename T>
  void operator()(const T &) const {
    auto &value = *static_cast<T *>(value_ptr_);
    auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();
    bool is_public = is_writable_;  // currently, all writable vars are public
    if (is_writable_) {
232 233
      instance->Register(name_,
                         is_public,
Z
Zeng Jinle 已提交
234 235 236
                         GlobalVarGetterSetterRegistry::CreateGetter(value),
                         GlobalVarGetterSetterRegistry::CreateSetter(&value));
    } else {
237 238
      instance->Register(
          name_, is_public, GlobalVarGetterSetterRegistry::CreateGetter(value));
Z
Zeng Jinle 已提交
239 240
    }
  }
241

Z
Zeng Jinle 已提交
242 243 244 245 246 247 248
 private:
  std::string name_;
  bool is_writable_;
  void *value_ptr_;
};

static void RegisterGlobalVarGetterSetter() {
G
guofei 已提交
249
#ifdef PADDLE_WITH_DITRIBUTE
Z
Zeng Jinle 已提交
250 251
  REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_get_thread_num);
  REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_prefetch_thread_num);
252
#endif
Z
Zeng Jinle 已提交
253 254 255 256 257 258 259

  const auto &flag_map = platform::GetExportedFlagInfoMap();
  for (const auto &pair : flag_map) {
    const std::string &name = pair.second.name;
    bool is_writable = pair.second.is_writable;
    void *value_ptr = pair.second.value_ptr;
    const auto &default_value = pair.second.default_value;
260 261
    RegisterGetterSetterVisitor visitor(
        "FLAGS_" + name, is_writable, value_ptr);
R
Ruibiao Chen 已提交
262
    paddle::visit(visitor, default_value);
Z
Zeng Jinle 已提交
263
  }
264
}
Z
Zeng Jinle 已提交
265

266 267
}  // namespace pybind
}  // namespace paddle