未验证 提交 565354f6 编写于 作者: T Thunderbrook 提交者: GitHub

support save multi sparse table in one path (#31108)

* save multi table one path

* format
上级 50967135
...@@ -1231,6 +1231,24 @@ void FleetWrapper::LoadWithWhitelist(const uint64_t table_id, ...@@ -1231,6 +1231,24 @@ void FleetWrapper::LoadWithWhitelist(const uint64_t table_id,
#endif #endif
} }
void FleetWrapper::SaveMultiTableOnePath(const std::vector<int>& table_ids,
const std::string& path,
const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save_multi_table_one_path(
table_ids, path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "save model failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
#else
VLOG(0) << "FleetWrapper::SaveMultiTableOnePath does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) { void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode)); auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode));
......
...@@ -272,6 +272,8 @@ class FleetWrapper { ...@@ -272,6 +272,8 @@ class FleetWrapper {
// mode = 0, save all feature // mode = 0, save all feature
// mode = 1, save delta feature, which means save diff // mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode); void SaveModel(const std::string& path, const int mode);
void SaveMultiTableOnePath(const std::vector<int>& table_ids,
const std::string& path, const int mode);
// mode = 0, save all feature // mode = 0, save all feature
// mode = 1, save delta feature, which means save diff // mode = 1, save delta feature, which means save diff
void SaveModelOneTable(const uint64_t table_id, const std::string& path, void SaveModelOneTable(const uint64_t table_id, const std::string& path,
......
...@@ -57,6 +57,8 @@ void BindFleetWrapper(py::module* m) { ...@@ -57,6 +57,8 @@ void BindFleetWrapper(py::module* m) {
.def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold) .def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &framework::FleetWrapper::CacheShuffle) .def("cache_shuffle", &framework::FleetWrapper::CacheShuffle)
.def("save_cache", &framework::FleetWrapper::SaveCache) .def("save_cache", &framework::FleetWrapper::SaveCache)
.def("save_multi_table_one_path",
&framework::FleetWrapper::SaveMultiTableOnePath)
.def("save_model_with_whitelist", .def("save_model_with_whitelist",
&framework::FleetWrapper::SaveWithWhitelist) &framework::FleetWrapper::SaveWithWhitelist)
.def("load_model", &framework::FleetWrapper::LoadModel) .def("load_model", &framework::FleetWrapper::LoadModel)
......
...@@ -385,6 +385,28 @@ class PSLib(Fleet): ...@@ -385,6 +385,28 @@ class PSLib(Fleet):
whitelist_path) whitelist_path)
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
def save_multi_table_one_path(self, table_ids, model_dir, **kwargs):
"""
save pslib multi sparse table in one path.
Args:
table_ids(list): table ids
model_dir(str): if you use hdfs, model_dir should starts with
'hdfs:', otherwise means local dir
kwargs(dict): user-defined properties.
mode(int): the modes illustrated above, default 0
prefix(str): the parts to save can have prefix,
for example, part-prefix-000-00000
Examples:
.. code-block:: python
fleet.save_multi_table_one_path("[0, 1]", "afs:/user/path/")
"""
mode = kwargs.get("mode", 0)
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.save_multi_table_one_path(table_ids, model_dir,
mode)
self._role_maker._barrier_worker()
def save_cache_model(self, executor, dirname, main_program=None, **kwargs): def save_cache_model(self, executor, dirname, main_program=None, **kwargs):
""" """
save sparse cache table, save sparse cache table,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册