From d98084e7ec550ff44819ee18ab3c9da920178345 Mon Sep 17 00:00:00 2001 From: xujiaqi01 <173596896@qq.com> Date: Sat, 11 Apr 2020 18:38:08 +0800 Subject: [PATCH] add save with prefix (#23449) * add save with prefix * test=develop --- paddle/fluid/framework/fleet/fleet_wrapper.cc | 32 ++++++++ paddle/fluid/framework/fleet/fleet_wrapper.h | 11 ++- paddle/fluid/framework/fleet/test_fleet.cc | 2 + paddle/fluid/pybind/fleet_wrapper_py.cc | 3 + .../fleet/parameter_server/pslib/__init__.py | 82 +++++++++++++++++++ .../fluid/tests/unittests/test_fleet.py | 22 ++++- .../fluid/tests/unittests/test_fleet_1.py | 5 ++ .../tests/unittests/test_fleet_nocvm_1.py | 19 +++++ 8 files changed, 173 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 207ce748c1..4d632d7376 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -917,6 +917,38 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { #endif } +void FleetWrapper::SaveModelOneTable(const uint64_t table_id, + const std::string& path, const int mode) { +#ifdef PADDLE_WITH_PSLIB + auto ret = + pslib_ptr_->_worker_ptr->save(table_id, path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "save model of table id: " << table_id + << ", to path: " << path << " failed"; + } +#else + VLOG(0) << "FleetWrapper::SaveModelOneTable does nothing when no pslib"; +#endif +} + +void FleetWrapper::SaveModelOneTablePrefix(const uint64_t table_id, + const std::string& path, + const int mode, + const std::string& prefix) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->save(table_id, path, std::to_string(mode), + prefix); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "save model (with prefix) of table id: " << table_id + << ", to path: " << path << " failed"; + } +#else + VLOG(0) << "FleetWrapper::SaveModelOneTablePrefix does nothing when no pslib"; +#endif +} + void FleetWrapper::PrintTableStat(const uint64_t table_id) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->print_table_stat(table_id); diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index afc97e01ea..933b0a8bd8 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -221,15 +221,22 @@ class FleetWrapper { void PrintTableStat(const uint64_t table_id); // mode = 0, load all feature - // mode = 1, laod delta feature, which means load diff + // mode = 1, load delta feature, which means load diff void LoadModel(const std::string& path, const int mode); // mode = 0, load all feature - // mode = 1, laod delta feature, which means load diff + // mode = 1, load delta feature, which means load diff void LoadModelOneTable(const uint64_t table_id, const std::string& path, const int mode); // mode = 0, save all feature // mode = 1, save delta feature, which means save diff void SaveModel(const std::string& path, const int mode); + // mode = 0, save all feature + // mode = 1, save delta feature, which means save diff + void SaveModelOneTable(const uint64_t table_id, const std::string& path, + const int mode); + // save model with prefix + void SaveModelOneTablePrefix(const uint64_t table_id, const std::string& path, + const int mode, const std::string& prefix); // get save cache threshold double GetCacheThreshold(int table_id); // shuffle cache model between servers diff --git a/paddle/fluid/framework/fleet/test_fleet.cc b/paddle/fluid/framework/fleet/test_fleet.cc index 5a3fd132d7..bf9928789c 100644 --- a/paddle/fluid/framework/fleet/test_fleet.cc +++ b/paddle/fluid/framework/fleet/test_fleet.cc @@ -61,5 +61,7 @@ TEST(TEST_FLEET, fleet_1) { #ifdef PADDLE_WITH_PSLIB #else fleet->RunServer("", 0); + fleet->SaveModelOneTable(0, "", 0); + fleet->SaveModelOneTablePrefix(0, "", 0, ""); #endif } diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 3b4505c611..3ae4eef449 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -78,6 +78,9 @@ void BindFleetWrapper(py::module* m) { &framework::FleetWrapper::SetClient2ClientConfig) .def("set_pull_local_thread_num", &framework::FleetWrapper::SetPullLocalThreadNum) + .def("save_model_one_table", &framework::FleetWrapper::SaveModelOneTable) + .def("save_model_one_table_with_prefix", + &framework::FleetWrapper::SaveModelOneTablePrefix) .def("copy_table", &framework::FleetWrapper::CopyTable) .def("copy_table_by_feasign", &framework::FleetWrapper::CopyTableByFeasign); diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 7dfe8f7e7d..210640f64c 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -567,6 +567,88 @@ class PSLib(Fleet): model_proto_file, table_var_names, load_combine) self._role_maker._barrier_worker() + def load_model(self, model_dir=None, **kwargs): + """ + load pslib model, there are at least 4 modes, these modes are the same + in load one table/save model/save one table: + 0: load checkpoint model + 1: load delta model (delta means diff, it's usually for online predict) + 2: load base model (base model filters some feasigns in checkpoint, it's + usually for online predict) + 3: load batch model (do some statistic works in checkpoint, such as + calculate unseen days of each feasign) + + Args: + model_dir(str): if you use hdfs, model_dir should starts with + 'hdfs:', otherwise means local dir + kwargs(dict): user-defined properties. + mode(int): the modes illustrated above, default 0 + + Examples: + .. code-block:: python + + fleet.load_model("afs:/user/path/") + + """ + mode = kwargs.get("mode", 0) + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.load_model(model_dir, mode) + self._role_maker._barrier_worker() + + def save_model(self, model_dir=None, **kwargs): + """ + save pslib model, the modes are same with load model. + + Args: + model_dir(str): if you use hdfs, model_dir should starts with + 'hdfs:', otherwise means local dir + kwargs(dict): user-defined properties. + mode(int): the modes illustrated above, default 0 + + Examples: + .. code-block:: python + + fleet.save_model("afs:/user/path/") + + """ + mode = kwargs.get("mode", 0) + prefix = kwargs.get("prefix", None) + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.save_model(model_dir, mode) + self._role_maker._barrier_worker() + + def save_one_table(self, table_id, model_dir, **kwargs): + """ + save pslib model's one table, the modes are same with load model. + + Args: + table_id(int): table id + model_dir(str): if you use hdfs, model_dir should starts with + 'hdfs:', otherwise means local dir + kwargs(dict): user-defined properties. + mode(int): the modes illustrated above, default 0 + prefix(str): the parts to save can have prefix, + for example, part-prefix-000-00000 + + Examples: + .. code-block:: python + + fleet.save_one_table("afs:/user/path/") + + """ + mode = kwargs.get("mode", 0) + prefix = kwargs.get("prefix", None) + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + if prefix is not None: + self._fleet_ptr.save_model_one_table_with_prefix( + table_id, model_dir, mode, prefix) + else: + self._fleet_ptr.save_model_one_table(table_id, model_dir, mode) + self._role_maker._barrier_worker() + def _set_opt_info(self, opt_info): """ this function saves the result from DistributedOptimizer.minimize() diff --git a/python/paddle/fluid/tests/unittests/test_fleet.py b/python/paddle/fluid/tests/unittests/test_fleet.py index 5f508917ef..6657f5a120 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet.py +++ b/python/paddle/fluid/tests/unittests/test_fleet.py @@ -21,7 +21,8 @@ import paddle.fluid.incubate.fleet.base.role_maker as role_maker class TestFleet1(unittest.TestCase): """ - Test cases for fleet minimize. + Test cases for fleet minimize, + and some other fleet apu tests. """ def setUp(self): @@ -80,6 +81,25 @@ class TestFleet1(unittest.TestCase): except: print("do not support pslib test, skip") return + try: + # worker should call these methods instead of server + # the following is only for test when with_pslib=off + def test_func(): + """ + it is only a test function + """ + return True + + fleet._role_maker.is_first_worker = test_func + fleet._role_maker._barrier_worker = test_func + fleet.save_model("./model_000") + fleet.save_one_table(0, "./model_001") + fleet.save_one_table(0, "./model_002", prefix="hahaha") + fleet.load_model("./model_0003") + fleet.load_one_table(0, "./model_004") + except: + print("do not support pslib test, skip") + return if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_fleet_1.py b/python/paddle/fluid/tests/unittests/test_fleet_1.py index 7f221494d6..eaca009dd4 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_1.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_1.py @@ -221,6 +221,11 @@ class TestFleet2(unittest.TestCase): emb1 = fluid.embedding(input=show, size=[1, 1], \ is_sparse=True, is_distributed=True, \ param_attr=fluid.ParamAttr(name="embedding")) + fleet.save_model("./tmodel_000") + fleet.save_one_table(0, "./tmodel_001") + fleet.save_one_table(0, "./tmodel_002", prefix="thahaha") + fleet.load_model("./tmodel_0003") + fleet.load_one_table(0, "./tmodel_004") if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py b/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py index ef655d1999..5aa2867343 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py @@ -80,6 +80,25 @@ class TestFleet1(unittest.TestCase): except: print("do not support pslib test, skip") return + try: + # worker should call these methods instead of server + # the following is only for test when with_pslib=off + def test_func(): + """ + it is only a test function + """ + return True + + fleet._role_maker.is_first_worker = test_func + fleet._role_maker._barrier_worker = test_func + fleet.save_model("./model_000") + fleet.save_one_table(0, "./model_001") + fleet.save_one_table(0, "./model_002", prefix="hahaha") + fleet.load_model("./model_0003") + fleet.load_one_table(0, "./model_004") + except: + print("do not support pslib test, skip") + return if __name__ == "__main__": -- GitLab