未验证 提交 991dc67d 编写于 作者: T Thunderbrook 提交者: GitHub

set file_num in one shard (#35835)

* set file_num in one shard

* format
上级 49c8253f
...@@ -1347,6 +1347,20 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) { ...@@ -1347,6 +1347,20 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) {
#endif #endif
} }
void FleetWrapper::SetFileNumOneShard(const uint64_t table_id, int file_num) {
#ifdef PADDLE_WITH_PSLIB
auto ret =
pslib_ptr_->_worker_ptr->set_file_num_one_shard(table_id, file_num);
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
LOG(ERROR) << "set_file_num_one_shard failed";
}
#else
VLOG(0) << "FleetWrapper::SetFileNumOneShard does nothing when no pslib";
#endif
}
double FleetWrapper::GetCacheThreshold(int table_id) { double FleetWrapper::GetCacheThreshold(int table_id) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
double cache_threshold = 0.0; double cache_threshold = 0.0;
......
...@@ -266,6 +266,7 @@ class FleetWrapper { ...@@ -266,6 +266,7 @@ class FleetWrapper {
bool load_combine); bool load_combine);
void PrintTableStat(const uint64_t table_id); void PrintTableStat(const uint64_t table_id);
void SetFileNumOneShard(const uint64_t table_id, int file_num);
// mode = 0, load all feature // mode = 0, load all feature
// mode = 1, load delta feature, which means load diff // mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const int mode); void LoadModel(const std::string& path, const int mode);
......
...@@ -76,6 +76,8 @@ void BindFleetWrapper(py::module* m) { ...@@ -76,6 +76,8 @@ void BindFleetWrapper(py::module* m) {
.def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable) .def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable)
.def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable) .def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable)
.def("print_table_stat", &framework::FleetWrapper::PrintTableStat) .def("print_table_stat", &framework::FleetWrapper::PrintTableStat)
.def("set_file_num_one_shard",
&framework::FleetWrapper::SetFileNumOneShard)
.def("client_flush", &framework::FleetWrapper::ClientFlush) .def("client_flush", &framework::FleetWrapper::ClientFlush)
.def("load_from_paddle_model", .def("load_from_paddle_model",
&framework::FleetWrapper::LoadFromPaddleModel) &framework::FleetWrapper::LoadFromPaddleModel)
......
...@@ -327,6 +327,21 @@ class PSLib(Fleet): ...@@ -327,6 +327,21 @@ class PSLib(Fleet):
self._fleet_ptr.print_table_stat(table_id) self._fleet_ptr.print_table_stat(table_id)
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
def set_file_num_one_shard(self, table_id, file_num):
"""
set file_num in one shard
Args:
table_id(int): the id of table
file_num(int): file num in one shard
Example:
.. code-block:: python
fleet.set_file_num_one_shard(0, 5)
"""
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.set_file_num_one_shard(table_id, file_num)
self._role_maker._barrier_worker()
def save_persistables(self, executor, dirname, main_program=None, **kwargs): def save_persistables(self, executor, dirname, main_program=None, **kwargs):
""" """
save presistable parameters, save presistable parameters,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册