diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index c4c64881a82b56be57d85e53b9abde1c8082d7bb..40a8ad66f9c8c24b4d097d07e986d513262f7f6f 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -560,6 +560,19 @@ void FleetWrapper::SaveModel(const std::string& path, const int mode) { #endif } +void FleetWrapper::PrintTableStat(const uint64_t table_id) { +#ifdef PADDLE_WITH_PSLIB + auto ret = pslib_ptr_->_worker_ptr->print_table_stat(table_id); + ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "print table stat failed"; + } +#else + VLOG(0) << "FleetWrapper::PrintTableStat does nothing when no pslib"; +#endif +} + double FleetWrapper::GetCacheThreshold(int table_id) { #ifdef PADDLE_WITH_PSLIB double cache_threshold = 0.0; diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index aa93e8d28bcfc5286ce777e4410e211fc12e719c..f86df13c9aa603049362fde1d7668d630a5bb51d 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -167,6 +167,8 @@ class FleetWrapper { std::string model_path, std::string model_proto_file, std::vector table_var_list, bool load_combine); + + void PrintTableStat(const uint64_t table_id); // mode = 0, load all feature // mode = 1, laod delta feature, which means load diff void LoadModel(const std::string& path, const int mode); diff --git a/paddle/fluid/pybind/fleet_wrapper_py.cc b/paddle/fluid/pybind/fleet_wrapper_py.cc index 679c91e8d8dbd9cb1e9fa5e038c3b4138eec2a6b..780fa99d2c43d8c90384e2f820b562cc0ce2e2c5 100644 --- a/paddle/fluid/pybind/fleet_wrapper_py.cc +++ b/paddle/fluid/pybind/fleet_wrapper_py.cc @@ -63,6 +63,7 @@ void BindFleetWrapper(py::module* m) { &framework::FleetWrapper::CreateClient2ClientConnection) .def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable) .def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable) + .def("print_table_stat", &framework::FleetWrapper::PrintTableStat) .def("client_flush", &framework::FleetWrapper::ClientFlush) .def("load_from_paddle_model", &framework::FleetWrapper::LoadFromPaddleModel) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index acebbe8251620140e9c6037accf73b4fb311eaeb..84bae58a7a3eb500b67010f81cdcdb66a7d38dc6 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -238,6 +238,25 @@ class PSLib(Fleet): """ self._fleet_ptr.save_model(dirname) + def print_table_stat(self, table_id): + """ + print stat info of table_id, + format: tableid, feasign size, mf size + + Args: + table_id(int): the id of table + + Example: + .. code-block:: python + + fleet.print_table_stat(0) + + """ + self._role_maker._barrier_worker() + if self._role_maker.is_first_worker(): + self._fleet_ptr.print_table_stat(table_id) + self._role_maker._barrier_worker() + def save_persistables(self, executor, dirname, main_program=None, **kwargs): """ save presistable parameters,