未验证 提交 d98084e7 编写于 作者: X xujiaqi01 提交者: GitHub

add save with prefix (#23449)

* add save with prefix
* test=develop
上级 588eb8e2
......@@ -917,6 +917,38 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#endif
}
void FleetWrapper::SaveModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret =
pslib_ptr_->_worker_ptr->save(table_id, path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "save model of table id: " << table_id
<< ", to path: " << path << " failed";
}
#else
VLOG(0) << "FleetWrapper::SaveModelOneTable does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModelOneTablePrefix(const uint64_t table_id,
const std::string& path,
const int mode,
const std::string& prefix) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(table_id, path, std::to_string(mode),
prefix);
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "save model (with prefix) of table id: " << table_id
<< ", to path: " << path << " failed";
}
#else
VLOG(0) << "FleetWrapper::SaveModelOneTablePrefix does nothing when no pslib";
#endif
}
void FleetWrapper::PrintTableStat(const uint64_t table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->print_table_stat(table_id);
......
......@@ -221,15 +221,22 @@ class FleetWrapper {
void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
// mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
// mode = 1, load delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// save model with prefix
void SaveModelOneTablePrefix(const uint64_t table_id, const std::string& path,
const int mode, const std::string& prefix);
// get save cache threshold
double GetCacheThreshold(int table_id);
// shuffle cache model between servers
......
......@@ -61,5 +61,7 @@ TEST(TEST_FLEET, fleet_1) {
#ifdef PADDLE_WITH_PSLIB
#else
fleet->RunServer("", 0);
fleet->SaveModelOneTable(0, "", 0);
fleet->SaveModelOneTablePrefix(0, "", 0, "");
#endif
}
......@@ -78,6 +78,9 @@ void BindFleetWrapper(py::module* m) {
&framework::FleetWrapper::SetClient2ClientConfig)
.def("set_pull_local_thread_num",
&framework::FleetWrapper::SetPullLocalThreadNum)
.def("save_model_one_table", &framework::FleetWrapper::SaveModelOneTable)
.def("save_model_one_table_with_prefix",
&framework::FleetWrapper::SaveModelOneTablePrefix)
.def("copy_table", &framework::FleetWrapper::CopyTable)
.def("copy_table_by_feasign",
&framework::FleetWrapper::CopyTableByFeasign);
......
......@@ -567,6 +567,88 @@ class PSLib(Fleet):
model_proto_file, table_var_names, load_combine)
self._role_maker._barrier_worker()
def load_model(self, model_dir=None, **kwargs):
"""
load pslib model, there are at least 4 modes, these modes are the same
in load one table/save model/save one table:
0: load checkpoint model
1: load delta model (delta means diff, it's usually for online predict)
2: load base model (base model filters some feasigns in checkpoint, it's
usually for online predict)
3: load batch model (do some statistic works in checkpoint, such as
calculate unseen days of each feasign)
Args:
model_dir(str): if you use hdfs, model_dir should starts with
'hdfs:', otherwise means local dir
kwargs(dict): user-defined properties.
mode(int): the modes illustrated above, default 0
Examples:
.. code-block:: python
fleet.load_model("afs:/user/path/")
"""
mode = kwargs.get("mode", 0)
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.load_model(model_dir, mode)
self._role_maker._barrier_worker()
def save_model(self, model_dir=None, **kwargs):
"""
save pslib model, the modes are same with load model.
Args:
model_dir(str): if you use hdfs, model_dir should starts with
'hdfs:', otherwise means local dir
kwargs(dict): user-defined properties.
mode(int): the modes illustrated above, default 0
Examples:
.. code-block:: python
fleet.save_model("afs:/user/path/")
"""
mode = kwargs.get("mode", 0)
prefix = kwargs.get("prefix", None)
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.save_model(model_dir, mode)
self._role_maker._barrier_worker()
def save_one_table(self, table_id, model_dir, **kwargs):
"""
save pslib model's one table, the modes are same with load model.
Args:
table_id(int): table id
model_dir(str): if you use hdfs, model_dir should starts with
'hdfs:', otherwise means local dir
kwargs(dict): user-defined properties.
mode(int): the modes illustrated above, default 0
prefix(str): the parts to save can have prefix,
for example, part-prefix-000-00000
Examples:
.. code-block:: python
fleet.save_one_table("afs:/user/path/")
"""
mode = kwargs.get("mode", 0)
prefix = kwargs.get("prefix", None)
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
if prefix is not None:
self._fleet_ptr.save_model_one_table_with_prefix(
table_id, model_dir, mode, prefix)
else:
self._fleet_ptr.save_model_one_table(table_id, model_dir, mode)
self._role_maker._barrier_worker()
def _set_opt_info(self, opt_info):
"""
this function saves the result from DistributedOptimizer.minimize()
......
......@@ -21,7 +21,8 @@ import paddle.fluid.incubate.fleet.base.role_maker as role_maker
class TestFleet1(unittest.TestCase):
"""
Test cases for fleet minimize.
Test cases for fleet minimize,
and some other fleet apu tests.
"""
def setUp(self):
......@@ -80,6 +81,25 @@ class TestFleet1(unittest.TestCase):
except:
print("do not support pslib test, skip")
return
try:
# worker should call these methods instead of server
# the following is only for test when with_pslib=off
def test_func():
"""
it is only a test function
"""
return True
fleet._role_maker.is_first_worker = test_func
fleet._role_maker._barrier_worker = test_func
fleet.save_model("./model_000")
fleet.save_one_table(0, "./model_001")
fleet.save_one_table(0, "./model_002", prefix="hahaha")
fleet.load_model("./model_0003")
fleet.load_one_table(0, "./model_004")
except:
print("do not support pslib test, skip")
return
if __name__ == "__main__":
......
......@@ -221,6 +221,11 @@ class TestFleet2(unittest.TestCase):
emb1 = fluid.embedding(input=show, size=[1, 1], \
is_sparse=True, is_distributed=True, \
param_attr=fluid.ParamAttr(name="embedding"))
fleet.save_model("./tmodel_000")
fleet.save_one_table(0, "./tmodel_001")
fleet.save_one_table(0, "./tmodel_002", prefix="thahaha")
fleet.load_model("./tmodel_0003")
fleet.load_one_table(0, "./tmodel_004")
if __name__ == "__main__":
......
......@@ -80,6 +80,25 @@ class TestFleet1(unittest.TestCase):
except:
print("do not support pslib test, skip")
return
try:
# worker should call these methods instead of server
# the following is only for test when with_pslib=off
def test_func():
"""
it is only a test function
"""
return True
fleet._role_maker.is_first_worker = test_func
fleet._role_maker._barrier_worker = test_func
fleet.save_model("./model_000")
fleet.save_one_table(0, "./model_001")
fleet.save_one_table(0, "./model_002", prefix="hahaha")
fleet.load_model("./model_0003")
fleet.load_one_table(0, "./model_004")
except:
print("do not support pslib test, skip")
return
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册