未验证 提交 8d4b64e8 编写于 作者: C Chitsing KUI 提交者: GitHub

[DEBUG] print modifed flags (#53243)

* print modifed flags

* fix ref, opt print

* fix default getter

* fix ut
上级 eb677102
...@@ -64,6 +64,11 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { ...@@ -64,6 +64,11 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
return [&]() -> py::object { return py::cast(var); }; return [&]() -> py::object { return py::cast(var); };
} }
template <typename T>
static Getter CreateDefaultValueGetter(const T &var) {
return [=]() -> py::object { return py::cast(var); };
}
template <typename T> template <typename T>
static Setter CreateSetter(T *var) { static Setter CreateSetter(T *var) {
return [var](const py::object &obj) { *var = py::cast<T>(obj); }; return [var](const py::object &obj) { *var = py::cast<T>(obj); };
...@@ -71,14 +76,23 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { ...@@ -71,14 +76,23 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
private: private:
struct VarInfo { struct VarInfo {
VarInfo(bool is_public, const Getter &getter) VarInfo(bool is_public, const Getter &getter, const Getter &default_getter)
: is_public(is_public), getter(getter) {} : is_public(is_public),
getter(getter),
default_getter(default_getter) {}
VarInfo(bool is_public, const Getter &getter, const Setter &setter) VarInfo(bool is_public,
: is_public(is_public), getter(getter), setter(setter) {} const Getter &getter,
const Getter &default_getter,
const Setter &setter)
: is_public(is_public),
getter(getter),
default_getter(default_getter),
setter(setter) {}
const bool is_public; const bool is_public;
const Getter getter; const Getter getter;
const Getter default_getter;
const Setter setter; const Setter setter;
}; };
...@@ -87,7 +101,10 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { ...@@ -87,7 +101,10 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; } static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; }
void Register(const std::string &name, bool is_public, const Getter &getter) { void Register(const std::string &name,
bool is_public,
const Getter &getter,
const Getter &default_getter) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
HasGetterMethod(name), HasGetterMethod(name),
false, false,
...@@ -96,12 +113,13 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { ...@@ -96,12 +113,13 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
PADDLE_ENFORCE_NOT_NULL(getter, PADDLE_ENFORCE_NOT_NULL(getter,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Getter of %s should not be null", name)); "Getter of %s should not be null", name));
var_infos_.insert({name, VarInfo(is_public, getter)}); var_infos_.insert({name, VarInfo(is_public, getter, default_getter)});
} }
void Register(const std::string &name, void Register(const std::string &name,
bool is_public, bool is_public,
const Getter &getter, const Getter &getter,
const Getter &default_getter,
const Setter &setter) { const Setter &setter) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
HasGetterMethod(name), HasGetterMethod(name),
...@@ -122,7 +140,8 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { ...@@ -122,7 +140,8 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
PADDLE_ENFORCE_NOT_NULL(setter, PADDLE_ENFORCE_NOT_NULL(setter,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Setter of %s should not be null", name)); "Setter of %s should not be null", name));
var_infos_.insert({name, VarInfo(is_public, getter, setter)}); var_infos_.insert(
{name, VarInfo(is_public, getter, default_getter, setter)});
} }
const Getter &GetterMethod(const std::string &name) const { const Getter &GetterMethod(const std::string &name) const {
...@@ -133,6 +152,14 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { ...@@ -133,6 +152,14 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
return var_infos_.at(name).getter; return var_infos_.at(name).getter;
} }
const Getter &DefaultGetterMethod(const std::string &name) const {
PADDLE_ENFORCE_EQ(
HasGetterMethod(name),
true,
platform::errors::NotFound("Cannot find global variable %s", name));
return var_infos_.at(name).default_getter;
}
py::object GetOrReturnDefaultValue(const std::string &name, py::object GetOrReturnDefaultValue(const std::string &name,
const py::object &default_value) const { const py::object &default_value) const {
if (HasGetterMethod(name)) { if (HasGetterMethod(name)) {
...@@ -142,6 +169,14 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { ...@@ -142,6 +169,14 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
} }
} }
py::object GetDefaultValue(const std::string &name) const {
if (HasGetterMethod(name)) {
return DefaultGetterMethod(name)();
} else {
return py::cast(Py_None);
}
}
py::object Get(const std::string &name) const { return GetterMethod(name)(); } py::object Get(const std::string &name) const { return GetterMethod(name)(); }
const Setter &SetterMethod(const std::string &name) const { const Setter &SetterMethod(const std::string &name) const {
...@@ -198,6 +233,9 @@ void BindGlobalValueGetterSetter(pybind11::module *module) { ...@@ -198,6 +233,9 @@ void BindGlobalValueGetterSetter(pybind11::module *module) {
.def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod) .def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod)
.def("keys", &GlobalVarGetterSetterRegistry::Keys) .def("keys", &GlobalVarGetterSetterRegistry::Keys)
.def("is_public", &GlobalVarGetterSetterRegistry::IsPublic) .def("is_public", &GlobalVarGetterSetterRegistry::IsPublic)
.def("get_default",
&GlobalVarGetterSetterRegistry::GetDefaultValue,
py::arg("key"))
.def("get", .def("get",
&GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue, &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue,
py::arg("key"), py::arg("key"),
...@@ -212,9 +250,11 @@ void BindGlobalValueGetterSetter(pybind11::module *module) { ...@@ -212,9 +250,11 @@ void BindGlobalValueGetterSetter(pybind11::module *module) {
#define REGISTER_PUBLIC_GLOBAL_VAR(var) \ #define REGISTER_PUBLIC_GLOBAL_VAR(var) \
do { \ do { \
auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); \ auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); \
instance->Register(#var, \ instance->Register( \
#var, \
/*is_public=*/true, \ /*is_public=*/true, \
GlobalVarGetterSetterRegistry::CreateGetter(var), \ GlobalVarGetterSetterRegistry::CreateGetter(var), \
GlobalVarGetterSetterRegistry::CreateDefaultValueGetter(var), \
GlobalVarGetterSetterRegistry::CreateSetter(&var)); \ GlobalVarGetterSetterRegistry::CreateSetter(&var)); \
} while (0) } while (0)
...@@ -225,18 +265,25 @@ struct RegisterGetterSetterVisitor { ...@@ -225,18 +265,25 @@ struct RegisterGetterSetterVisitor {
: name_(name), is_writable_(is_writable), value_ptr_(value_ptr) {} : name_(name), is_writable_(is_writable), value_ptr_(value_ptr) {}
template <typename T> template <typename T>
void operator()(const T &) const { void operator()(const T &default_value) const {
auto &value = *static_cast<T *>(value_ptr_); auto &value = *static_cast<T *>(value_ptr_);
auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();
bool is_public = is_writable_; // currently, all writable vars are public bool is_public = is_writable_; // currently, all writable vars are public
if (is_writable_) { if (is_writable_) {
instance->Register(name_, instance->Register(
name_,
is_public, is_public,
GlobalVarGetterSetterRegistry::CreateGetter(value), GlobalVarGetterSetterRegistry::CreateGetter(value),
GlobalVarGetterSetterRegistry::CreateDefaultValueGetter(
default_value),
GlobalVarGetterSetterRegistry::CreateSetter(&value)); GlobalVarGetterSetterRegistry::CreateSetter(&value));
} else { } else {
instance->Register( instance->Register(
name_, is_public, GlobalVarGetterSetterRegistry::CreateGetter(value)); name_,
is_public,
GlobalVarGetterSetterRegistry::CreateGetter(value),
GlobalVarGetterSetterRegistry::CreateDefaultValueGetter(
default_value));
} }
} }
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
import itertools import itertools
import os import os
import sys
import time import time
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict, namedtuple
from contextlib import contextmanager from contextlib import contextmanager
from multiprocessing import Manager # noqa: F401 from multiprocessing import Manager # noqa: F401
from multiprocessing import Process # noqa: F401 from multiprocessing import Process # noqa: F401
...@@ -905,6 +906,31 @@ def _check_var_exists(var_name): ...@@ -905,6 +906,31 @@ def _check_var_exists(var_name):
) )
def _get_modified_flags():
ret = []
FLAGS = namedtuple('FLAGS', ['name', 'current_value', 'default_value'])
global_flags = core.globals()
for key in global_flags.keys():
value = global_flags.get(key)
default_value = global_flags.get_default(key)
if not value == default_value:
ret.append(FLAGS(key, value, default_value))
return ret
def _print_modified_flags(modified_flags):
if len(modified_flags) > 0:
sys.stderr.write(
"======================= Modified FLAGS detected =======================\n"
)
for flag in modified_flags:
sys.stderr.write(str(flag))
sys.stderr.write("\n")
sys.stderr.write(
"=======================================================================\n"
)
def init_parallel_env(): def init_parallel_env():
""" """
...@@ -967,6 +993,9 @@ def init_parallel_env(): ...@@ -967,6 +993,9 @@ def init_parallel_env():
""" """
modified_flags = _get_modified_flags()
_print_modified_flags(modified_flags)
# 0. get env & check world size # 0. get env & check world size
global _global_parallel_env global _global_parallel_env
# when call init_parallel_env, need update `_global_parallel_env` # when call init_parallel_env, need update `_global_parallel_env`
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册