diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index 7268bcbd23411994a3c6bdd6572e8f826a5bd9de..b1aeaca353e65ba7206c65bcde9bc28ec4b06416 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -459,6 +459,16 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id, } } +void FleetWrapper::RecvAndSaveTable(const uint64_t table_id, + const std::string& path) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->recv_and_save_table(table_id, path); + if (ret != 0) { + LOG(ERROR) << "save model of table id: " << table_id + << ", to path: " << path << " failed"; + } +} + void FleetWrapper::PrintTableStat(const uint64_t table_id) { auto* communicator = Communicator::GetInstance(); auto ret = communicator->_worker_ptr->print_table_stat(table_id); diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h index 28ecedebf2c1e1bc4da5676abd83ec4f9f7a11ca..5de278e067ecd307bd0e0a26a2ba7c0c4f72fb6e 100644 --- a/paddle/fluid/distributed/fleet.h +++ b/paddle/fluid/distributed/fleet.h @@ -198,6 +198,10 @@ class FleetWrapper { // mode = 1, save delta feature, which means save diff void SaveModelOneTable(const uint64_t table_id, const std::string& path, const int mode); + + // recv table from server and save it in LodTensor + void RecvAndSaveTable(const uint64_t table_id, const std::string& path); + // clear all models, release their memory void ClearModel(); // clear one table diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc index f4e11818561fcdbf3538754fd01a56a2c9e0cc1a..6f932d55e9a194785bc2e950a75db3c1857d5561 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/service/brpc_ps_client.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -21,6 +22,7 @@ #include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/string/string_helper.h" const static int max_port = 65535; @@ -55,6 +57,16 @@ DEFINE_int32(pserver_connect_timeout_ms, 10000, DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); +namespace paddle { +namespace framework { +class Scope; +class Variable; +} // namespace framework +namespace platform { +class DeviceContext; +} // namespace platform +} // namespace paddle + namespace paddle { namespace distributed { @@ -903,5 +915,72 @@ std::future BrpcPsClient::push_sparse_raw_gradient_partial( return fut; } +int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id, + const std::string &path) { + // get var information + std::string var_name = ""; + int64_t var_num = 0; + int64_t var_shape = 0; + const auto &worker_param = _config.worker_param().downpour_worker_param(); + for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) { + if (worker_param.downpour_table_param(i).table_id() == table_id) { + var_name = worker_param.downpour_table_param(i).common().table_name(); + var_num = worker_param.downpour_table_param(i).accessor().fea_dim(); + var_shape = worker_param.downpour_table_param(i).accessor().embedx_dim(); + break; + } + } + + PADDLE_ENFORCE_NE( + var_name, "", + platform::errors::InvalidArgument( + "Cannot find table id %d to save variables.", table_id)); + + std::string var_store = string::Sprintf("%s", path); + MkDirRecursively(var_store.c_str()); + + // pull sparse from server + std::vector save_huge_vec(var_num * var_shape); + std::vector save_key(var_num); + std::vector save_vec; + for (size_t i = 0; i < save_key.size(); ++i) { + save_key[i] = i; + save_vec.push_back(save_huge_vec.data() + i * var_shape); + } + + auto status = pull_sparse((float **)save_vec.data(), table_id, + save_key.data(), save_key.size()); + status.wait(); + + // create lod tensor + std::shared_ptr scope; + scope.reset(new framework::Scope()); + auto place = platform::CPUPlace(); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + framework::Variable *var = scope->Var(var_name); + framework::LoDTensor *var_tensor = var->GetMutable(); + + std::vector vec_dim = {var_num, var_shape}; + var_tensor->Resize(framework::make_ddim(vec_dim)); + + // copy and save + float *tensor_data = var_tensor->mutable_data(place); + memcpy(tensor_data, save_huge_vec.data(), + var_num * var_shape * sizeof(float)); + + std::string file_name = string::Sprintf("%s/%s", var_store, var_name); + std::ofstream fout(file_name, std::ios::binary); + PADDLE_ENFORCE_EQ(static_cast(fout), true, + platform::errors::Unavailable( + "Cannot open %s to save variables.", file_name)); + + framework::SerializeToStream(fout, *var_tensor, dev_ctx); + fout.close(); + + return 0; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h index ed4310f016441d18a0a95bfb29fad85072d7530d..50faf7c9771c58dde24384d179d01b734f38eabf 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/service/brpc_ps_client.h @@ -22,6 +22,9 @@ #include "brpc/controller.h" #include "brpc/server.h" #include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" namespace paddle { namespace distributed { @@ -148,6 +151,10 @@ class BrpcPsClient : public PSClient { virtual std::future send_client2client_msg( int msg_type, int to_client_id, const std::string &msg) override; + // for local save sparse + virtual int32_t recv_and_save_table(const uint64_t table_id, + const std::string &path); + private: virtual int32_t initialize() override; diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index d549d09778c580b45586a930a5f3d372e9fdfc42..9d2309faef152c0b1793467eabda44fe2f44d1fa 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -134,6 +134,11 @@ class PSClient { virtual std::future push_global_step(int table_id, int64_t *total_send_data, void *done) = 0; + + // recv table from server and save it in LodTensor + virtual int32_t recv_and_save_table(const uint64_t table_id, + const std::string &path) = 0; + virtual void finalize_worker() = 0; // client to client, 消息发送 virtual std::future send_client2client_msg(int msg_type, diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc index 5c03b3f501880a021942d70066bc64836c31bbc9..fffe5eac1d8c199f44aa89de77429e10a09d29a8 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/table/common_sparse_table.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/string_helper.h" +#define PSERVER_SAVE_SUFFIX "_txt" namespace paddle { namespace distributed { @@ -290,7 +291,8 @@ int32_t CommonSparseTable::save(const std::string& dirname, VLOG(0) << "sparse table save: " << dirname << " mode: " << mode; auto varname = _config.common().table_name(); - std::string var_store = string::Sprintf("%s/%s", dirname, varname); + std::string var_store = + string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX); MkDirRecursively(var_store.c_str()); VLOG(3) << "save " << varname << " in dir: " << var_store << " begin"; diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 4dd43175a1162189fea2a143a547ae4c696b4a93..4777951d82c5e635e40fc2784d32721718067df5 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -58,6 +58,7 @@ void BindDistFleetWrapper(py::module* m) { .def("pull_dense_params", &FleetWrapper::PullDenseVarsSync) .def("save_all_model", &FleetWrapper::SaveModel) .def("save_one_model", &FleetWrapper::SaveModelOneTable) + .def("recv_and_save_model", &FleetWrapper::RecvAndSaveTable) .def("sparse_table_stat", &FleetWrapper::PrintTableStat) .def("stop_server", &FleetWrapper::StopServer) .def("stop_worker", &FleetWrapper::FinalizeWorker) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index cd6238c1125ed8e4f5b1a142c1330e8095e42962..a45cdd6f38f7c329df68794a5169662931895300 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -545,7 +545,7 @@ class Fleet(object): executor, dirname, feeded_var_names, target_vars, main_program, export_for_deployment) - def save_persistables(self, executor, dirname, main_program=None, mode=1): + def save_persistables(self, executor, dirname, main_program=None, mode=0): """ saves all persistable tensors from :code:`main_program` to diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py index 8fd172b5227492fc3496d279a42ba99d93b814ec..dd13f9bc5d4e759a7b4352474b6be37369a380fc 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py @@ -64,12 +64,12 @@ class ParameterServerOptimizer(MetaOptimizerBase): _main = compiled_config.origin_main_program.clone() _startup = compiled_config.origin_startup_program.clone() - from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass - _add_lr_decay_table_pass( - _main, compiled_config, - self.user_defined_strategy.a_sync_configs["lr_decay_steps"]) - if not compiled_config.is_geo_mode(): + from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass + _add_lr_decay_table_pass( + _main, compiled_config, + self.user_defined_strategy.a_sync_configs["lr_decay_steps"]) + # for main program _main = worker.delete_optimizer_pass(_main, compiled_config) _main = worker.distributed_ops_pass(_main, compiled_config) diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 3b17be1aa075871d23d48a4d3746028667eb8bb6..74a961eff0297fc8a4c320a9795158871e85d7c8 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -851,15 +851,26 @@ class TheOnePSRuntime(RuntimeBase): return is_valid - def _save_sparse_params(self, executor, dirname, context, main_program): + def _save_sparse_params(self, executor, dirname, context, main_program, + mode): + from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames + distributed_varnames = get_sparse_tablenames( + self.compiled_strategy.origin_main_program, True) values = [] for id, names in context.items(): + if names not in distributed_varnames: + # only save sparse param to local + self._worker.recv_and_save_model(id, dirname) + # save sparse & distributed param on server + self._worker.save_one_model(id, dirname, mode) values.extend(names) - self._worker.save_one_model(id, dirname, 0) return values - def _save_distributed_persistables(self, executor, dirname, main_program, - mode): + def _save_distributed_persistables(self, + executor, + dirname, + main_program, + mode=0): denses = self.compiled_strategy.get_the_one_recv_context( is_dense=True, @@ -870,14 +881,14 @@ class TheOnePSRuntime(RuntimeBase): split_dense_table=self.role_maker._is_heter_parameter_server_mode, use_origin_program=True) - recv_sparse_varnames = self._save_sparse_params(executor, dirname, - sparses, main_program) + sparse_varnames = self._save_sparse_params(executor, dirname, sparses, + main_program, mode) recv_dense_varnames = [] for id, names in denses.items(): recv_dense_varnames.extend(names) - saved_varnames = recv_sparse_varnames + saved_varnames = sparse_varnames remaining_vars = list( filter( @@ -925,6 +936,7 @@ class TheOnePSRuntime(RuntimeBase): "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" ) + # Todo(MrChengmo): Save optimizer status self._save_distributed_persistables(executor, dirname, main_program, mode) @@ -971,8 +983,7 @@ class TheOnePSRuntime(RuntimeBase): program = Program.parse_from_string(program_desc_str) program._copy_dist_param_info_from(fluid.default_main_program()) - self._ps_inference_save_persistables( - executor, dirname, program, mode=0) + self._ps_inference_save_persistables(executor, dirname, program) def _save_inference_model(self, *args, **kwargs): self._ps_inference_save_inference_model(*args, **kwargs)