diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index 2bc419f49b1be61d5032c6e6c3e448c97ebe9e57..e39ff42b44795ac1c49c4a712d02ee2e5c710b1f 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pybind/global_value_getter_setter.h" +#include #include #include #include @@ -50,43 +51,76 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { GlobalVarGetterSetterRegistry() = default; public: - using Setter = std::function; using Getter = std::function; + using Setter = std::function; + + template + static Getter CreateGetter(const T &var) { + return [&]() -> py::object { return py::cast(var); }; + } + + template + static Setter CreateSetter(T *var) { + return [var](const py::object &obj) { *var = py::cast(obj); }; + } + + private: + struct VarInfo { + VarInfo(bool is_public, const Getter &getter) + : is_public(is_public), getter(getter) {} + + VarInfo(bool is_public, const Getter &getter, const Setter &setter) + : is_public(is_public), getter(getter), setter(setter) {} + const bool is_public; + const Getter getter; + const Setter setter; + }; + + public: static const GlobalVarGetterSetterRegistry &Instance() { return instance_; } static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; } - void RegisterGetter(const std::string &name, Getter func) { + void Register(const std::string &name, bool is_public, const Getter &getter) { PADDLE_ENFORCE_EQ( - getters_.count(name), 0, + HasGetterMethod(name), false, platform::errors::AlreadyExists( "Getter of global variable %s has been registered", name)); - PADDLE_ENFORCE_NOT_NULL(func, platform::errors::InvalidArgument( - "Getter of %s should not be null", name)); - getters_[name] = std::move(func); + PADDLE_ENFORCE_NOT_NULL(getter, + platform::errors::InvalidArgument( + "Getter of %s should not be null", name)); + var_infos_.insert({name, VarInfo(is_public, getter)}); } - void RegisterSetter(const std::string &name, Setter func) { + void Register(const std::string &name, bool is_public, const Getter &getter, + const Setter &setter) { PADDLE_ENFORCE_EQ( - HasGetterMethod(name), true, - platform::errors::NotFound( - "Cannot register setter for %s before register getter", name)); + HasGetterMethod(name), false, + platform::errors::AlreadyExists( + "Getter of global variable %s has been registered", name)); PADDLE_ENFORCE_EQ( - setters_.count(name), 0, + HasSetterMethod(name), false, platform::errors::AlreadyExists( "Setter of global variable %s has been registered", name)); - PADDLE_ENFORCE_NOT_NULL(func, platform::errors::InvalidArgument( - "Setter of %s should not be null", name)); - setters_[name] = std::move(func); + + PADDLE_ENFORCE_NOT_NULL(getter, + platform::errors::InvalidArgument( + "Getter of %s should not be null", name)); + + PADDLE_ENFORCE_NOT_NULL(setter, + platform::errors::InvalidArgument( + "Setter of %s should not be null", name)); + + var_infos_.insert({name, VarInfo(is_public, getter, setter)}); } const Getter &GetterMethod(const std::string &name) const { PADDLE_ENFORCE_EQ( HasGetterMethod(name), true, platform::errors::NotFound("Cannot find global variable %s", name)); - return getters_.at(name); + return var_infos_.at(name).getter; } py::object GetOrReturnDefaultValue(const std::string &name, @@ -104,7 +138,7 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { PADDLE_ENFORCE_EQ( HasSetterMethod(name), true, platform::errors::NotFound("Global variable %s is not writable", name)); - return setters_.at(name); + return var_infos_.at(name).setter; } void Set(const std::string &name, const py::object &value) const { @@ -112,17 +146,21 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { } bool HasGetterMethod(const std::string &name) const { - return getters_.count(name) > 0; + return var_infos_.count(name) > 0; } bool HasSetterMethod(const std::string &name) const { - return setters_.count(name) > 0; + return var_infos_.count(name) > 0 && var_infos_.at(name).setter; + } + + bool IsPublic(const std::string &name) const { + return var_infos_.count(name) > 0 && var_infos_.at(name).is_public; } std::unordered_set Keys() const { std::unordered_set keys; - keys.reserve(getters_.size()); - for (auto &pair : getters_) { + keys.reserve(var_infos_.size()); + for (auto &pair : var_infos_) { keys.insert(pair.first); } return keys; @@ -131,12 +169,86 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { private: static GlobalVarGetterSetterRegistry instance_; - std::unordered_map getters_; - std::unordered_map setters_; + std::unordered_map var_infos_; }; GlobalVarGetterSetterRegistry GlobalVarGetterSetterRegistry::instance_; +class GlobalVarGetterSetterRegistryHelper { + public: + GlobalVarGetterSetterRegistryHelper(bool is_public, bool is_writable, + const std::string &var_names) + : is_public_(is_public), + is_writable_(is_writable), + var_names_(SplitVarNames(var_names)) {} + + template + void Register(Args &&... args) const { + Impl<0, sizeof...(args) == 1, Args...>::Register( + is_public_, is_writable_, var_names_, std::forward(args)...); + } + + private: + static std::vector SplitVarNames(const std::string &names) { + auto valid_char = [](char ch) { return !std::isspace(ch) && ch != ','; }; + + std::vector ret; + size_t i = 0, j = 0, n = names.size(); + while (i < n) { + for (; i < n && !valid_char(names[i]); ++i) { + } + for (j = i + 1; j < n && valid_char(names[j]); ++j) { + } + + if (i < n && j <= n) { + auto substring = names.substr(i, j - i); + VLOG(10) << "Get substring: \"" << substring << "\""; + ret.emplace_back(substring); + } + i = j + 1; + } + return ret; + } + + private: + template + struct Impl { + static void Register(bool is_public, bool is_writable, + const std::vector &var_names, T &&var, + Args &&... args) { + PADDLE_ENFORCE_EQ(kIdx + 1 + sizeof...(args), var_names.size(), + platform::errors::InvalidArgument( + "Argument number not match name number")); + Impl::Register(is_public, is_writable, var_names, var); + Impl::Register( + is_public, is_writable, var_names, std::forward(args)...); + } + }; + + template + struct Impl { + static void Register(bool is_public, bool is_writable, + const std::vector &var_names, T &&var) { + auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); + if (is_writable) { + instance->Register( + var_names[kIdx], is_public, + GlobalVarGetterSetterRegistry::CreateGetter(std::forward(var)), + GlobalVarGetterSetterRegistry::CreateSetter(&var)); + } else { + instance->Register( + var_names[kIdx], is_public, + GlobalVarGetterSetterRegistry::CreateGetter(std::forward(var))); + } + } + }; + + private: + const bool is_public_; + const bool is_writable_; + const std::vector var_names_; +}; + static void RegisterGlobalVarGetterSetter(); void BindGlobalValueGetterSetter(pybind11::module *module) { @@ -148,6 +260,7 @@ void BindGlobalValueGetterSetter(pybind11::module *module) { .def("__setitem__", &GlobalVarGetterSetterRegistry::Set) .def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod) .def("keys", &GlobalVarGetterSetterRegistry::Keys) + .def("is_public", &GlobalVarGetterSetterRegistry::IsPublic) .def("get", &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue, py::arg("key"), py::arg("default") = py::cast(Py_None)); @@ -155,33 +268,33 @@ void BindGlobalValueGetterSetter(pybind11::module *module) { py::return_value_policy::reference); } -#define REGISTER_GLOBAL_VAR_GETTER_ONLY(var) \ - GlobalVarGetterSetterRegistry::MutableInstance()->RegisterGetter( \ - #var, []() -> py::object { return py::cast(var); }) +/* Public vars are designed to be writable. */ +#define REGISTER_PUBLIC_GLOBAL_VAR(...) \ + do { \ + GlobalVarGetterSetterRegistryHelper(/*is_public=*/true, \ + /*is_writable=*/true, "" #__VA_ARGS__) \ + .Register(__VA_ARGS__); \ + } while (0) + +#define REGISTER_PRIVATE_GLOBAL_VAR(is_writable, ...) \ + do { \ + GlobalVarGetterSetterRegistryHelper(/*is_public=*/false, is_writable, \ + "" #__VA_ARGS__) \ + .Register(__VA_ARGS__); \ + } while (0) -#define REGISTER_GLOBAL_VAR_SETTER_ONLY(var) \ - GlobalVarGetterSetterRegistry::MutableInstance()->RegisterSetter( \ - #var, [](const py::object &obj) { \ - using ValueType = std::remove_reference::type; \ - var = py::cast(obj); \ - }) +static void RegisterGlobalVarGetterSetter() { + REGISTER_PRIVATE_GLOBAL_VAR(/*is_writable=*/false, FLAGS_use_mkldnn, + FLAGS_use_ngraph, FLAGS_free_idle_chunk, + FLAGS_free_when_no_cache_hit); -#define REGISTER_GLOBAL_VAR_GETTER_SETTER(var) \ - REGISTER_GLOBAL_VAR_GETTER_ONLY(var); \ - REGISTER_GLOBAL_VAR_SETTER_ONLY(var) + REGISTER_PUBLIC_GLOBAL_VAR( + FLAGS_eager_delete_tensor_gb, FLAGS_enable_parallel_graph, + FLAGS_allocator_strategy, FLAGS_use_system_allocator); -static void RegisterGlobalVarGetterSetter() { - REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_mkldnn); - REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_ngraph); - REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_eager_delete_tensor_gb); - REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_use_system_allocator); - REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_allocator_strategy); - REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_enable_parallel_graph); - REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_idle_chunk); - REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_when_no_cache_hit); #ifdef PADDLE_WITH_CUDA - REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_gpu_memory_limit_mb); - REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_cudnn_deterministic); + REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_gpu_memory_limit_mb, + FLAGS_cudnn_deterministic); #endif }