diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc index 5a4db2d7e686ce84abef620f890be8f3aa82cb73..521760228b5d7737a29151298c20c93429081a53 100644 --- a/paddle/operators/detail/grpc_client.cc +++ b/paddle/operators/detail/grpc_client.cc @@ -63,9 +63,6 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, sendrecv::VariableMessage req; req.set_varname(var_name); - auto* var = scope.FindVar(var_name); - SerializeToMessage(var_name, var, ctx, &req); - // varhandle VarHandle var_h; var_h.ep = ep; @@ -87,7 +84,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool RPCClient::wait() { +bool RPCClient::Wait() { bool ok = true; while (true) { diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h index d27b5ced9ece67f9b9da3b7f87ec231477603580..a62e70a2533ae52d84d010504b19fed5aeb15dc0 100644 --- a/paddle/operators/detail/grpc_client.h +++ b/paddle/operators/detail/grpc_client.h @@ -130,7 +130,7 @@ class RPCClient { const framework::Scope& scope, const std::string& var_name, int64_t time_out = 600 * 1000); - bool wait(); + bool Wait(); private: bool Proceed(); diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 55b33343af43802e1b6b95a32603bfee806c9764..dea7db391cf5635e1148ab40356d0851b67753a7 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/operators/detail/grpc_server.h" #include "paddle/operators/detail/sendrecvop_utils.h" #include "paddle/operators/detail/simple_block_queue.h" +#include "paddle/string/printf.h" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" @@ -77,35 +78,37 @@ class RecvOp : public framework::OperatorBase { if (grads_counter_.find(varname) == grads_counter_.end()) { grads_counter_[varname] = 0; } - char ret[256]; - snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(), - grads_counter_[varname]++); - return std::string(ret); + return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++); } void Run(const framework::Scope &scope, const platform::Place &dev_place) const override { - // FIXME(typhoonzero): no new scopes for every run. + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); framework::Scope &recv_scope = scope.NewScope(); rpc_service_->SetScope(&recv_scope); auto param_list = Attr>("ParamList"); auto grad_list = Attr>("GradList"); - auto trainer_count = Attr("Trainers"); + auto fan_in = Attr("Fanin"); size_t param_count = param_list.size(); + std::string program_str = Attr("OptimizeProgram"); + framework::proto::ProgramDesc program_desc; + program_desc.ParseFromString(program_str); + framework::ProgramDesc program(program_desc); + framework::Executor executor(dev_place); + rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; while (!exit_flag) { - // TODO(gognwb): simply this loop. - // Get from multiple trainers, we don't care about order in which - // the gradient arrives, just add suffix 0~n then average the gradient. - for (size_t i = 0; i < param_count * trainer_count; ++i) { - // blocking get one var from client. + // Get from multiple trainers, we don't care about the order in which + // the gradients arrives, just add suffix 0~n and merge the gradient. + for (size_t i = 0; i < param_count * fan_in; ++i) { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { - VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit"; + LOG(INFO) << "received terminate message and exit"; exit_flag = true; break; } @@ -114,44 +117,27 @@ class RecvOp : public framework::OperatorBase { if (it != grad_list.end()) { param_var_name = param_list[it - grad_list.begin()]; } else { - LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name - << "\""; + LOG(ERROR) << "grad have no paired param:" << grad_var_name; } VLOG(3) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; - - auto *merged_grad = recv_scope.FindVar(grad_var_name); - if (merged_grad == nullptr) { - auto *ptr = recv_scope.Var(grad_var_name); - CreateTensorFromMessageType(ptr, v.second.type()); - VLOG(3) << "Create Variable " << grad_var_name - << " on recv scope, which pointer is " << ptr << " type is " - << v.second.type(); + // Assume grad_var_name must appear in global scope. + std::string grad_var_name_trainer; + if (fan_in > 1) { + grad_var_name_trainer = this->GetGradVarNameForTrainer(grad_var_name); } - - if (trainer_count > 1) { - grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); + auto *var = recv_scope.FindVar(grad_var_name_trainer); + if (var == nullptr) { + LOG(ERROR) << "can not find server side var: " + << grad_var_name_trainer; + PADDLE_THROW("can not find server side var"); } - - auto *var = recv_scope.Var(grad_var_name); - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); detail::DeserializeFromMessage(v.second, dev_ctx, var); } - if (exit_flag) { break; } - rpc_service_->Reset(); - - std::string program_str = Attr("OptimizeProgram"); - framework::proto::ProgramDesc program_desc; - program_desc.ParseFromString(program_str); - framework::ProgramDesc program(program_desc); - framework::Executor executor(dev_place); - // Run sub graph to get optimized tensor try { executor.Run(program, &recv_scope, 0, /*global_block*/ false /*create_local_scope*/, false /*create_vars*/); @@ -195,7 +181,7 @@ This operator will recv tensor from send_op "GradList", "type list of string", "grad->param name mapping to find which param to optimize.") .SetDefault({}); - AddAttr("Trainers", "type int", + AddAttr("Fanin", "type int", "Number of trainers in the current cluster job") .SetDefault(1); } diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 4d145250bdc73607c8817e20fdb753f4c96e2391..d65153c1fdb5bd67f19b349a5cd9c4d83a9e5f09 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -41,14 +41,16 @@ class SendOp : public framework::OperatorBase { // FIXME(gongwb): DeviceContext? auto ctx = platform::CPUDeviceContext(); for (size_t i = 0; i < ins.size(); i++) { + VLOG(3) << "sending " << ins[i]; client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); } + client_.Wait(); for (size_t i = 0; i < outs.size(); i++) { + VLOG(3) << "getting " << outs[i]; client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - - client_.wait(); + client_.Wait(); } private: diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 00fe3e68c9008603bff3b2e0dd04e3e286d8c4b6..9876296a37ae1a72f482064648e4d9d1d1bd6412 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -452,6 +452,19 @@ class DistributeTranspiler: pserver_program = Program() for v in self.param_grad_ep_mapping[endpoint]["params"]: self._clone_var(pserver_program.global_block(), v) + for v in self.param_grad_ep_mapping[endpoint]["grads"]: + # create vars for each trainer in global scope, so + # we don't need to create them when grad arrives. + pserver_program.global_block().create_var( + name=v.name, persistable=True, dtype=v.dtype, shape=v.shape) + for trainer_id in xrange(self.trainers): + print("create variable for program: %s.trainer_%d" % + (v.name, trainer_id)) + pserver_program.global_block().create_var( + name="%s.trainer_%d" % (v.name, trainer_id), + persistable=True, + dtype=v.dtype, + shape=v.shape) # step6 optimize_sub_program = Program() for idx, opt_op in enumerate(optimize_ops): @@ -481,7 +494,7 @@ class DistributeTranspiler: p.name for p in self.param_grad_ep_mapping[endpoint]["grads"] ], - "Trainers": self.trainers + "Fanin": self.trainers }) pserver_program.sync_with_cpp() return pserver_program diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py index 20b4a8b34cd085ae51e6169f0d4eac58b7f3ffb2..e563e0ddc5d7998e19813065768bd87189a22ba4 100644 --- a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py +++ b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py @@ -39,26 +39,27 @@ train_reader = paddle.batch( place = fluid.CPUPlace() exe = fluid.Executor(place) -t = fluid.DistributeTranspiler() -# all parameter server endpoints list for spliting parameters -pserver_endpoints = os.getenv("PSERVERS") -# server endpoint for current node -current_endpoint = os.getenv("SERVER_ENDPOINT") -# run as trainer or parameter server +pserver_endpoints = os.getenv("PSERVERS") # all pserver endpoints +trainers = int(os.getenv("TRAINERS")) # total trainer count +current_endpoint = os.getenv("SERVER_ENDPOINT") # current pserver endpoint training_role = os.getenv("TRAINING_ROLE", "TRAINER") # get the training role: trainer/pserver -t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2) +t = fluid.DistributeTranspiler() +t.transpile( + optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": if not current_endpoint: print("need env SERVER_ENDPOINT") exit(1) pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops) - exe.run(fluid.default_startup_program()) + pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) + exe.run(pserver_startup) exe.run(pserver_prog) elif training_role == "TRAINER": trainer_prog = t.get_trainer_program() feeder = fluid.DataFeeder(feed_list=[images, label], place=place) + # TODO(typhoonzero): change trainer startup program to fetch parameters from pserver exe.run(fluid.default_startup_program()) for pass_id in range(PASS_NUM):