diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 09ce4b1569d58c16fb5aff1b0f0975828fd50640..f80e735bdc7cc2892b67de8890fd6e2411498d2f 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -917,6 +917,24 @@ void FleetWrapper::LoadWithWhitelist(const uint64_t table_id, #endif } +void FleetWrapper::SaveMultiTableOnePath(const std::vector& table_ids, + const std::string& path, + const int mode) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->save_multi_table_one_path( + table_ids, path, std::to_string(mode)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "save model failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +#else + VLOG(0) << "FleetWrapper::SaveMultiTableOnePath does nothing when no pslib"; +#endif +} + void FleetWrapper::SaveModel(const std::string& path, const int mode) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode)); diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 88c2906a1bb6edda2849a98cada7d976ed491ce2..9d60beb7fd839f75f30c449cfa195431e583260c 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -230,6 +230,8 @@ class FleetWrapper { // mode = 0, save all feature // mode = 1, save delta feature, which means save diff void SaveModel(const std::string& path, const int mode); + void SaveMultiTableOnePath(const std::vector& table_ids, + const std::string& path, const int mode); // mode = 0, save all feature // mode = 1, save delta feature, which means save diff void SaveModelOneTable(const uint64_t table_id, const std::string& path, diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 1e70bd9381b9d683af82f77959db9ad680f06bd3..873476629cb78faba141cdd430c68a00ed2c21c7 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -57,6 +57,8 @@ void BindFleetWrapper(py::module* m) { .def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold) .def("cache_shuffle", &framework::FleetWrapper::CacheShuffle) .def("save_cache", &framework::FleetWrapper::SaveCache) + .def("save_multi_table_one_path", + &framework::FleetWrapper::SaveMultiTableOnePath) .def("save_model_with_whitelist", &framework::FleetWrapper::SaveWithWhitelist) .def("load_model", &framework::FleetWrapper::LoadModel) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index ad6843e7e2dce2ae46488202c4d9b94a8562a59f..02f3ac44a892919931fde3dfcffd92a7fc669165 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -333,6 +333,28 @@ class PSLib(Fleet): whitelist_path) self._role_maker._barrier_worker() + def save_multi_table_one_path(self, table_ids, model_dir, **kwargs): + """ + save pslib multi sparse table in one path. + Args: + table_ids(list): table ids + model_dir(str): if you use hdfs, model_dir should starts with + 'hdfs:', otherwise means local dir + kwargs(dict): user-defined properties. + mode(int): the modes illustrated above, default 0 + prefix(str): the parts to save can have prefix, + for example, part-prefix-000-00000 + Examples: + .. code-block:: python + fleet.save_multi_table_one_path("[0, 1]", "afs:/user/path/") + """ + mode = kwargs.get("mode", 0) + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.save_multi_table_one_path(table_ids, model_dir, + mode) + self._role_maker._barrier_worker() + def save_cache_model(self, executor, dirname, main_program=None, **kwargs): """ save sparse cache table,