diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 335cbc382c178b1a14949764f2908dc402298868..cdf210d661c73e69e125c0ebfa85cc852360e352 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -1170,6 +1170,21 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id, #endif } +void FleetWrapper::LoadWithWhitelist(const uint64_t table_id, + const std::string& path, const int mode) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->load_with_whitelist(table_id, path, + std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "load model of table id: " << table_id + << ", from path: " << path << " failed"; + } +#else + VLOG(0) << "FleetWrapper::LoadWhitelist does nothing when no pslib"; +#endif +} + void FleetWrapper::SaveModel(const std::string& path, const int mode) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode)); @@ -1285,6 +1300,26 @@ int32_t FleetWrapper::SaveCache(int table_id, const std::string& path, #endif } +int32_t FleetWrapper::SaveWithWhitelist(int table_id, const std::string& path, + const int mode, + const std::string& whitelist_path) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->save_with_whitelist( + table_id, path, std::to_string(mode), whitelist_path); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "table save cache failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } + return feasign_cnt; +#else + VLOG(0) << "FleetWrapper::SaveCache does nothing when no pslib"; + return -1; +#endif +} + void FleetWrapper::ShrinkSparseTable(int table_id) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->shrink(table_id); diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 92f3a625a755bba4989033c0cd41d9b25591c960..cc13a50160a94c63345bcbd5633f2d3f8555ae0c 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -273,6 +273,11 @@ class FleetWrapper { // save cache model // cache model can speed up online predict int32_t SaveCache(int table_id, const std::string& path, const int mode); + // save sparse table filtered by user-defined whitelist + int32_t SaveWithWhitelist(int table_id, const std::string& path, + const int mode, const std::string& whitelist_path); + void LoadWithWhitelist(const uint64_t table_id, const std::string& path, + const int mode); // copy feasign key/value from src_table_id to dest_table_id int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id); // copy feasign key/value from src_table_id to dest_table_id diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 4b72b09adddf24f63814e8e4872af289b38bcb44..1e70bd9381b9d683af82f77959db9ad680f06bd3 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -57,7 +57,11 @@ void BindFleetWrapper(py::module* m) { .def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold) .def("cache_shuffle", &framework::FleetWrapper::CacheShuffle) .def("save_cache", &framework::FleetWrapper::SaveCache) + .def("save_model_with_whitelist", + &framework::FleetWrapper::SaveWithWhitelist) .def("load_model", &framework::FleetWrapper::LoadModel) + .def("load_table_with_whitelist", + &framework::FleetWrapper::LoadWithWhitelist) .def("clear_model", &framework::FleetWrapper::ClearModel) .def("clear_one_table", &framework::FleetWrapper::ClearOneTable) .def("stop_server", &framework::FleetWrapper::StopServer) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 2a1945532e654605d2e2d45206daa3cd8306737f..f3563808d235b634c92671dfb654918001b6c8bc 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -348,6 +348,41 @@ class PSLib(Fleet): self._fleet_ptr.save_model(dirname, mode) self._role_maker._barrier_worker() + def save_model_with_whitelist(self, + executor, + dirname, + whitelist_path, + main_program=None, + **kwargs): + """ + save whitelist, mode is consistent with fleet.save_persistables, + when using fleet, it will save sparse and dense feature + + Args: + executor(Executor): fluid executor + dirname(str): save path. It can be hdfs/afs path or local path + main_program(Program): fluid program, default None + kwargs: use define property, current support following + mode(int): 0 means save all pserver model, + 1 means save delta pserver model (save diff), + 2 means save xbox base, + 3 means save batch model. + + Example: + .. code-block:: python + + fleet.save_persistables(dirname="/you/path/to/model", mode = 0) + + """ + mode = kwargs.get("mode", 0) + table_id = kwargs.get("table_id", 0) + self._fleet_ptr.client_flush() + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.save_model_with_whitelist(table_id, dirname, mode, + whitelist_path) + self._role_maker._barrier_worker() + def save_cache_model(self, executor, dirname, main_program=None, **kwargs): """ save sparse cache table, @@ -480,6 +515,51 @@ class PSLib(Fleet): self._fleet_ptr.clear_model() self._role_maker._barrier_worker() + def load_pslib_whitelist(self, table_id, model_path, **kwargs): + """ + load pslib model for one table with whitelist + + Args: + table_id(int): load table id + model_path(str): load model path, can be local or hdfs/afs path + kwargs(dict): user defined params, currently support following: + only for load pslib model for one table: + mode(int): load model mode. 0 is for load whole model, 1 is + for load delta model (load diff), default is 0. + only for load params from paddle model: + scope(Scope): Scope object + model_proto_file(str): path of program desc proto binary + file, can be local or hdfs/afs file + var_names(list): var name list + load_combine(bool): load from a file or split param files + default False. + + Examples: + .. code-block:: python + + # load pslib model for one table + fleet.load_one_table(0, "hdfs:/my_fleet_model/20190714/0/") + fleet.load_one_table(1, "hdfs:/xx/xxx", mode = 0) + + # load params from paddle model + fleet.load_one_table(2, "hdfs:/my_paddle_model/", + scope = my_scope, + model_proto_file = "./my_program.bin", + load_combine = False) + + # below is how to save proto binary file + with open("my_program.bin", "wb") as fout: + my_program = fluid.default_main_program() + fout.write(my_program.desc.serialize_to_string()) + + """ + self._role_maker._barrier_worker() + mode = kwargs.get("mode", 0) + if self._role_maker.is_first_worker(): + self._fleet_ptr.load_table_with_whitelist(table_id, model_path, + mode) + self._role_maker._barrier_worker() + def load_one_table(self, table_id, model_path, **kwargs): """ load pslib model for one table or load params from paddle model