diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index e6935b577d737705bffabb8f47c43b6cd5d21ce0..94e3ca1ba41bdbb944f8b031479d347c423b9fe0 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -64,6 +64,11 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { return [&]() -> py::object { return py::cast(var); }; } + template + static Getter CreateDefaultValueGetter(const T &var) { + return [=]() -> py::object { return py::cast(var); }; + } + template static Setter CreateSetter(T *var) { return [var](const py::object &obj) { *var = py::cast(obj); }; @@ -71,14 +76,23 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { private: struct VarInfo { - VarInfo(bool is_public, const Getter &getter) - : is_public(is_public), getter(getter) {} - - VarInfo(bool is_public, const Getter &getter, const Setter &setter) - : is_public(is_public), getter(getter), setter(setter) {} + VarInfo(bool is_public, const Getter &getter, const Getter &default_getter) + : is_public(is_public), + getter(getter), + default_getter(default_getter) {} + + VarInfo(bool is_public, + const Getter &getter, + const Getter &default_getter, + const Setter &setter) + : is_public(is_public), + getter(getter), + default_getter(default_getter), + setter(setter) {} const bool is_public; const Getter getter; + const Getter default_getter; const Setter setter; }; @@ -87,7 +101,10 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; } - void Register(const std::string &name, bool is_public, const Getter &getter) { + void Register(const std::string &name, + bool is_public, + const Getter &getter, + const Getter &default_getter) { PADDLE_ENFORCE_EQ( HasGetterMethod(name), false, @@ -96,12 +113,13 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { PADDLE_ENFORCE_NOT_NULL(getter, platform::errors::InvalidArgument( "Getter of %s should not be null", name)); - var_infos_.insert({name, VarInfo(is_public, getter)}); + var_infos_.insert({name, VarInfo(is_public, getter, default_getter)}); } void Register(const std::string &name, bool is_public, const Getter &getter, + const Getter &default_getter, const Setter &setter) { PADDLE_ENFORCE_EQ( HasGetterMethod(name), @@ -122,7 +140,8 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { PADDLE_ENFORCE_NOT_NULL(setter, platform::errors::InvalidArgument( "Setter of %s should not be null", name)); - var_infos_.insert({name, VarInfo(is_public, getter, setter)}); + var_infos_.insert( + {name, VarInfo(is_public, getter, default_getter, setter)}); } const Getter &GetterMethod(const std::string &name) const { @@ -133,6 +152,14 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { return var_infos_.at(name).getter; } + const Getter &DefaultGetterMethod(const std::string &name) const { + PADDLE_ENFORCE_EQ( + HasGetterMethod(name), + true, + platform::errors::NotFound("Cannot find global variable %s", name)); + return var_infos_.at(name).default_getter; + } + py::object GetOrReturnDefaultValue(const std::string &name, const py::object &default_value) const { if (HasGetterMethod(name)) { @@ -142,6 +169,14 @@ class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { } } + py::object GetDefaultValue(const std::string &name) const { + if (HasGetterMethod(name)) { + return DefaultGetterMethod(name)(); + } else { + return py::cast(Py_None); + } + } + py::object Get(const std::string &name) const { return GetterMethod(name)(); } const Setter &SetterMethod(const std::string &name) const { @@ -198,6 +233,9 @@ void BindGlobalValueGetterSetter(pybind11::module *module) { .def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod) .def("keys", &GlobalVarGetterSetterRegistry::Keys) .def("is_public", &GlobalVarGetterSetterRegistry::IsPublic) + .def("get_default", + &GlobalVarGetterSetterRegistry::GetDefaultValue, + py::arg("key")) .def("get", &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue, py::arg("key"), @@ -209,13 +247,15 @@ void BindGlobalValueGetterSetter(pybind11::module *module) { } /* Public vars are designed to be writable. */ -#define REGISTER_PUBLIC_GLOBAL_VAR(var) \ - do { \ - auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); \ - instance->Register(#var, \ - /*is_public=*/true, \ - GlobalVarGetterSetterRegistry::CreateGetter(var), \ - GlobalVarGetterSetterRegistry::CreateSetter(&var)); \ +#define REGISTER_PUBLIC_GLOBAL_VAR(var) \ + do { \ + auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); \ + instance->Register( \ + #var, \ + /*is_public=*/true, \ + GlobalVarGetterSetterRegistry::CreateGetter(var), \ + GlobalVarGetterSetterRegistry::CreateDefaultValueGetter(var), \ + GlobalVarGetterSetterRegistry::CreateSetter(&var)); \ } while (0) struct RegisterGetterSetterVisitor { @@ -225,18 +265,25 @@ struct RegisterGetterSetterVisitor { : name_(name), is_writable_(is_writable), value_ptr_(value_ptr) {} template - void operator()(const T &) const { + void operator()(const T &default_value) const { auto &value = *static_cast(value_ptr_); auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); bool is_public = is_writable_; // currently, all writable vars are public if (is_writable_) { - instance->Register(name_, - is_public, - GlobalVarGetterSetterRegistry::CreateGetter(value), - GlobalVarGetterSetterRegistry::CreateSetter(&value)); + instance->Register( + name_, + is_public, + GlobalVarGetterSetterRegistry::CreateGetter(value), + GlobalVarGetterSetterRegistry::CreateDefaultValueGetter( + default_value), + GlobalVarGetterSetterRegistry::CreateSetter(&value)); } else { instance->Register( - name_, is_public, GlobalVarGetterSetterRegistry::CreateGetter(value)); + name_, + is_public, + GlobalVarGetterSetterRegistry::CreateGetter(value), + GlobalVarGetterSetterRegistry::CreateDefaultValueGetter( + default_value)); } } diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index c016f9d743c7ff91c14610319ec52a9877a72306..3fdf7cdcdd9542913cd57bbe15aa39ad9349098e 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -14,9 +14,10 @@ import itertools import os +import sys import time import warnings -from collections import OrderedDict +from collections import OrderedDict, namedtuple from contextlib import contextmanager from multiprocessing import Manager # noqa: F401 from multiprocessing import Process # noqa: F401 @@ -905,6 +906,31 @@ def _check_var_exists(var_name): ) +def _get_modified_flags(): + ret = [] + FLAGS = namedtuple('FLAGS', ['name', 'current_value', 'default_value']) + global_flags = core.globals() + for key in global_flags.keys(): + value = global_flags.get(key) + default_value = global_flags.get_default(key) + if not value == default_value: + ret.append(FLAGS(key, value, default_value)) + return ret + + +def _print_modified_flags(modified_flags): + if len(modified_flags) > 0: + sys.stderr.write( + "======================= Modified FLAGS detected =======================\n" + ) + for flag in modified_flags: + sys.stderr.write(str(flag)) + sys.stderr.write("\n") + sys.stderr.write( + "=======================================================================\n" + ) + + def init_parallel_env(): """ @@ -967,6 +993,9 @@ def init_parallel_env(): """ + modified_flags = _get_modified_flags() + _print_modified_flags(modified_flags) + # 0. get env & check world size global _global_parallel_env # when call init_parallel_env, need update `_global_parallel_env`