From 991dc67df6fd68c63f0816231d33e011401d2a3a Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Sun, 26 Sep 2021 15:34:07 +0800 Subject: [PATCH] set file_num in one shard (#35835) * set file_num in one shard * format --- paddle/fluid/framework/fleet/fleet_wrapper.cc | 14 ++++++++++++++ paddle/fluid/framework/fleet/fleet_wrapper.h | 1 + paddle/fluid/pybind/fleet_wrapper_py.cc | 2 ++ .../fleet/parameter_server/pslib/__init__.py | 15 +++++++++++++++ 4 files changed, 32 insertions(+) diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index dc5e24ef5d..4346c144fa 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -1347,6 +1347,20 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) { #endif } +void FleetWrapper::SetFileNumOneShard(const uint64_t table_id, int file_num) { +#ifdef PADDLE_WITH_PSLIB + auto ret = + pslib_ptr_->_worker_ptr->set_file_num_one_shard(table_id, file_num); + ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "set_file_num_one_shard failed"; + } +#else + VLOG(0) << "FleetWrapper::SetFileNumOneShard does nothing when no pslib"; +#endif +} + double FleetWrapper::GetCacheThreshold(int table_id) { #ifdef PADDLE_WITH_PSLIB double cache_threshold = 0.0; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index c1db06a298..d368b421ff 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -266,6 +266,7 @@ class FleetWrapper { bool load_combine); void PrintTableStat(const uint64_t table_id); + void SetFileNumOneShard(const uint64_t table_id, int file_num); // mode = 0, load all feature // mode = 1, load delta feature, which means load diff void LoadModel(const std::string& path, const int mode); diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 873476629c..d8142f717b 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -76,6 +76,8 @@ void BindFleetWrapper(py::module* m) { .def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable) .def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable) .def("print_table_stat", &framework::FleetWrapper::PrintTableStat) + .def("set_file_num_one_shard", + &framework::FleetWrapper::SetFileNumOneShard) .def("client_flush", &framework::FleetWrapper::ClientFlush) .def("load_from_paddle_model", &framework::FleetWrapper::LoadFromPaddleModel) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 39cf3ebeb3..e8d9cc3b77 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -327,6 +327,21 @@ class PSLib(Fleet): self._fleet_ptr.print_table_stat(table_id) self._role_maker._barrier_worker() + def set_file_num_one_shard(self, table_id, file_num): + """ + set file_num in one shard + Args: + table_id(int): the id of table + file_num(int): file num in one shard + Example: + .. code-block:: python + fleet.set_file_num_one_shard(0, 5) + """ + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.set_file_num_one_shard(table_id, file_num) + self._role_maker._barrier_worker() + def save_persistables(self, executor, dirname, main_program=None, **kwargs): """ save presistable parameters, -- GitLab