diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index a93917a5a31a0d6b208b1feef622a5613a38906f..8099251351068d4edc6fb647d65fc36af1f587ee 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -37,6 +37,7 @@ namespace framework { const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; std::shared_ptr FleetWrapper::s_instance_ = NULL; bool FleetWrapper::is_initialized_ = false; +std::mutex FleetWrapper::ins_mutex; #ifdef PADDLE_WITH_PSLIB std::shared_ptr FleetWrapper::pslib_ptr_ = NULL; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index c9c03fb66f8fa7860f1f746d2587d1b516e7b45d..990178feaba48e7a539ffc933193973ef441227b 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -24,6 +24,7 @@ limitations under the License. */ #include #include #include +#include #include #include #include @@ -381,8 +382,11 @@ class FleetWrapper { void Revert(); // FleetWrapper singleton static std::shared_ptr GetInstance() { - if (NULL == s_instance_) { - s_instance_.reset(new paddle::framework::FleetWrapper()); + { + std::lock_guard lk(ins_mutex); + if (NULL == s_instance_) { + s_instance_.reset(new paddle::framework::FleetWrapper()); + } } return s_instance_; } @@ -397,6 +401,7 @@ class FleetWrapper { private: static std::shared_ptr s_instance_; + static std::mutex ins_mutex; #ifdef PADDLE_WITH_PSLIB std::map> _regions; #endif diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 8626659d8633afec4c09cd9c609c181307a9a570..c78e5f2b5fd2c7dd09954f61353e847c37d6436d 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -41,8 +41,9 @@ namespace py = pybind11; namespace paddle { namespace pybind { void BindFleetWrapper(py::module* m) { - py::class_(*m, "Fleet") - .def(py::init()) + py::class_>( + *m, "Fleet") + .def(py::init([]() { return framework::FleetWrapper::GetInstance(); })) .def("push_dense", &framework::FleetWrapper::PushDenseVarsSync) .def("pull_dense", &framework::FleetWrapper::PullDenseVarsSync) .def("init_server", &framework::FleetWrapper::InitServer) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index 1354c317b0a85dd935489eefe78475110374c207..67c1a0f0f8b47470973fab7c6ea071be01a4ca69 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -88,7 +88,8 @@ class FleetTranspiler(Fleet): if role_maker is None: role_maker = MPISymetricRoleMaker() super(FleetTranspiler, self).init(role_maker) - self._fleet_ptr = core.Fleet() + if self._fleet_ptr is None: + self._fleet_ptr = core.Fleet() def _init_transpiler_worker(self): """