未验证 提交 a47d92d8 编写于 作者: Y yaoxuefeng 提交者: GitHub

fleet add save with whitelist test=develop (#23376)

上级 f7fb4c22
...@@ -1170,6 +1170,21 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id, ...@@ -1170,6 +1170,21 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
#endif #endif
} }
void FleetWrapper::LoadWithWhitelist(const uint64_t table_id,
const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->load_with_whitelist(table_id, path,
std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id
<< ", from path: " << path << " failed";
}
#else
VLOG(0) << "FleetWrapper::LoadWhitelist does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) { void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode)); auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode));
...@@ -1285,6 +1300,26 @@ int32_t FleetWrapper::SaveCache(int table_id, const std::string& path, ...@@ -1285,6 +1300,26 @@ int32_t FleetWrapper::SaveCache(int table_id, const std::string& path,
#endif #endif
} }
int32_t FleetWrapper::SaveWithWhitelist(int table_id, const std::string& path,
const int mode,
const std::string& whitelist_path) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save_with_whitelist(
table_id, path, std::to_string(mode), whitelist_path);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "table save cache failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return feasign_cnt;
#else
VLOG(0) << "FleetWrapper::SaveCache does nothing when no pslib";
return -1;
#endif
}
void FleetWrapper::ShrinkSparseTable(int table_id) { void FleetWrapper::ShrinkSparseTable(int table_id) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->shrink(table_id); auto ret = pslib_ptr_->_worker_ptr->shrink(table_id);
......
...@@ -273,6 +273,11 @@ class FleetWrapper { ...@@ -273,6 +273,11 @@ class FleetWrapper {
// save cache model // save cache model
// cache model can speed up online predict // cache model can speed up online predict
int32_t SaveCache(int table_id, const std::string& path, const int mode); int32_t SaveCache(int table_id, const std::string& path, const int mode);
// save sparse table filtered by user-defined whitelist
int32_t SaveWithWhitelist(int table_id, const std::string& path,
const int mode, const std::string& whitelist_path);
void LoadWithWhitelist(const uint64_t table_id, const std::string& path,
const int mode);
// copy feasign key/value from src_table_id to dest_table_id // copy feasign key/value from src_table_id to dest_table_id
int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id); int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id);
// copy feasign key/value from src_table_id to dest_table_id // copy feasign key/value from src_table_id to dest_table_id
......
...@@ -57,7 +57,11 @@ void BindFleetWrapper(py::module* m) { ...@@ -57,7 +57,11 @@ void BindFleetWrapper(py::module* m) {
.def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold) .def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &framework::FleetWrapper::CacheShuffle) .def("cache_shuffle", &framework::FleetWrapper::CacheShuffle)
.def("save_cache", &framework::FleetWrapper::SaveCache) .def("save_cache", &framework::FleetWrapper::SaveCache)
.def("save_model_with_whitelist",
&framework::FleetWrapper::SaveWithWhitelist)
.def("load_model", &framework::FleetWrapper::LoadModel) .def("load_model", &framework::FleetWrapper::LoadModel)
.def("load_table_with_whitelist",
&framework::FleetWrapper::LoadWithWhitelist)
.def("clear_model", &framework::FleetWrapper::ClearModel) .def("clear_model", &framework::FleetWrapper::ClearModel)
.def("clear_one_table", &framework::FleetWrapper::ClearOneTable) .def("clear_one_table", &framework::FleetWrapper::ClearOneTable)
.def("stop_server", &framework::FleetWrapper::StopServer) .def("stop_server", &framework::FleetWrapper::StopServer)
......
...@@ -348,6 +348,41 @@ class PSLib(Fleet): ...@@ -348,6 +348,41 @@ class PSLib(Fleet):
self._fleet_ptr.save_model(dirname, mode) self._fleet_ptr.save_model(dirname, mode)
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
def save_model_with_whitelist(self,
executor,
dirname,
whitelist_path,
main_program=None,
**kwargs):
"""
save whitelist, mode is consistent with fleet.save_persistables,
when using fleet, it will save sparse and dense feature
Args:
executor(Executor): fluid executor
dirname(str): save path. It can be hdfs/afs path or local path
main_program(Program): fluid program, default None
kwargs: use define property, current support following
mode(int): 0 means save all pserver model,
1 means save delta pserver model (save diff),
2 means save xbox base,
3 means save batch model.
Example:
.. code-block:: python
fleet.save_persistables(dirname="/you/path/to/model", mode = 0)
"""
mode = kwargs.get("mode", 0)
table_id = kwargs.get("table_id", 0)
self._fleet_ptr.client_flush()
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.save_model_with_whitelist(table_id, dirname, mode,
whitelist_path)
self._role_maker._barrier_worker()
def save_cache_model(self, executor, dirname, main_program=None, **kwargs): def save_cache_model(self, executor, dirname, main_program=None, **kwargs):
""" """
save sparse cache table, save sparse cache table,
...@@ -480,6 +515,51 @@ class PSLib(Fleet): ...@@ -480,6 +515,51 @@ class PSLib(Fleet):
self._fleet_ptr.clear_model() self._fleet_ptr.clear_model()
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
def load_pslib_whitelist(self, table_id, model_path, **kwargs):
"""
load pslib model for one table with whitelist
Args:
table_id(int): load table id
model_path(str): load model path, can be local or hdfs/afs path
kwargs(dict): user defined params, currently support following:
only for load pslib model for one table:
mode(int): load model mode. 0 is for load whole model, 1 is
for load delta model (load diff), default is 0.
only for load params from paddle model:
scope(Scope): Scope object
model_proto_file(str): path of program desc proto binary
file, can be local or hdfs/afs file
var_names(list): var name list
load_combine(bool): load from a file or split param files
default False.
Examples:
.. code-block:: python
# load pslib model for one table
fleet.load_one_table(0, "hdfs:/my_fleet_model/20190714/0/")
fleet.load_one_table(1, "hdfs:/xx/xxx", mode = 0)
# load params from paddle model
fleet.load_one_table(2, "hdfs:/my_paddle_model/",
scope = my_scope,
model_proto_file = "./my_program.bin",
load_combine = False)
# below is how to save proto binary file
with open("my_program.bin", "wb") as fout:
my_program = fluid.default_main_program()
fout.write(my_program.desc.serialize_to_string())
"""
self._role_maker._barrier_worker()
mode = kwargs.get("mode", 0)
if self._role_maker.is_first_worker():
self._fleet_ptr.load_table_with_whitelist(table_id, model_path,
mode)
self._role_maker._barrier_worker()
def load_one_table(self, table_id, model_path, **kwargs): def load_one_table(self, table_id, model_path, **kwargs):
""" """
load pslib model for one table or load params from paddle model load pslib model for one table or load params from paddle model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册