// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/pybind/global_value_getter_setter.h" #include #include #include #include #include #include #include #include "gflags/gflags.h" #include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" #include "pybind11/stl.h" // data processing DECLARE_bool(use_mkldnn); DECLARE_string(tracer_mkldnn_ops_on); DECLARE_string(tracer_mkldnn_ops_off); // debug DECLARE_bool(check_nan_inf); DECLARE_bool(cpu_deterministic); DECLARE_bool(enable_rpc_profiler); DECLARE_int32(multiple_of_cupti_buffer_size); DECLARE_bool(reader_queue_speed_test_mode); DECLARE_int32(call_stack_level); DECLARE_bool(sort_sum_gradient); // device management DECLARE_int32(paddle_num_threads); // executor DECLARE_bool(enable_parallel_graph); DECLARE_string(pe_profile_fname); DECLARE_string(print_sub_graph_dir); DECLARE_bool(use_ngraph); // memory management DECLARE_string(allocator_strategy); DECLARE_double(eager_delete_tensor_gb); DECLARE_double(fraction_of_cpu_memory_to_use); DECLARE_bool(free_idle_chunk); DECLARE_bool(free_when_no_cache_hit); DECLARE_int32(fuse_parameter_groups_size); DECLARE_double(fuse_parameter_memory_size); DECLARE_bool(init_allocated_mem); DECLARE_uint64(initial_cpu_memory_in_mb); DECLARE_double(memory_fraction_of_eager_deletion); DECLARE_bool(use_pinned_memory); DECLARE_bool(use_system_allocator); // others DECLARE_bool(benchmark); DECLARE_int32(inner_op_parallelism); DECLARE_int32(max_inplace_grad_add); DECLARE_string(tracer_profile_fname); #ifdef PADDLE_WITH_CUDA // cudnn DECLARE_uint64(conv_workspace_size_limit); DECLARE_bool(cudnn_batchnorm_spatial_persistent); DECLARE_bool(cudnn_deterministic); DECLARE_bool(cudnn_exhaustive_search); // data processing DECLARE_bool(enable_cublas_tensor_op_math); // device management DECLARE_string(selected_gpus); // memory management DECLARE_bool(eager_delete_scope); DECLARE_bool(fast_eager_deletion_mode); DECLARE_double(fraction_of_cuda_pinned_memory_to_use); DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_uint64(gpu_memory_limit_mb); DECLARE_uint64(initial_gpu_memory_in_mb); DECLARE_uint64(reallocate_gpu_memory_in_mb); // others DECLARE_bool(sync_nccl_allreduce); #endif #ifdef PADDLE_WITH_ASCEND_CL // device management DECLARE_string(selected_npus); #endif #ifdef PADDLE_WITH_DISTRIBUTE DECLARE_int32(rpc_send_thread_num); DECLARE_int32(rpc_get_thread_num); DECLARE_int32(rpc_prefetch_thread_num); #endif namespace paddle { namespace pybind { namespace py = pybind11; class PYBIND11_HIDDEN GlobalVarGetterSetterRegistry { DISABLE_COPY_AND_ASSIGN(GlobalVarGetterSetterRegistry); GlobalVarGetterSetterRegistry() = default; public: using Getter = std::function; using Setter = std::function; template static Getter CreateGetter(const T &var) { return [&]() -> py::object { return py::cast(var); }; } template static Setter CreateSetter(T *var) { return [var](const py::object &obj) { *var = py::cast(obj); }; } private: struct VarInfo { VarInfo(bool is_public, const Getter &getter) : is_public(is_public), getter(getter) {} VarInfo(bool is_public, const Getter &getter, const Setter &setter) : is_public(is_public), getter(getter), setter(setter) {} const bool is_public; const Getter getter; const Setter setter; }; public: static const GlobalVarGetterSetterRegistry &Instance() { return instance_; } static GlobalVarGetterSetterRegistry *MutableInstance() { return &instance_; } void Register(const std::string &name, bool is_public, const Getter &getter) { PADDLE_ENFORCE_EQ( HasGetterMethod(name), false, platform::errors::AlreadyExists( "Getter of global variable %s has been registered", name)); PADDLE_ENFORCE_NOT_NULL(getter, platform::errors::InvalidArgument( "Getter of %s should not be null", name)); var_infos_.insert({name, VarInfo(is_public, getter)}); } void Register(const std::string &name, bool is_public, const Getter &getter, const Setter &setter) { PADDLE_ENFORCE_EQ( HasGetterMethod(name), false, platform::errors::AlreadyExists( "Getter of global variable %s has been registered", name)); PADDLE_ENFORCE_EQ( HasSetterMethod(name), false, platform::errors::AlreadyExists( "Setter of global variable %s has been registered", name)); PADDLE_ENFORCE_NOT_NULL(getter, platform::errors::InvalidArgument( "Getter of %s should not be null", name)); PADDLE_ENFORCE_NOT_NULL(setter, platform::errors::InvalidArgument( "Setter of %s should not be null", name)); var_infos_.insert({name, VarInfo(is_public, getter, setter)}); } const Getter &GetterMethod(const std::string &name) const { PADDLE_ENFORCE_EQ( HasGetterMethod(name), true, platform::errors::NotFound("Cannot find global variable %s", name)); return var_infos_.at(name).getter; } py::object GetOrReturnDefaultValue(const std::string &name, const py::object &default_value) const { if (HasGetterMethod(name)) { return GetterMethod(name)(); } else { return default_value; } } py::object Get(const std::string &name) const { return GetterMethod(name)(); } const Setter &SetterMethod(const std::string &name) const { PADDLE_ENFORCE_EQ( HasSetterMethod(name), true, platform::errors::NotFound("Global variable %s is not writable", name)); return var_infos_.at(name).setter; } void Set(const std::string &name, const py::object &value) const { SetterMethod(name)(value); } bool HasGetterMethod(const std::string &name) const { return var_infos_.count(name) > 0; } bool HasSetterMethod(const std::string &name) const { return var_infos_.count(name) > 0 && var_infos_.at(name).setter; } bool IsPublic(const std::string &name) const { return var_infos_.count(name) > 0 && var_infos_.at(name).is_public; } std::unordered_set Keys() const { std::unordered_set keys; keys.reserve(var_infos_.size()); for (auto &pair : var_infos_) { keys.insert(pair.first); } return keys; } private: static GlobalVarGetterSetterRegistry instance_; std::unordered_map var_infos_; }; GlobalVarGetterSetterRegistry GlobalVarGetterSetterRegistry::instance_; class GlobalVarGetterSetterRegistryHelper { public: GlobalVarGetterSetterRegistryHelper(bool is_public, bool is_writable, const std::string &var_names) : is_public_(is_public), is_writable_(is_writable), var_names_(SplitVarNames(var_names)) {} template void Register(Args &&... args) const { Impl<0, sizeof...(args) == 1, Args...>::Register( is_public_, is_writable_, var_names_, std::forward(args)...); } private: static std::vector SplitVarNames(const std::string &names) { auto valid_char = [](char ch) { return !std::isspace(ch) && ch != ','; }; std::vector ret; size_t i = 0, j = 0, n = names.size(); while (i < n) { for (; i < n && !valid_char(names[i]); ++i) { } for (j = i + 1; j < n && valid_char(names[j]); ++j) { } if (i < n && j <= n) { auto substring = names.substr(i, j - i); VLOG(10) << "Get substring: \"" << substring << "\""; ret.emplace_back(substring); } i = j + 1; } return ret; } private: template struct Impl { static void Register(bool is_public, bool is_writable, const std::vector &var_names, T &&var, Args &&... args) { PADDLE_ENFORCE_EQ(kIdx + 1 + sizeof...(args), var_names.size(), platform::errors::InvalidArgument( "Argument number not match name number")); Impl::Register(is_public, is_writable, var_names, var); Impl::Register( is_public, is_writable, var_names, std::forward(args)...); } }; template struct Impl { static void Register(bool is_public, bool is_writable, const std::vector &var_names, T &&var) { auto *instance = GlobalVarGetterSetterRegistry::MutableInstance(); if (is_writable) { instance->Register( var_names[kIdx], is_public, GlobalVarGetterSetterRegistry::CreateGetter(std::forward(var)), GlobalVarGetterSetterRegistry::CreateSetter(&var)); } else { instance->Register( var_names[kIdx], is_public, GlobalVarGetterSetterRegistry::CreateGetter(std::forward(var))); } } }; private: const bool is_public_; const bool is_writable_; const std::vector var_names_; }; static void RegisterGlobalVarGetterSetter(); void BindGlobalValueGetterSetter(pybind11::module *module) { RegisterGlobalVarGetterSetter(); py::class_(*module, "GlobalVarGetterSetterRegistry") .def("__getitem__", &GlobalVarGetterSetterRegistry::Get) .def("__setitem__", &GlobalVarGetterSetterRegistry::Set) .def("__contains__", &GlobalVarGetterSetterRegistry::HasGetterMethod) .def("keys", &GlobalVarGetterSetterRegistry::Keys) .def("is_public", &GlobalVarGetterSetterRegistry::IsPublic) .def("get", &GlobalVarGetterSetterRegistry::GetOrReturnDefaultValue, py::arg("key"), py::arg("default") = py::cast(Py_None)); module->def("globals", &GlobalVarGetterSetterRegistry::Instance, py::return_value_policy::reference); } /* Public vars are designed to be writable. */ #define REGISTER_PUBLIC_GLOBAL_VAR(...) \ do { \ GlobalVarGetterSetterRegistryHelper(/*is_public=*/true, \ /*is_writable=*/true, "" #__VA_ARGS__) \ .Register(__VA_ARGS__); \ } while (0) #define REGISTER_PRIVATE_GLOBAL_VAR(is_writable, ...) \ do { \ GlobalVarGetterSetterRegistryHelper(/*is_public=*/false, is_writable, \ "" #__VA_ARGS__) \ .Register(__VA_ARGS__); \ } while (0) static void RegisterGlobalVarGetterSetter() { REGISTER_PRIVATE_GLOBAL_VAR(/*is_writable=*/false, FLAGS_free_idle_chunk, FLAGS_free_when_no_cache_hit); REGISTER_PUBLIC_GLOBAL_VAR( FLAGS_eager_delete_tensor_gb, FLAGS_enable_parallel_graph, FLAGS_allocator_strategy, FLAGS_use_system_allocator, FLAGS_check_nan_inf, FLAGS_call_stack_level, FLAGS_sort_sum_gradient, FLAGS_cpu_deterministic, FLAGS_enable_rpc_profiler, FLAGS_multiple_of_cupti_buffer_size, FLAGS_reader_queue_speed_test_mode, FLAGS_pe_profile_fname, FLAGS_print_sub_graph_dir, FLAGS_fraction_of_cpu_memory_to_use, FLAGS_fuse_parameter_groups_size, FLAGS_fuse_parameter_memory_size, FLAGS_init_allocated_mem, FLAGS_initial_cpu_memory_in_mb, FLAGS_memory_fraction_of_eager_deletion, FLAGS_use_pinned_memory, FLAGS_benchmark, FLAGS_inner_op_parallelism, FLAGS_tracer_profile_fname, FLAGS_paddle_num_threads, FLAGS_use_mkldnn, FLAGS_max_inplace_grad_add, FLAGS_tracer_mkldnn_ops_on, FLAGS_tracer_mkldnn_ops_off); #ifdef PADDLE_WITH_CUDA REGISTER_PUBLIC_GLOBAL_VAR( FLAGS_gpu_memory_limit_mb, FLAGS_cudnn_deterministic, FLAGS_conv_workspace_size_limit, FLAGS_cudnn_batchnorm_spatial_persistent, FLAGS_cudnn_exhaustive_search, FLAGS_eager_delete_scope, FLAGS_fast_eager_deletion_mode, FLAGS_fraction_of_cuda_pinned_memory_to_use, FLAGS_fraction_of_gpu_memory_to_use, FLAGS_initial_gpu_memory_in_mb, FLAGS_reallocate_gpu_memory_in_mb, FLAGS_enable_cublas_tensor_op_math, FLAGS_selected_gpus, FLAGS_sync_nccl_allreduce); #endif #ifdef PADDLE_WITH_ASCEND_CL REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_selected_npus); #endif #ifdef PADDLE_WITH_DITRIBUTE REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_rpc_send_thread_num, FLAGS_rpc_get_thread_num, FLAGS_rpc_prefetch_thread_num); #endif } } // namespace pybind } // namespace paddle