diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index a82bce97992a9516fb83b94b8a7b27c09b99195a..806e0ec0a15aefb6d2542f608ecd9a1e823520e2 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -29,7 +29,7 @@ #include "paddle/fluid/platform/macros.h" #include "pybind11/stl.h" -// NOTE: where is these 2 flags from? +// NOTE: where are these 2 flags from? #ifdef PADDLE_WITH_DISTRIBUTE DECLARE_int32(rpc_get_thread_num); DECLARE_int32(rpc_prefetch_thread_num); @@ -197,22 +197,28 @@ void BindGlobalValueGetterSetter(pybind11::module *module) { } while (0) struct RegisterGetterSetterVisitor : public boost::static_visitor { - RegisterGetterSetterVisitor(const std::string &name, bool is_public, + RegisterGetterSetterVisitor(const std::string &name, bool is_writable, void *value_ptr) - : name_(name), value_ptr_(value_ptr) {} + : name_(name), is_writable_(is_writable), value_ptr_(value_ptr) {} template void operator()(const T &) const { auto &value = *static_cast(value_ptr_); auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); - instance->Register(name_, is_public_, - GlobalVarGetterSetterRegistry::CreateGetter(value), - GlobalVarGetterSetterRegistry::CreateSetter(&value)); + bool is_public = is_writable_; // currently, all writable vars are public + if (is_public) { + instance->Register(name_, is_public, + GlobalVarGetterSetterRegistry::CreateGetter(value), + GlobalVarGetterSetterRegistry::CreateSetter(&value)); + } else { + instance->Register(name_, is_public, + GlobalVarGetterSetterRegistry::CreateGetter(value)); + } } private: std::string name_; - bool is_public_; + bool is_writable_; void *value_ptr_; };