From 79415712c2a232e0852422e752f2125abbbd65e3 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Tue, 18 May 2021 14:35:07 +0800 Subject: [PATCH] support save multi sparse table in one path (#31108) (#31125) * save multi table one path * format --- paddle/fluid/framework/fleet/fleet_wrapper.cc | 18 +++++++++++++++ paddle/fluid/framework/fleet/fleet_wrapper.h | 2 ++ paddle/fluid/pybind/fleet_wrapper_py.cc | 2 ++ .../fleet/parameter_server/pslib/__init__.py | 22 +++++++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 09ce4b1569d..f80e735bdc7 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -917,6 +917,24 @@ void FleetWrapper::LoadWithWhitelist(const uint64_t table_id, #endif } +void FleetWrapper::SaveMultiTableOnePath(const std::vector& table_ids, + const std::string& path, + const int mode) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->save_multi_table_one_path( + table_ids, path, std::to_string(mode)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "save model failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +#else + VLOG(0) << "FleetWrapper::SaveMultiTableOnePath does nothing when no pslib"; +#endif +} + void FleetWrapper::SaveModel(const std::string& path, const int mode) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode)); diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 88c2906a1bb..9d60beb7fd8 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -230,6 +230,8 @@ class FleetWrapper { // mode = 0, save all feature // mode = 1, save delta feature, which means save diff void SaveModel(const std::string& path, const int mode); + void SaveMultiTableOnePath(const std::vector& table_ids, + const std::string& path, const int mode); // mode = 0, save all feature // mode = 1, save delta feature, which means save diff void SaveModelOneTable(const uint64_t table_id, const std::string& path, diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 1e70bd9381b..873476629cb 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -57,6 +57,8 @@ void BindFleetWrapper(py::module* m) { .def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold) .def("cache_shuffle", &framework::FleetWrapper::CacheShuffle) .def("save_cache", &framework::FleetWrapper::SaveCache) + .def("save_multi_table_one_path", + &framework::FleetWrapper::SaveMultiTableOnePath) .def("save_model_with_whitelist", &framework::FleetWrapper::SaveWithWhitelist) .def("load_model", &framework::FleetWrapper::LoadModel) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index ad6843e7e2d..02f3ac44a89 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -333,6 +333,28 @@ class PSLib(Fleet): whitelist_path) self._role_maker._barrier_worker() + def save_multi_table_one_path(self, table_ids, model_dir, **kwargs): + """ + save pslib multi sparse table in one path. + Args: + table_ids(list): table ids + model_dir(str): if you use hdfs, model_dir should starts with + 'hdfs:', otherwise means local dir + kwargs(dict): user-defined properties. + mode(int): the modes illustrated above, default 0 + prefix(str): the parts to save can have prefix, + for example, part-prefix-000-00000 + Examples: + .. code-block:: python + fleet.save_multi_table_one_path("[0, 1]", "afs:/user/path/") + """ + mode = kwargs.get("mode", 0) + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.save_multi_table_one_path(table_ids, model_dir, + mode) + self._role_maker._barrier_worker() + def save_cache_model(self, executor, dirname, main_program=None, **kwargs): """ save sparse cache table, -- GitLab