diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 0c5eb0da1911dbeefbbd54406ca64ca33f0d51d0..3f4f345912467881ba0e83650c9ba1ee9aeee7b7 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -521,6 +521,15 @@ void FleetWrapper::ShrinkSparseTable(int table_id) { #endif } +void FleetWrapper::ClearModel() { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->clear(); + ret.wait(); +#else + VLOG(0) << "FleetWrapper::ClearModel does nothing when no pslib"; +#endif +} + void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope, std::vector var_list, float decay, int emb_dim) { diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 4779978689dfeccfe4f7138d0554281bebdcb5eb..17b58e575950edc61fd1ae6ba982f47ce15b03f6 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -148,6 +148,7 @@ class FleetWrapper { // mode = 1, save delta feature, which means save diff void SaveModel(const std::string& path, const int mode); + void ClearModel(); void ShrinkSparseTable(int table_id); void ShrinkDenseTable(int table_id, Scope* scope, std::vector var_list, float decay, diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index e96258796232c5cebee5125dbc2b8d2b8a7f59e1..36fc0822e8257d0dadef0d1bd6ad4dbc6263fcd8 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -49,6 +49,7 @@ void BindFleetWrapper(py::module* m) { .def("init_model", &framework::FleetWrapper::PushDenseParamSync) .def("save_model", &framework::FleetWrapper::SaveModel) .def("load_model", &framework::FleetWrapper::LoadModel) + .def("clear_model", &framework::FleetWrapper::ClearModel) .def("stop_server", &framework::FleetWrapper::StopServer) .def("gather_servers", &framework::FleetWrapper::GatherServers) .def("gather_clients", &framework::FleetWrapper::GatherClients) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index b70a4f5558e5484375a7f645f07aa4258550e818..92f1314816250781254d580e5265fb22981eaa41 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -264,6 +264,21 @@ class PSLib(Fleet): decay, emb_dim) self._role_maker._barrier_worker() + def clear_model(self): + """ + clear_model() will be called by user. It will clear sparse model. + + Examples: + .. code-block:: python + + fleet.clear_model() + + """ + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.clear_model() + self._role_maker._barrier_worker() + def load_one_table(self, table_id, model_path, **kwargs): """ load pslib model for one table or load params from paddle model