未验证 提交 8d4b64e8 编写于 作者: C Chitsing KUI 提交者: GitHub

[DEBUG] print modifed flags (#53243)

* print modifed flags

* fix ref, opt print

* fix default getter

* fix ut
上级 eb677102
......@@ -64,6 +64,11 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
return [&]() -> py::object { return py::cast(var); };
}
template <typename T>
static Getter CreateDefaultValueGetter(const T &var) {
return [=]() -> py::object { return py::cast(var); };
}
template <typename T>
static Setter CreateSetter(T *var) {
return [var](const py::object &obj) { *var = py::cast<T>(obj); };
......@@ -71,14 +76,23 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
private:
struct VarInfo {
VarInfo(bool is_public, const Getter &getter)
: is_public(is_public), getter(getter) {}
VarInfo(bool is_public, const Getter &getter, const Getter &default_getter)
: is_public(is_public),
getter(getter),
default_getter(default_getter) {}
VarInfo(bool is_public, const Getter &getter, const Setter &setter)
: is_public(is_public), getter(getter), setter(setter) {}
VarInfo(bool is_public,
const Getter &getter,
const Getter &default_getter,
const Setter &setter)
: is_public(is_public),
getter(getter),
default_getter(default_getter),
setter(setter) {}
const bool is_public;
const Getter getter;
const Getter default_getter;
const Setter setter;
};
......@@ -87,7 +101,10 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; }
void Register(const std::string &name, bool is_public, const Getter &getter) {
void Register(const std::string &name,
bool is_public,
const Getter &getter,
const Getter &default_getter) {
PADDLE_ENFORCE_EQ(
HasGetterMethod(name),
false,
......@@ -96,12 +113,13 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
PADDLE_ENFORCE_NOT_NULL(getter,
platform::errors::InvalidArgument(
"Getter of %s should not be null", name));
var_infos_.insert({name, VarInfo(is_public, getter)});
var_infos_.insert({name, VarInfo(is_public, getter, default_getter)});
}
void Register(const std::string &name,
bool is_public,
const Getter &getter,
const Getter &default_getter,
const Setter &setter) {
PADDLE_ENFORCE_EQ(
HasGetterMethod(name),
......@@ -122,7 +140,8 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
PADDLE_ENFORCE_NOT_NULL(setter,
platform::errors::InvalidArgument(
"Setter of %s should not be null", name));
var_infos_.insert({name, VarInfo(is_public, getter, setter)});
var_infos_.insert(
{name, VarInfo(is_public, getter, default_getter, setter)});
}
const Getter &GetterMethod(const std::string &name) const {
......@@ -133,6 +152,14 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
return var_infos_.at(name).getter;
}
const Getter &DefaultGetterMethod(const std::string &name) const {
PADDLE_ENFORCE_EQ(
HasGetterMethod(name),
true,
platform::errors::NotFound("Cannot find global variable %s", name));
return var_infos_.at(name).default_getter;
}
py::object GetOrReturnDefaultValue(const std::string &name,
const py::object &default_value) const {
if (HasGetterMethod(name)) {
......@@ -142,6 +169,14 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry {
}
}
py::object GetDefaultValue(const std::string &name) const {
if (HasGetterMethod(name)) {
return DefaultGetterMethod(name)();
} else {
return py::cast(Py_None);
}
}
py::object Get(const std::string &name) const { return GetterMethod(name)(); }
const Setter &SetterMethod(const std::string &name) const {
......@@ -198,6 +233,9 @@ void BindGlobalValueGetterSetter(pybind11::module *module) {
.def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod)
.def("keys", &GlobalVarGetterSetterRegistry::Keys)
.def("is_public", &GlobalVarGetterSetterRegistry::IsPublic)
.def("get_default",
&GlobalVarGetterSetterRegistry::GetDefaultValue,
py::arg("key"))
.def("get",
&GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue,
py::arg("key"),
......@@ -212,9 +250,11 @@ void BindGlobalValueGetterSetter(pybind11::module *module) {
#define REGISTER_PUBLIC_GLOBAL_VAR(var) \
do { \
auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); \
instance->Register(#var, \
instance->Register( \
#var, \
/*is_public=*/true, \
GlobalVarGetterSetterRegistry::CreateGetter(var), \
GlobalVarGetterSetterRegistry::CreateDefaultValueGetter(var), \
GlobalVarGetterSetterRegistry::CreateSetter(&var)); \
} while (0)
......@@ -225,18 +265,25 @@ struct RegisterGetterSetterVisitor {
: name_(name), is_writable_(is_writable), value_ptr_(value_ptr) {}
template <typename T>
void operator()(const T &) const {
void operator()(const T &default_value) const {
auto &value = *static_cast<T *>(value_ptr_);
auto *instance = GlobalVarGetterSetterRegistry::MutableInstance();
bool is_public = is_writable_; // currently, all writable vars are public
if (is_writable_) {
instance->Register(name_,
instance->Register(
name_,
is_public,
GlobalVarGetterSetterRegistry::CreateGetter(value),
GlobalVarGetterSetterRegistry::CreateDefaultValueGetter(
default_value),
GlobalVarGetterSetterRegistry::CreateSetter(&value));
} else {
instance->Register(
name_, is_public, GlobalVarGetterSetterRegistry::CreateGetter(value));
name_,
is_public,
GlobalVarGetterSetterRegistry::CreateGetter(value),
GlobalVarGetterSetterRegistry::CreateDefaultValueGetter(
default_value));
}
}
......
......@@ -14,9 +14,10 @@
import itertools
import os
import sys
import time
import warnings
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from contextlib import contextmanager
from multiprocessing import Manager # noqa: F401
from multiprocessing import Process # noqa: F401
......@@ -905,6 +906,31 @@ def _check_var_exists(var_name):
)
def _get_modified_flags():
ret = []
FLAGS = namedtuple('FLAGS', ['name', 'current_value', 'default_value'])
global_flags = core.globals()
for key in global_flags.keys():
value = global_flags.get(key)
default_value = global_flags.get_default(key)
if not value == default_value:
ret.append(FLAGS(key, value, default_value))
return ret
def _print_modified_flags(modified_flags):
if len(modified_flags) > 0:
sys.stderr.write(
"======================= Modified FLAGS detected =======================\n"
)
for flag in modified_flags:
sys.stderr.write(str(flag))
sys.stderr.write("\n")
sys.stderr.write(
"=======================================================================\n"
)
def init_parallel_env():
"""
......@@ -967,6 +993,9 @@ def init_parallel_env():
"""
modified_flags = _get_modified_flags()
_print_modified_flags(modified_flags)
# 0. get env & check world size
global _global_parallel_env
# when call init_parallel_env, need update `_global_parallel_env`
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册