diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index d46a7ec1fcd1d0de4f104948a9bef87707c706a7..c4c64881a82b56be57d85e53b9abde1c8082d7bb 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -91,6 +91,13 @@ void FleetWrapper::StopServer() { #endif } +void FleetWrapper::FinalizeWorker() { +#ifdef PADDLE_WITH_PSLIB + VLOG(3) << "Going to finalize worker"; + pslib_ptr_->finalize_worker(); +#endif +} + uint64_t FleetWrapper::RunServer() { #ifdef PADDLE_WITH_PSLIB VLOG(3) << "Going to run server"; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index fc98cba853b7b6edab73d6b077f288ca084f0f6c..aa93e8d28bcfc5286ce777e4410e211fc12e719c 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -147,6 +147,8 @@ class FleetWrapper { int index); // stop server void StopServer(); + // finalize worker to make worker can be stop + void FinalizeWorker(); // run server uint64_t RunServer(); // gather server ip diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 31268f5e1826a6be63a23cbe29e8a960b1ac5705..679c91e8d8dbd9cb1e9fa5e038c3b4138eec2a6b 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -55,6 +55,7 @@ void BindFleetWrapper(py::module* m) { .def("load_model", &framework::FleetWrapper::LoadModel) .def("clear_model", &framework::FleetWrapper::ClearModel) .def("stop_server", &framework::FleetWrapper::StopServer) + .def("finalize_worker", &framework::FleetWrapper::FinalizeWorker) .def("gather_servers", &framework::FleetWrapper::GatherServers) .def("gather_clients", &framework::FleetWrapper::GatherClients) .def("get_clients_info", &framework::FleetWrapper::GetClientsInfo) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index c6d62b1d02742a8167f48eb69e2691f2db6a20b8..acebbe8251620140e9c6037accf73b4fb311eaeb 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -182,6 +182,10 @@ class PSLib(Fleet): destroyed when stop() is called. """ self._role_maker._barrier_worker() + # all worker should be finalize first + if self._role_maker.is_worker(): + self._fleet_ptr.finalize_worker() + self._role_maker._barrier_worker() if self._role_maker.is_first_worker(): self._fleet_ptr.stop_server() self._role_maker._barrier_worker()