diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index fc52e1a4c930bd41571de4416a6b413923f0e94e..371a5507f1fb06a106c5337d38ebcbd8d25658ed 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -804,6 +804,15 @@ void FleetWrapper::ClearModel() { #endif } +void FleetWrapper::ClearOneTable(const uint64_t table_id) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->clear(table_id); + ret.wait(); +#else + VLOG(0) << "FleetWrapper::ClearOneTable does nothing when no pslib"; +#endif +} + void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope, std::vector var_list, float decay, int emb_dim) { diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 5d831f31c7f6a6f7887e2d1f425a34416a6206ce..a54aea034d2fbfe0d867a6fe28eaa676c8ab3c5c 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -220,6 +220,8 @@ class FleetWrapper { const std::vector& feasign_list); // clear all models, release their memory void ClearModel(); + // clear one table + void ClearOneTable(const uint64_t table_id); // shrink sparse table void ShrinkSparseTable(int table_id); // shrink dense table diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index fac6de452aed018f1397c536c40d7a55b5f188b0..3b4505c611b283648f4da1d36f0200bb3e439d8a 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -59,6 +59,7 @@ void BindFleetWrapper(py::module* m) { .def("save_cache", &framework::FleetWrapper::SaveCache) .def("load_model", &framework::FleetWrapper::LoadModel) .def("clear_model", &framework::FleetWrapper::ClearModel) + .def("clear_one_table", &framework::FleetWrapper::ClearOneTable) .def("stop_server", &framework::FleetWrapper::StopServer) .def("finalize_worker", &framework::FleetWrapper::FinalizeWorker) .def("gather_servers", &framework::FleetWrapper::GatherServers) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 40337110cfe966511050b78e3e463e7653c3caba..27fecf495ade0e8b66ceed83619726ac5d938401 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -402,6 +402,23 @@ class PSLib(Fleet): var_list, decay, emb_dim) self._role_maker._barrier_worker() + def clear_one_table(self, table_id): + """ + clear_one_table() will be called by user. It will clear one table. + + Args: + table_id(int): table id + + Examples: + .. code-block:: python + + fleet.clear_one_table(0) + """ + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.clear_one_table(table_id) + self._role_maker._barrier_worker() + def clear_model(self): """ clear_model() will be called by user. It will clear sparse model. diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py index 7322891338f3b75a9dc04bc5ae5ed6e4515d4869..47aeee95921346fe61c83ff6cc2b4f6a1c7fb07e 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker.py @@ -99,6 +99,7 @@ class TestCloudRoleMaker(unittest.TestCase): except: print("do not support pslib test, skip") return + fleet.clear_one_table(0) from paddle.fluid.incubate.fleet.base.role_maker import \ MPISymetricRoleMaker try: