未验证 提交 babda94c 编写于 作者: Z Zeng Jinle 提交者: GitHub

Distinguish public/private global vars (#23269)

* distinguish public/private vars, test=develop

* fix windows issues, test=develop
上级 f8205ffa
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/pybind/global_value_getter_setter.h"
#include <cctype>
#include <functional>
#include <string>
#include <unordered_map>
......@@ -50,43 +51,76 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
GlobalVarGetterSetterRegistry() = default;
public:
using Setter = std::function<void(const py::object &)>;
using Getter = std::function<py::object()>;
using Setter = std::function<void(const py::object &)>;
template <typename T>
static Getter CreateGetter(const T &var) {
return [&]() -> py::object { return py::cast(var); };
}
template <typename T>
static Setter CreateSetter(T *var) {
return [var](const py::object &obj) { *var = py::cast<T>(obj); };
}
private:
struct VarInfo {
VarInfo(bool is_public, const Getter &getter)
: is_public(is_public), getter(getter) {}
VarInfo(bool is_public, const Getter &getter, const Setter &setter)
: is_public(is_public), getter(getter), setter(setter) {}
const bool is_public;
const Getter getter;
const Setter setter;
};
public:
static const GlobalVarGetterSetterRegistry &Instance() { return instance_; }
static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; }
void RegisterGetter(const std::string &name, Getter func) {
void Register(const std::string &name, bool is_public, const Getter &getter) {
PADDLE_ENFORCE_EQ(
getters_.count(name), 0,
HasGetterMethod(name), false,
platform::errors::AlreadyExists(
"Getter of global variable %s has been registered", name));
PADDLE_ENFORCE_NOT_NULL(func, platform::errors::InvalidArgument(
"Getter of %s should not be null", name));
getters_[name] = std::move(func);
PADDLE_ENFORCE_NOT_NULL(getter,
platform::errors::InvalidArgument(
"Getter of %s should not be null", name));
var_infos_.insert({name, VarInfo(is_public, getter)});
}
void RegisterSetter(const std::string &name, Setter func) {
void Register(const std::string &name, bool is_public, const Getter &getter,
const Setter &setter) {
PADDLE_ENFORCE_EQ(
HasGetterMethod(name), true,
platform::errors::NotFound(
"Cannot register setter for %s before register getter", name));
HasGetterMethod(name), false,
platform::errors::AlreadyExists(
"Getter of global variable %s has been registered", name));
PADDLE_ENFORCE_EQ(
setters_.count(name), 0,
HasSetterMethod(name), false,
platform::errors::AlreadyExists(
"Setter of global variable %s has been registered", name));
PADDLE_ENFORCE_NOT_NULL(func, platform::errors::InvalidArgument(
"Setter of %s should not be null", name));
setters_[name] = std::move(func);
PADDLE_ENFORCE_NOT_NULL(getter,
platform::errors::InvalidArgument(
"Getter of %s should not be null", name));
PADDLE_ENFORCE_NOT_NULL(setter,
platform::errors::InvalidArgument(
"Setter of %s should not be null", name));
var_infos_.insert({name, VarInfo(is_public, getter, setter)});
}
const Getter &GetterMethod(const std::string &name) const {
PADDLE_ENFORCE_EQ(
HasGetterMethod(name), true,
platform::errors::NotFound("Cannot find global variable %s", name));
return getters_.at(name);
return var_infos_.at(name).getter;
}
py::object GetOrReturnDefaultValue(const std::string &name,
......@@ -104,7 +138,7 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
PADDLE_ENFORCE_EQ(
HasSetterMethod(name), true,
platform::errors::NotFound("Global variable %s is not writable", name));
return setters_.at(name);
return var_infos_.at(name).setter;
}
void Set(const std::string &name, const py::object &value) const {
......@@ -112,17 +146,21 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
}
bool HasGetterMethod(const std::string &name) const {
return getters_.count(name) > 0;
return var_infos_.count(name) > 0;
}
bool HasSetterMethod(const std::string &name) const {
return setters_.count(name) > 0;
return var_infos_.count(name) > 0 && var_infos_.at(name).setter;
}
bool IsPublic(const std::string &name) const {
return var_infos_.count(name) > 0 && var_infos_.at(name).is_public;
}
std::unordered_set<std::string> Keys() const {
std::unordered_set<std::string> keys;
keys.reserve(getters_.size());
for (auto &pair : getters_) {
keys.reserve(var_infos_.size());
for (auto &pair : var_infos_) {
keys.insert(pair.first);
}
return keys;
......@@ -131,12 +169,86 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
private:
static GlobalVarGetterSetterRegistry instance_;
std::unordered_map<std::string, Getter> getters_;
std::unordered_map<std::string, Setter> setters_;
std::unordered_map<std::string, VarInfo> var_infos_;
};
GlobalVarGetterSetterRegistry GlobalVarGetterSetterRegistry::instance_;
class GlobalVarGetterSetterRegistryHelper {
public:
GlobalVarGetterSetterRegistryHelper(bool is_public, bool is_writable,
const std::string &var_names)
: is_public_(is_public),
is_writable_(is_writable),
var_names_(SplitVarNames(var_names)) {}
template <typename... Args>
void Register(Args &&... args) const {
Impl<0, sizeof...(args) == 1, Args...>::Register(
is_public_, is_writable_, var_names_, std::forward<Args>(args)...);
}
private:
static std::vector<std::string> SplitVarNames(const std::string &names) {
auto valid_char = [](char ch) { return !std::isspace(ch) && ch != ','; };
std::vector<std::string> ret;
size_t i = 0, j = 0, n = names.size();
while (i < n) {
for (; i < n && !valid_char(names[i]); ++i) {
}
for (j = i + 1; j < n && valid_char(names[j]); ++j) {
}
if (i < n && j <= n) {
auto substring = names.substr(i, j - i);
VLOG(10) << "Get substring: \"" << substring << "\"";
ret.emplace_back(substring);
}
i = j + 1;
}
return ret;
}
private:
template <size_t kIdx, bool kIsStop, typename T, typename... Args>
struct Impl {
static void Register(bool is_public, bool is_writable,
const std::vector<std::string> &var_names, T &&var,
Args &&... args) {
PADDLE_ENFORCE_EQ(kIdx + 1 + sizeof...(args), var_names.size(),
platform::errors::InvalidArgument(
"Argument number not match name number"));
Impl<kIdx, true, T>::Register(is_public, is_writable, var_names, var);
Impl<kIdx + 1, sizeof...(Args) == 1, Args...>::Register(
is_public, is_writable, var_names, std::forward<Args>(args)...);
}
};
template <size_t kIdx, typename T>
struct Impl<kIdx, true, T> {
static void Register(bool is_public, bool is_writable,
const std::vector<std::string> &var_names, T &&var) {
auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();
if (is_writable) {
instance->Register(
var_names[kIdx], is_public,
GlobalVarGetterSetterRegistry::CreateGetter(std::forward<T>(var)),
GlobalVarGetterSetterRegistry::CreateSetter(&var));
} else {
instance->Register(
var_names[kIdx], is_public,
GlobalVarGetterSetterRegistry::CreateGetter(std::forward<T>(var)));
}
}
};
private:
const bool is_public_;
const bool is_writable_;
const std::vector<std::string> var_names_;
};
static void RegisterGlobalVarGetterSetter();
void BindGlobalValueGetterSetter(pybind11::module *module) {
......@@ -148,6 +260,7 @@ void BindGlobalValueGetterSetter(pybind11::module *module) {
.def("__setitem__", &GlobalVarGetterSetterRegistry::Set)
.def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod)
.def("keys", &GlobalVarGetterSetterRegistry::Keys)
.def("is_public", &GlobalVarGetterSetterRegistry::IsPublic)
.def("get", &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue,
py::arg("key"), py::arg("default") = py::cast<py::none>(Py_None));
......@@ -155,33 +268,33 @@ void BindGlobalValueGetterSetter(pybind11::module *module) {
py::return_value_policy::reference);
}
#define REGISTER_GLOBAL_VAR_GETTER_ONLY(var) \
GlobalVarGetterSetterRegistry::MutableInstance()->RegisterGetter( \
#var, []() -> py::object { return py::cast(var); })
/* Public vars are designed to be writable. */
#define REGISTER_PUBLIC_GLOBAL_VAR(...) \
do { \
GlobalVarGetterSetterRegistryHelper(/*is_public=*/true, \
/*is_writable=*/true, "" #__VA_ARGS__) \
.Register(__VA_ARGS__); \
} while (0)
#define REGISTER_PRIVATE_GLOBAL_VAR(is_writable, ...) \
do { \
GlobalVarGetterSetterRegistryHelper(/*is_public=*/false, is_writable, \
"" #__VA_ARGS__) \
.Register(__VA_ARGS__); \
} while (0)
#define REGISTER_GLOBAL_VAR_SETTER_ONLY(var) \
GlobalVarGetterSetterRegistry::MutableInstance()->RegisterSetter( \
#var, [](const py::object &obj) { \
using ValueType = std::remove_reference<decltype(var)>::type; \
var = py::cast<ValueType>(obj); \
})
static void RegisterGlobalVarGetterSetter() {
REGISTER_PRIVATE_GLOBAL_VAR(/*is_writable=*/false, FLAGS_use_mkldnn,
FLAGS_use_ngraph, FLAGS_free_idle_chunk,
FLAGS_free_when_no_cache_hit);
#define REGISTER_GLOBAL_VAR_GETTER_SETTER(var) \
REGISTER_GLOBAL_VAR_GETTER_ONLY(var); \
REGISTER_GLOBAL_VAR_SETTER_ONLY(var)
REGISTER_PUBLIC_GLOBAL_VAR(
FLAGS_eager_delete_tensor_gb, FLAGS_enable_parallel_graph,
FLAGS_allocator_strategy, FLAGS_use_system_allocator);
static void RegisterGlobalVarGetterSetter() {
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_mkldnn);
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_use_ngraph);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_eager_delete_tensor_gb);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_use_system_allocator);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_allocator_strategy);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_enable_parallel_graph);
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_idle_chunk);
REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_when_no_cache_hit);
#ifdef PADDLE_WITH_CUDA
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_gpu_memory_limit_mb);
REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_cudnn_deterministic);
REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_gpu_memory_limit_mb,
FLAGS_cudnn_deterministic);
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册