From 40d0fff2e55b795690ef93cb539e8c3a029b7b16 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 14 Dec 2017 12:24:25 +0800 Subject: [PATCH] single pserver workable version --- paddle/operators/recv_op.cc | 72 ++++++++++++++++-------------- python/paddle/v2/fluid/executor.py | 2 +- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 754338ec6bd..a0c25a25eb1 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -69,43 +69,47 @@ class RecvOp : public framework::OperatorBase { auto param_list = Attr>("ParamList"); auto grad_list = Attr>("GradList"); size_t param_count = param_list.size(); - for (size_t i = 0; i < param_count; ++i) { - // blocking get one var from client. - const detail::TensorWithName &v = rpc_service_->Get(); - auto grad_var_name = v.first; - auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); - std::string param_var_name; - if (it != grad_list.end()) { - param_var_name = param_list[it - grad_list.begin()]; + // TODO(typhoonzero): change this to a while_op for every cluster-batch. + while (true) { + // TODO(typhoonzero): get from multiple trainers. + for (size_t i = 0; i < param_count; ++i) { + // blocking get one var from client. + const detail::TensorWithName &v = rpc_service_->Get(); + auto grad_var_name = v.first; + auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); + std::string param_var_name; + if (it != grad_list.end()) { + param_var_name = param_list[it - grad_list.begin()]; + } + VLOG(10) << "recved grad: " << grad_var_name + << " updating param: " << param_var_name; + auto *var = recv_scope.Var(grad_var_name); + auto *tensor = var->GetMutable(); + // FIXME(typhoonzero): do not copy + framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); } - VLOG(10) << "recved grad: " << grad_var_name - << " updating param: " << param_var_name; - auto *var = recv_scope.Var(grad_var_name); - auto *tensor = var->GetMutable(); - // FIXME(typhoonzero): do not copy - framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); - } - std::string program_str = Attr("OptimizeProgram"); - framework::ProgramDesc program_desc; - program_desc.ParseFromString(program_str); - framework::ProgramDescBind program(program_desc); - framework::Executor executor(dev_ctx); - // Run sub graph to get optimized tensor - try { - executor.Run(program, &recv_scope, 0, /*global_block*/ - false /*create_local_scope*/, false /*create_vars*/); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } + std::string program_str = Attr("OptimizeProgram"); + framework::ProgramDesc program_desc; + program_desc.ParseFromString(program_str); + framework::ProgramDescBind program(program_desc); + framework::Executor executor(dev_ctx); + // Run sub graph to get optimized tensor + try { + executor.Run(program, &recv_scope, 0, /*global_block*/ + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } - for (size_t i = 0; i < param_count; ++i) { - auto *out_var = recv_scope.FindVar(param_list[i]); - detail::TensorWithName out; - out.first = param_list[i]; - out.second = out_var->Get(); - rpc_service_->Push(out); - } + for (size_t i = 0; i < param_count; ++i) { + auto *out_var = recv_scope.FindVar(param_list[i]); + detail::TensorWithName out; + out.first = param_list[i]; + out.second = out_var->Get(); + rpc_service_->Push(out); + } + } // while(true) } protected: diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index c8c9a4ef366..4d245250e89 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -93,7 +93,7 @@ class Executor(object): dtype=var.dtype, type=var.type, lod_level=var.lod_level, - persistable=True) + persistable=var.persistable) def _optimize_distributed(self, optimize_ops, program, params_and_grads, **kwargs): -- GitLab