From e7842ba6670824efa5484b7ccfe9b364949a6fb7 Mon Sep 17 00:00:00 2001 From: wangguanqun Date: Thu, 28 Oct 2021 14:37:29 +0800 Subject: [PATCH] save/load in ps runtime(the_one_ps) (#36097) * add trainer desc config to distributed strategy * code style modified * data_feed set lod * fix bug * code style * fix bug * save load * save load * save unittest * add unittest of the_one_ps * unittest * add todo in communicator sendsparse --- .../fluid/distributed/service/communicator.cc | 23 ++++++ .../fluid/distributed/service/communicator.h | 2 + .../distributed/table/common_sparse_table.cc | 17 +++-- paddle/fluid/pybind/fleet_py.cc | 3 +- .../distributed/fleet/runtime/the_one_ps.py | 72 +++++++++++++++++-- python/paddle/fluid/communicator.py | 3 + .../tests/unittests/test_fleet_base_2.py | 8 +++ 7 files changed, 116 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/distributed/service/communicator.cc b/paddle/fluid/distributed/service/communicator.cc index 3d5ab8e16d..30529d73fa 100644 --- a/paddle/fluid/distributed/service/communicator.cc +++ b/paddle/fluid/distributed/service/communicator.cc @@ -283,6 +283,18 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, push_g_vec.push_back(tensor->mutable_value()->data() + i * dim); } + // TODO(wangguanqun): padding_idx is not ignored, this is a bug. + // if padding_idx == padding in datareader, the server will core. + /* + for (size_t i = 0; i < tensor->rows().size(); ++i) { + uint64_t real_id = static_cast(tensor->rows()[i]); + if (real_id != 0) { + sparse_push_keys.push_back(real_id); + push_g_vec.push_back(tensor->mutable_value()->data() + i * dim); + } + } + */ + ++_async_call_num; DownpourBrpcClosure *closure = new DownpourBrpcClosure( request_call_num, [this, request_call_num](void *done) { @@ -353,6 +365,17 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { return; } +void Communicator::PullDense(const RecvCtxMap &recv_varname_to_ctx) { + for (auto &iter : recv_varname_to_ctx) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcRecvDense(varnames, table_id, recv_scope_); + VLOG(1) << "pull dense param to table " << table_id + << " from 0' trainer done"; + } + return; +} + void Communicator::RpcProfilerControl() { if (trainer_id_ == 0) { if (!do_server_profiler_ && platform::IsProfileEnabled()) { diff --git a/paddle/fluid/distributed/service/communicator.h b/paddle/fluid/distributed/service/communicator.h index c6d37defbd..01ec3c617d 100644 --- a/paddle/fluid/distributed/service/communicator.h +++ b/paddle/fluid/distributed/service/communicator.h @@ -271,6 +271,8 @@ class Communicator { virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx); + virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx); + virtual void Start() = 0; virtual void Stop() = 0; diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc index 8b79b1c02f..e124160e71 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/table/common_sparse_table.cc @@ -279,18 +279,25 @@ int32_t CommonSparseTable::set_global_lr(float* lr) { return 0; } -int32_t CommonSparseTable::load(const std::string& path, +int32_t CommonSparseTable::load(const std::string& dirname, const std::string& param) { auto begin = GetCurrentUS(); rwlock_->WRLock(); - LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_, + auto varname = _config.common().table_name(); + std::string var_store = + string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX); + std::string shard_var_pre = + string::Sprintf("%s.block%d", varname, _shard_idx); + std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre); + std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre); + + LoadFromText(value_, meta_, _shard_idx, _shard_num, task_pool_size_, &shard_values_); rwlock_->UNLock(); auto end = GetCurrentUS(); - auto varname = _config.common().table_name(); - VLOG(0) << "load " << varname << " with value: " << path - << " , meta: " << param + VLOG(0) << "load " << varname << " with value: " << value_ + << " , meta: " << meta_ << " using: " << std::to_string((end - begin) / 1e+6) << " seconds"; return 0; diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index ea9faf57ac..0a39f52938 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -158,7 +158,8 @@ void BindDistCommunicator(py::module* m) { .def("start", &Communicator::Start) .def("push_sparse_param", &Communicator::RpcSendSparseParam) .def("is_running", &Communicator::IsRunning) - .def("init_params", &Communicator::InitParams); + .def("init_params", &Communicator::InitParams) + .def("pull_dense", &Communicator::PullDense); // .def("recv", &Communicator::RecvNoBarrier); } diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 642d0e427f..0b874b8c61 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -868,11 +868,11 @@ class TheOnePSRuntime(RuntimeBase): for var_name in load_varnames: table_id = sparse_table_maps[var_name] - path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, - "{}.block{}.txt".format(var_name, pserver_id)) - meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, - "{}.block{}.meta".format(var_name, pserver_id)) - self._server.load_sparse(path, meta, table_id) + # path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, + # "{}.block{}.txt".format(var_name, pserver_id)) + # meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, + # "{}.block{}.meta".format(var_name, pserver_id)) + self._server.load_sparse(dirname, "0", table_id) def _run_server(self): if self.role_maker._is_heter_worker(): @@ -967,8 +967,12 @@ class TheOnePSRuntime(RuntimeBase): TheOnePSRuntime.__exclude_vars(saved_varnames), main_program.list_vars())) + self._communicator.pull_dense(denses) + import paddle for var in remaining_vars: + if var.name not in recv_dense_varnames: + continue tensor = var.get_value() paddle.save( tensor, os.path.join(dirname, var.name), use_binary_format=True) @@ -1063,8 +1067,64 @@ class TheOnePSRuntime(RuntimeBase): def _save_persistables(self, *args, **kwargs): self._ps_inference_save_persistables(*args, **kwargs) + def _load_sparse_params(self, dirname, context, main_program, mode): + from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames + distributed_varnames = get_sparse_tablenames( + self.compiled_strategy.origin_main_program, True) + values = [] + for id, names in context.items(): + if names[0] not in distributed_varnames: + # TODO: only load sparse param from local + warnings.warn("varname is not in distributed_varnames, pass") + # load sparse & distributed param on server + self._worker.load_one_table(id, dirname, mode) + values.extend(names) + return values + + def _load_distributed_persistables(self, dirname, main_program=None, + mode=0): + if main_program is None: + main_program = self.compiled_strategy.get_origin_ps_main_program() + + if isinstance(main_program, CompiledProgram): + raise TypeError( + "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed" + ) + + denses = self.compiled_strategy.get_the_one_recv_context( + is_dense=True, + split_dense_table=self.role_maker._is_heter_parameter_server_mode, + use_origin_program=True) + sparses = self.compiled_strategy.get_the_one_recv_context( + is_dense=False, + split_dense_table=self.role_maker._is_heter_parameter_server_mode, + use_origin_program=True) + + sparse_varnames = self._load_sparse_params(dirname, sparses, + main_program, mode) + + recv_dense_varnames = [] + for id, names in denses.items(): + recv_dense_varnames.extend(names) + + loaded_varnames = sparse_varnames + + remaining_vars = list( + filter( + TheOnePSRuntime.__exclude_vars(loaded_varnames), + main_program.list_vars())) + + import paddle + for var in remaining_vars: + if var.name not in recv_dense_varnames: + continue + tensor = paddle.load(os.path.join(dirname, var.name)) + var.set_value(tensor) + + self._communicator.init_params(denses) + def load_model(self, path, mode): - self._worker.load_model(path, mode) + self._load_distributed_persistables(path, mode=mode) def _shrink(self, threshold): import paddle.distributed.fleet as fleet diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py index fa497f5c28..9a75ef8c58 100644 --- a/python/paddle/fluid/communicator.py +++ b/python/paddle/fluid/communicator.py @@ -161,6 +161,9 @@ class Communicator(object): def init_params(self, context): self.communicator_.init_params(context) + def pull_dense(self, context): + self.communicator_.pull_dense(context) + def push_sparse_param(self, var_name, table_id=-1, scope=global_scope()): if not self.is_running(): raise ValueError( diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py index 7ca08bcb9d..64b8744472 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py @@ -36,8 +36,13 @@ class TestFleetBase(unittest.TestCase): input_x = paddle.fluid.layers.data( name="x", shape=[32], dtype='float32') + input_slot = paddle.fluid.layers.data( + name="slot", shape=[1], dtype='int64') input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + emb = paddle.fluid.layers.embedding( + input=input_slot, size=[10, 9], is_sparse=True) + input_x = paddle.concat(x=[input_x, emb], axis=1) fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') @@ -63,11 +68,14 @@ class TestFleetBase(unittest.TestCase): compiled_prog = fluid.compiler.CompiledProgram( fluid.default_main_program()) + fleet.init_worker() fleet.fleet.save(dirname="/tmp", feed=['x', 'y'], fetch=[avg_cost]) fleet.fleet.save( dirname="/tmp", feed=[input_x, input_y], fetch=[avg_cost]) fleet.fleet.save(dirname="/tmp") + fleet.load_model(path="/tmp", mode=0) + self.assertRaises( Exception, fleet.save_inference_model, -- GitLab