diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index dc5e24ef5de42f0718ae29c7f721e487b672ab07..4346c144fab7f2e516bd36ab5c964b0b8677528c 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -1347,6 +1347,20 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) { #endif } +void FleetWrapper::SetFileNumOneShard(const uint64_t table_id, int file_num) { +#ifdef PADDLE_WITH_PSLIB + auto ret = + pslib_ptr_->_worker_ptr->set_file_num_one_shard(table_id, file_num); + ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "set_file_num_one_shard failed"; + } +#else + VLOG(0) << "FleetWrapper::SetFileNumOneShard does nothing when no pslib"; +#endif +} + double FleetWrapper::GetCacheThreshold(int table_id) { #ifdef PADDLE_WITH_PSLIB double cache_threshold = 0.0; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index c1db06a298c86196b40c749b5552826b4a071b95..d368b421ff2a055aa91b0b920ee8c2f64daa8784 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -266,6 +266,7 @@ class FleetWrapper { bool load_combine); void PrintTableStat(const uint64_t table_id); + void SetFileNumOneShard(const uint64_t table_id, int file_num); // mode = 0, load all feature // mode = 1, load delta feature, which means load diff void LoadModel(const std::string& path, const int mode); diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 873476629cb78faba141cdd430c68a00ed2c21c7..d8142f717baed8e0bea3858027aa33899627f55c 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -76,6 +76,8 @@ void BindFleetWrapper(py::module* m) { .def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable) .def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable) .def("print_table_stat", &framework::FleetWrapper::PrintTableStat) + .def("set_file_num_one_shard", + &framework::FleetWrapper::SetFileNumOneShard) .def("client_flush", &framework::FleetWrapper::ClientFlush) .def("load_from_paddle_model", &framework::FleetWrapper::LoadFromPaddleModel) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 39cf3ebeb32a95fb0cf8d885fb2f8bbed2f05c65..e8d9cc3b77b6a8902bebc6a18fcb783cbd368da2 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -327,6 +327,21 @@ class PSLib(Fleet): self._fleet_ptr.print_table_stat(table_id) self._role_maker._barrier_worker() + def set_file_num_one_shard(self, table_id, file_num): + """ + set file_num in one shard + Args: + table_id(int): the id of table + file_num(int): file num in one shard + Example: + .. code-block:: python + fleet.set_file_num_one_shard(0, 5) + """ + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.set_file_num_one_shard(table_id, file_num) + self._role_maker._barrier_worker() + def save_persistables(self, executor, dirname, main_program=None, **kwargs): """ save presistable parameters,