diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 055c8347ecf15c5d314def73be7e1d966b4e8cfa..425c8a9f2a72a9f3d103392c1281c284c12d2073 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -1231,6 +1231,24 @@ void FleetWrapper::LoadWithWhitelist(const uint64_t table_id, #endif } +void FleetWrapper::SaveMultiTableOnePath(const std::vector& table_ids, + const std::string& path, + const int mode) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->save_multi_table_one_path( + table_ids, path, std::to_string(mode)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "save model failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +#else + VLOG(0) << "FleetWrapper::SaveMultiTableOnePath does nothing when no pslib"; +#endif +} + void FleetWrapper::SaveModel(const std::string& path, const int mode) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode)); diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index c2f89e336a41a8df75c097164f69a9fe31bb25dc..aa0da8286269fea1818eef0de770256a016d1e56 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -272,6 +272,8 @@ class FleetWrapper { // mode = 0, save all feature // mode = 1, save delta feature, which means save diff void SaveModel(const std::string& path, const int mode); + void SaveMultiTableOnePath(const std::vector& table_ids, + const std::string& path, const int mode); // mode = 0, save all feature // mode = 1, save delta feature, which means save diff void SaveModelOneTable(const uint64_t table_id, const std::string& path, diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 1e70bd9381b9d683af82f77959db9ad680f06bd3..873476629cb78faba141cdd430c68a00ed2c21c7 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -57,6 +57,8 @@ void BindFleetWrapper(py::module* m) { .def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold) .def("cache_shuffle", &framework::FleetWrapper::CacheShuffle) .def("save_cache", &framework::FleetWrapper::SaveCache) + .def("save_multi_table_one_path", + &framework::FleetWrapper::SaveMultiTableOnePath) .def("save_model_with_whitelist", &framework::FleetWrapper::SaveWithWhitelist) .def("load_model", &framework::FleetWrapper::LoadModel) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 2bfc19b013708d14bf74a185adf7ea55119d0236..49c262607498c6b86572acd5d5b520bd185b686e 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -385,6 +385,28 @@ class PSLib(Fleet): whitelist_path) self._role_maker._barrier_worker() + def save_multi_table_one_path(self, table_ids, model_dir, **kwargs): + """ + save pslib multi sparse table in one path. + Args: + table_ids(list): table ids + model_dir(str): if you use hdfs, model_dir should starts with + 'hdfs:', otherwise means local dir + kwargs(dict): user-defined properties. + mode(int): the modes illustrated above, default 0 + prefix(str): the parts to save can have prefix, + for example, part-prefix-000-00000 + Examples: + .. code-block:: python + fleet.save_multi_table_one_path("[0, 1]", "afs:/user/path/") + """ + mode = kwargs.get("mode", 0) + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.save_multi_table_one_path(table_ids, model_dir, + mode) + self._role_maker._barrier_worker() + def save_cache_model(self, executor, dirname, main_program=None, **kwargs): """ save sparse cache table,