diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 207ce748c1b46881c69fb0a73118da276f06aca4..4d632d737611e9c04d8f60d34bbb9518266e2810 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -917,6 +917,38 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { #endif } +void FleetWrapper::SaveModelOneTable(const uint64_t table_id, + const std::string& path, const int mode) { +#ifdef PADDLE_WITH_PSLIB + auto ret = + pslib_ptr_->_worker_ptr->save(table_id, path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "save model of table id: " << table_id + << ", to path: " << path << " failed"; + } +#else + VLOG(0) << "FleetWrapper::SaveModelOneTable does nothing when no pslib"; +#endif +} + +void FleetWrapper::SaveModelOneTablePrefix(const uint64_t table_id, + const std::string& path, + const int mode, + const std::string& prefix) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->save(table_id, path, std::to_string(mode), + prefix); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "save model (with prefix) of table id: " << table_id + << ", to path: " << path << " failed"; + } +#else + VLOG(0) << "FleetWrapper::SaveModelOneTablePrefix does nothing when no pslib"; +#endif +} + void FleetWrapper::PrintTableStat(const uint64_t table_id) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->print_table_stat(table_id); diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index afc97e01eaebd84f0167cf9ecd3d10850b4c7ef5..933b0a8bd852cfa2b1fb1b485cfd3ef2d012dd11 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -221,15 +221,22 @@ class FleetWrapper { void PrintTableStat(const uint64_t table_id); // mode = 0, load all feature - // mode = 1, laod delta feature, which means load diff + // mode = 1, load delta feature, which means load diff void LoadModel(const std::string& path, const int mode); // mode = 0, load all feature - // mode = 1, laod delta feature, which means load diff + // mode = 1, load delta feature, which means load diff void LoadModelOneTable(const uint64_t table_id, const std::string& path, const int mode); // mode = 0, save all feature // mode = 1, save delta feature, which means save diff void SaveModel(const std::string& path, const int mode); + // mode = 0, save all feature + // mode = 1, save delta feature, which means save diff + void SaveModelOneTable(const uint64_t table_id, const std::string& path, + const int mode); + // save model with prefix + void SaveModelOneTablePrefix(const uint64_t table_id, const std::string& path, + const int mode, const std::string& prefix); // get save cache threshold double GetCacheThreshold(int table_id); // shuffle cache model between servers diff --git a/paddle/fluid/framework/fleet/test_fleet.cc b/paddle/fluid/framework/fleet/test_fleet.cc index 5a3fd132d7e2e86d1736ca055ff3ec0612dfde51..bf9928789cafeb1455bb241b76482bf78a737871 100644 --- a/paddle/fluid/framework/fleet/test_fleet.cc +++ b/paddle/fluid/framework/fleet/test_fleet.cc @@ -61,5 +61,7 @@ TEST(TEST_FLEET, fleet_1) { #ifdef PADDLE_WITH_PSLIB #else fleet->RunServer("", 0); + fleet->SaveModelOneTable(0, "", 0); + fleet->SaveModelOneTablePrefix(0, "", 0, ""); #endif } diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 3b4505c611b283648f4da1d36f0200bb3e439d8a..3ae4eef44910031dd9f93c3286d404c51c11c1fe 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -78,6 +78,9 @@ void BindFleetWrapper(py::module* m) { &framework::FleetWrapper::SetClient2ClientConfig) .def("set_pull_local_thread_num", &framework::FleetWrapper::SetPullLocalThreadNum) + .def("save_model_one_table", &framework::FleetWrapper::SaveModelOneTable) + .def("save_model_one_table_with_prefix", + &framework::FleetWrapper::SaveModelOneTablePrefix) .def("copy_table", &framework::FleetWrapper::CopyTable) .def("copy_table_by_feasign", &framework::FleetWrapper::CopyTableByFeasign); diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 7dfe8f7e7d7178d9f89145453daa6aa78d97c5b4..210640f64c12c220304d31a5083de4e64756fd0d 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -567,6 +567,88 @@ class PSLib(Fleet): model_proto_file, table_var_names, load_combine) self._role_maker._barrier_worker() + def load_model(self, model_dir=None, **kwargs): + """ + load pslib model, there are at least 4 modes, these modes are the same + in load one table/save model/save one table: + 0: load checkpoint model + 1: load delta model (delta means diff, it's usually for online predict) + 2: load base model (base model filters some feasigns in checkpoint, it's + usually for online predict) + 3: load batch model (do some statistic works in checkpoint, such as + calculate unseen days of each feasign) + + Args: + model_dir(str): if you use hdfs, model_dir should starts with + 'hdfs:', otherwise means local dir + kwargs(dict): user-defined properties. + mode(int): the modes illustrated above, default 0 + + Examples: + .. code-block:: python + + fleet.load_model("afs:/user/path/") + + """ + mode = kwargs.get("mode", 0) + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.load_model(model_dir, mode) + self._role_maker._barrier_worker() + + def save_model(self, model_dir=None, **kwargs): + """ + save pslib model, the modes are same with load model. + + Args: + model_dir(str): if you use hdfs, model_dir should starts with + 'hdfs:', otherwise means local dir + kwargs(dict): user-defined properties. + mode(int): the modes illustrated above, default 0 + + Examples: + .. code-block:: python + + fleet.save_model("afs:/user/path/") + + """ + mode = kwargs.get("mode", 0) + prefix = kwargs.get("prefix", None) + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.save_model(model_dir, mode) + self._role_maker._barrier_worker() + + def save_one_table(self, table_id, model_dir, **kwargs): + """ + save pslib model's one table, the modes are same with load model. + + Args: + table_id(int): table id + model_dir(str): if you use hdfs, model_dir should starts with + 'hdfs:', otherwise means local dir + kwargs(dict): user-defined properties. + mode(int): the modes illustrated above, default 0 + prefix(str): the parts to save can have prefix, + for example, part-prefix-000-00000 + + Examples: + .. code-block:: python + + fleet.save_one_table("afs:/user/path/") + + """ + mode = kwargs.get("mode", 0) + prefix = kwargs.get("prefix", None) + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + if prefix is not None: + self._fleet_ptr.save_model_one_table_with_prefix( + table_id, model_dir, mode, prefix) + else: + self._fleet_ptr.save_model_one_table(table_id, model_dir, mode) + self._role_maker._barrier_worker() + def _set_opt_info(self, opt_info): """ this function saves the result from DistributedOptimizer.minimize() diff --git a/python/paddle/fluid/tests/unittests/test_fleet.py b/python/paddle/fluid/tests/unittests/test_fleet.py index 5f508917ef51ba1f81a6aae528530c0421205585..6657f5a1202de222e8a0b572f2d566b6765fbefe 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet.py +++ b/python/paddle/fluid/tests/unittests/test_fleet.py @@ -21,7 +21,8 @@ import paddle.fluid.incubate.fleet.base.role_maker as role_maker class TestFleet1(unittest.TestCase): """ - Test cases for fleet minimize. + Test cases for fleet minimize, + and some other fleet apu tests. """ def setUp(self): @@ -80,6 +81,25 @@ class TestFleet1(unittest.TestCase): except: print("do not support pslib test, skip") return + try: + # worker should call these methods instead of server + # the following is only for test when with_pslib=off + def test_func(): + """ + it is only a test function + """ + return True + + fleet._role_maker.is_first_worker = test_func + fleet._role_maker._barrier_worker = test_func + fleet.save_model("./model_000") + fleet.save_one_table(0, "./model_001") + fleet.save_one_table(0, "./model_002", prefix="hahaha") + fleet.load_model("./model_0003") + fleet.load_one_table(0, "./model_004") + except: + print("do not support pslib test, skip") + return if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_fleet_1.py b/python/paddle/fluid/tests/unittests/test_fleet_1.py index 7f221494d6773db4fb0808ed94a4f899e80c4714..eaca009dd4a13f26473a9a445d20946d43441b3e 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_1.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_1.py @@ -221,6 +221,11 @@ class TestFleet2(unittest.TestCase): emb1 = fluid.embedding(input=show, size=[1, 1], \ is_sparse=True, is_distributed=True, \ param_attr=fluid.ParamAttr(name="embedding")) + fleet.save_model("./tmodel_000") + fleet.save_one_table(0, "./tmodel_001") + fleet.save_one_table(0, "./tmodel_002", prefix="thahaha") + fleet.load_model("./tmodel_0003") + fleet.load_one_table(0, "./tmodel_004") if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py b/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py index ef655d1999a87ba3a80ff1318e7697bd02217de9..5aa28673437b06de9376e1944a06d72f2668c385 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_nocvm_1.py @@ -80,6 +80,25 @@ class TestFleet1(unittest.TestCase): except: print("do not support pslib test, skip") return + try: + # worker should call these methods instead of server + # the following is only for test when with_pslib=off + def test_func(): + """ + it is only a test function + """ + return True + + fleet._role_maker.is_first_worker = test_func + fleet._role_maker._barrier_worker = test_func + fleet.save_model("./model_000") + fleet.save_one_table(0, "./model_001") + fleet.save_one_table(0, "./model_002", prefix="hahaha") + fleet.load_model("./model_0003") + fleet.load_one_table(0, "./model_004") + except: + print("do not support pslib test, skip") + return if __name__ == "__main__":