From 28cb0067d48167695af1b3b66de77fed8e12afea Mon Sep 17 00:00:00 2001 From: zmxdream Date: Wed, 20 Jul 2022 11:07:28 +0800 Subject: [PATCH] [GPUPS]FleetWrapper initialize (#44441) * fix FleetWrapper initialize --- paddle/fluid/framework/fleet/fleet_wrapper.cc | 1 + paddle/fluid/framework/fleet/fleet_wrapper.h | 9 +++++++-- paddle/fluid/pybind/fleet_wrapper_py.cc | 5 +++-- .../parameter_server/distribute_transpiler/__init__.py | 3 ++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index a93917a5a3..8099251351 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -37,6 +37,7 @@ namespace framework { const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; std::shared_ptr FleetWrapper::s_instance_ = NULL; bool FleetWrapper::is_initialized_ = false; +std::mutex FleetWrapper::ins_mutex; #ifdef PADDLE_WITH_PSLIB std::shared_ptr FleetWrapper::pslib_ptr_ = NULL; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index c9c03fb66f..990178feab 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -24,6 +24,7 @@ limitations under the License. */ #include #include #include +#include #include #include #include @@ -381,8 +382,11 @@ class FleetWrapper { void Revert(); // FleetWrapper singleton static std::shared_ptr GetInstance() { - if (NULL == s_instance_) { - s_instance_.reset(new paddle::framework::FleetWrapper()); + { + std::lock_guard lk(ins_mutex); + if (NULL == s_instance_) { + s_instance_.reset(new paddle::framework::FleetWrapper()); + } } return s_instance_; } @@ -397,6 +401,7 @@ class FleetWrapper { private: static std::shared_ptr s_instance_; + static std::mutex ins_mutex; #ifdef PADDLE_WITH_PSLIB std::map> _regions; #endif diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 8626659d86..c78e5f2b5f 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -41,8 +41,9 @@ namespace py = pybind11; namespace paddle { namespace pybind { void BindFleetWrapper(py::module* m) { - py::class_(*m, "Fleet") - .def(py::init()) + py::class_>( + *m, "Fleet") + .def(py::init([]() { return framework::FleetWrapper::GetInstance(); })) .def("push_dense", &framework::FleetWrapper::PushDenseVarsSync) .def("pull_dense", &framework::FleetWrapper::PullDenseVarsSync) .def("init_server", &framework::FleetWrapper::InitServer) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index 1354c317b0..67c1a0f0f8 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -88,7 +88,8 @@ class FleetTranspiler(Fleet): if role_maker is None: role_maker = MPISymetricRoleMaker() super(FleetTranspiler, self).init(role_maker) - self._fleet_ptr = core.Fleet() + if self._fleet_ptr is None: + self._fleet_ptr = core.Fleet() def _init_transpiler_worker(self): """ -- GitLab