diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 8cee46cbb2d6a1002864916e250fb7ab30f91430..95f4738b4ff50852d9591719133ca650533bf848 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH }; class RequestBase { public: explicit RequestBase(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, const platform::DeviceContext* dev_ctx) - : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { + : service_(service), + cq_(cq), + sync_mode_(sync_mode), + status_(PROCESS), + dev_ctx_(dev_ctx) { PADDLE_ENFORCE(cq_); } virtual ~RequestBase() {} @@ -49,6 +53,7 @@ class RequestBase { ::grpc::ServerContext ctx_; GrpcService::AsyncService* service_; ::grpc::ServerCompletionQueue* cq_; + const bool sync_mode_; CallStatus status_; const platform::DeviceContext* dev_ctx_; }; @@ -56,11 +61,17 @@ class RequestBase { class RequestSend final : public RequestBase { public: explicit RequestSend(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, ReceivedQueue* queue, const platform::DeviceContext* dev_ctx) - : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { - request_.reset(new VariableResponse(scope, dev_ctx_)); + : RequestBase(service, cq, sync_mode, dev_ctx), + queue_(queue), + responder_(&ctx_) { + if (sync_mode_) { + request_.reset(new VariableResponse(scope, dev_ctx_, false)); + } else { + request_.reset(new VariableResponse(scope, dev_ctx_, true)); + } int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -87,11 +98,11 @@ class RequestSend final : public RequestBase { class RequestGet final : public RequestBase { public: explicit RequestGet(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, framework::BlockingQueue* queue) - : RequestBase(service, cq, dev_ctx), + : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), queue_(queue) { @@ -134,19 +145,23 @@ class RequestGet final : public RequestBase { class RequestPrefetch final : public RequestBase { public: explicit RequestPrefetch(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, framework::Executor* executor, framework::ProgramDesc* program, framework::ExecutorPrepareContext* prefetch_ctx) - : RequestBase(service, cq, dev_ctx), + : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), prefetch_ctx_(prefetch_ctx) { - request_.reset(new VariableResponse(scope, dev_ctx_)); + if (sync_mode_) { + request_.reset(new VariableResponse(scope, dev_ctx_, false)); + } else { + request_.reset(new VariableResponse(scope, dev_ctx_, true)); + } int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase { framework::Executor* executor_; framework::ProgramDesc* program_; framework::ExecutorPrepareContext* prefetch_ctx_; - int blkid_; }; void AsyncGRPCServer::WaitClientGet(int count) { @@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; return; } - RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, - &var_recv_queue_, dev_ctx_); + RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_, + scope_, &var_recv_queue_, dev_ctx_); VLOG(4) << "Create RequestSend status:" << send->Status(); } @@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; return; } - RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, - &var_get_queue_); + RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, + dev_ctx_, &var_get_queue_); VLOG(4) << "Create RequestGet status:" << get->Status(); } @@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { return; } RequestPrefetch* prefetch = - new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, - executor_, program_, prefetch_ctx_); + new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_, + dev_ctx_, executor_, program_, prefetch_ctx_); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } @@ -301,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, VLOG(3) << "HandleRequest for " << cq_name << " while after Next"; PADDLE_ENFORCE(tag); - // FIXME(typhoonzero): de-couple the barriers with recv_op - if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); - if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); + if (sync_mode_) { + // FIXME(typhoonzero): de-couple the barriers with recv_op + if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); + if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); + } RequestBase* base = reinterpret_cast(tag); // reference: @@ -320,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, switch (base->Status()) { case PROCESS: { - VLOG(4) << cq_name << " status:" << base->Status(); + VLOG(4) << cq_name << " PROCESS status:" << base->Status(); TryToRegisterNewOne(); base->Process(); break; } case FINISH: { - VLOG(4) << cq_name << " status:" << base->Status(); + VLOG(4) << cq_name << " FINISH status:" << base->Status(); delete base; break; } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index a15c93b7830265a2bb22334b5bb5a0f8ee2f28f4..99b87b8c6cb3e597778b88c395e4abf400d82c39 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -44,7 +44,8 @@ class RequestBase; class AsyncGRPCServer final { public: - explicit AsyncGRPCServer(const std::string &address) : address_(address) {} + explicit AsyncGRPCServer(const std::string &address, bool sync_mode) + : address_(address), sync_mode_(sync_mode) {} void RunSyncUpdate(); @@ -95,6 +96,7 @@ class AsyncGRPCServer final { std::unique_ptr<::grpc::Server> server_; std::string address_; + const bool sync_mode_; framework::Scope *scope_; const platform::DeviceContext *dev_ctx_; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index c51933718f4ca78e87c77e007c485642000d247d..25b95d608d10d6e456d5f563ce9fbe35d812cb0f 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, } void StartServer(const std::string& endpoint) { - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true)); framework::ProgramDesc program; framework::Scope scope; platform::CPUPlace place; diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index 3018a5c4af876828380ff4c1cbfdaafa8a2057e1..bf624da2a6c26472e47711b3c6409f78afba0a64 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -46,7 +46,9 @@ class VariableResponse { } virtual ~VariableResponse() { - if (create_scope_) scope_->DeleteScope(local_scope_); + if (create_scope_) { + scope_->DeleteScope(local_scope_); + } } // return: @@ -63,6 +65,8 @@ class VariableResponse { const framework::Scope& GetLocalScope() const { return *local_scope_; } + framework::Scope* GetMutableLocalScope() const { return local_scope_; } + inline std::string Varname() { return meta_.varname(); } inline std::string OutVarname() { return meta_.out_varname(); } diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index af235fb6a029a71ee275bebfbbd75aaa0b7d546d..57cff680ab89f2df7e71af4056ee06cdf330bbab 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -27,6 +27,38 @@ void RunServer(std::shared_ptr service) { VLOG(4) << "RunServer thread end"; } +static void split(const std::string &str, char sep, + std::vector *pieces) { + pieces->clear(); + if (str.empty()) { + return; + } + size_t pos = 0; + size_t next = str.find(sep, pos); + while (next != std::string::npos) { + pieces->push_back(str.substr(pos, next - pos)); + pos = next + 1; + next = str.find(sep, pos); + } + if (!str.substr(pos).empty()) { + pieces->push_back(str.substr(pos)); + } +} + +static void AsyncExecuteBlock(framework::Executor *executor, + framework::ExecutorPrepareContext *prepared, + framework::Scope *scope) { + std::future future = framework::Async([&executor, &prepared, &scope]() { + try { + executor->RunPreparedContext(prepared, scope, false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + }); + // TODO(qiao) maybe we can remove this + future.wait(); +} + static void ParallelExecuteBlocks( const std::vector ¶llel_blkids, framework::Executor *executor, const std::vector> @@ -169,15 +201,82 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, } // while(true) } +void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, + framework::ProgramDesc *program, + framework::Scope *recv_scope, + framework::BlockDesc *prefetch_block) const { + VLOG(3) << "RunAsyncLoop in"; + // grad name to block id + std::unordered_map grad_to_block_id; + std::unordered_map id_to_grad; + + auto grad_to_block_id_str = + Attr>("grad_to_block_id"); + for (auto &grad_and_id : grad_to_block_id_str) { + std::vector pieces; + split(grad_and_id, ':', &pieces); + VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; + PADDLE_ENFORCE_EQ(pieces.size(), 2); + PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0); + int block_id = std::stoi(pieces[1]); + grad_to_block_id[pieces[0]] = block_id; + id_to_grad[block_id] = pieces[0]; + } + size_t num_blocks = program->Size(); + PADDLE_ENFORCE_GE(num_blocks, 2, + "server program should have at least 2 blocks"); + + std::vector block_list; + for (size_t blkid = 1; blkid < num_blocks; ++blkid) { + block_list.push_back(blkid); + } + auto optimize_prepared = executor->Prepare(*program, block_list); + std::unordered_map> + grad_to_prepared_ctx; + for (size_t i = 0; i < block_list.size(); ++i) { + grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i]; + } + + VLOG(3) << "RunAsyncLoop into while"; + bool exit_flag = false; + while (!exit_flag) { + const detail::ReceivedMessage v = rpc_service_->Get(); + auto recv_var_name = v.first; + if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { + LOG(INFO) << "received terminate message and exit"; + exit_flag = true; + break; + } else { + VLOG(3) << "received grad: " << recv_var_name; + auto var = v.second->GetVar(); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << recv_var_name; + PADDLE_THROW("Can not find server side var"); + } + AsyncExecuteBlock(executor, grad_to_prepared_ctx[recv_var_name].get(), + v.second->GetMutableLocalScope()); + } + + if (exit_flag) { + rpc_service_->ShutDown(); + break; + } + } // while(true) +} + void ListenAndServOp::RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); framework::Scope &recv_scope = scope.NewScope(); + bool sync_mode = Attr("sync_mode"); + PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode)); auto *optimize_block = Attr(kOptimizeBlock); auto *prefetch_block = Attr(kPrefetchBlock); @@ -202,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, sleep(5); // Write to a file of server selected port for python use. SavePort(rpc_service_); - RunSyncLoop(&executor, program, &recv_scope, prefetch_block); + if (sync_mode) { + RunSyncLoop(&executor, program, &recv_scope, prefetch_block); + } else { + RunAsyncLoop(&executor, program, &recv_scope, prefetch_block); + } } class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { @@ -221,6 +324,12 @@ from send_op and send back variables to recv_op. "IP address to listen on.") .SetDefault("127.0.0.1:6164") .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + AddAttr>( + "grad_to_block_id", + "['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] " + "a map from grad name to it's optimize block id") + .SetDefault({}); + AddAttr("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr(kOptimizeBlock, "BlockID to run on server side."); AddAttr(kPrefetchBlock, diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index dfb7c77c8e36d9af79d8b1713d0c0c59c81b1ca6..3cc0f3047733bea94daa310cd39cb0a4f44bef85 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase { framework::Scope* recv_scope, framework::BlockDesc* prefetch_block) const; + void RunAsyncLoop(framework::Executor* executor, + framework::ProgramDesc* program, + framework::Scope* recv_scope, + framework::BlockDesc* prefetch_block) const; + void Stop() override; void RunImpl(const framework::Scope& scope, diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 82ff087d0a7a4b482aef842e618f593b17dca171..e4386b640a298cd216bb60104653f20c4a96e7dc 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase { std::vector endpoints = Attr>("endpoints"); + bool sync_mode = Attr("sync_mode"); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase { } PADDLE_ENFORCE(rpc_client->Wait()); - for (auto& ep : endpoints) { - VLOG(3) << "batch barrier, ep: " << ep; - rpc_client->AsyncSendBatchBarrier(ep); + if (sync_mode) { + for (auto& ep : endpoints) { + VLOG(3) << "batch barrier, ep: " << ep; + rpc_client->AsyncSendBatchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); } - PADDLE_ENFORCE(rpc_client->Wait()); if (outs.size() > 0) { for (size_t i = 0; i < outs.size(); i++) { @@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server. "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); + AddAttr("sync_mode", "work in sync_mode or not").SetDefault(true); } }; diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index 81350fee38df058d1b63eb5a8cd0b770e0626ae4..d2e1f3cb2ff9c8254cd4815a0f8750966a6e161c 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -137,6 +137,8 @@ void StartServerNet(bool is_sparse) { attrs.insert({"GradList", std::vector({"x1"})}); attrs.insert({"OptimizeBlock", optimize_block}); attrs.insert({"PrefetchBlock", prefetch_block}); + attrs.insert({"grad_to_block_id", std::vector({""})}); + attrs.insert({"sync_mode", true}); listen_and_serv_op = f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); listen_and_serv_op->Run(scope, place); diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index d07e0f696e79cfb98efc09a9f40d7961678b6af4..d17475cd28b3ae57032d3be811542fc89246e299 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -143,7 +143,8 @@ class DistributeTranspiler: program=None, pservers="127.0.0.1:6174", trainers=1, - split_method=splitter.round_robin): + split_method=splitter.round_robin, + sync_mode=True): """ Transpile the program to distributed data-parallelism programs. The main_program will be transformed to use a remote parameter server @@ -184,6 +185,9 @@ class DistributeTranspiler: :param split_method: A function to determin how to split variables to different servers equally. :type split_method: function + :type sync_mode: boolean default True + :param sync_mode: if sync_mode is set True, it means that dist transpiler + will transpile the program into sync_mode pserver and trainer program. """ assert (callable(split_method)) if program is None: @@ -191,6 +195,7 @@ class DistributeTranspiler: self.origin_program = program self.trainer_num = trainers self.optimize_ops = optimize_ops + self.sync_mode = sync_mode # TODO(typhoonzero): currently trainer_id is fetched from cluster system # like Kubernetes, we should port this to use etcd later when developing # fluid distributed training with fault-tolerance. @@ -295,8 +300,11 @@ class DistributeTranspiler: inputs={"X": send_inputs}, outputs={"Out": send_outputs, "RPCClient": rpc_client_var}, - attrs={"endpoints": pserver_endpoints, - "epmap": eplist}) + attrs={ + "endpoints": pserver_endpoints, + "epmap": eplist, + "sync_mode": self.sync_mode + }) # step4: Concat the parameters splits together after recv. for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: @@ -356,7 +364,7 @@ class DistributeTranspiler: type=v.type, dtype=v.dtype, shape=v.shape) - if self.trainer_num > 1: + if self.sync_mode and self.trainer_num > 1: for trainer_id in xrange(self.trainer_num): var = pserver_program.global_block().create_var( name="%s.trainer_%d" % (orig_var_name, trainer_id), @@ -402,13 +410,13 @@ class DistributeTranspiler: for op in self.optimize_ops: if op.type == "scale": for in_name in op.input_arg_names: - if in_name.startswith("beta1_pow_acc") or\ - in_name.startswith("beta2_pow_acc"): + if in_name.startswith("beta1_pow_acc") or \ + in_name.startswith("beta2_pow_acc"): global_ops.append(op) - def __append_optimize_op__(op, block): + def __append_optimize_op__(op, block, grad_to_block_id): if self._is_opt_op(op): - self._append_pserver_ops(block, op, endpoint, + self._append_pserver_ops(block, op, endpoint, grad_to_block_id, default_main_program()) else: self._append_pserver_non_opt_ops(block, op) @@ -422,16 +430,16 @@ class DistributeTranspiler: self._append_pserver_non_opt_ops(lr_decay_block, op) # append op to the current block + grad_to_block_id = [] pre_block_idx = pserver_program.num_blocks - 1 for idx, opt_op in enumerate(opt_op_on_pserver): per_opt_block = pserver_program.create_block(pre_block_idx) for _, op in enumerate(self.optimize_ops): # optimizer is connected to itself if ufind.is_connected(op, opt_op) and op not in global_ops: - __append_optimize_op__(op, per_opt_block) + __append_optimize_op__(op, per_opt_block, grad_to_block_id) # append global ops - opt_state_block = None if global_ops: opt_state_block = pserver_program.create_block( pserver_program.num_blocks - 1) @@ -472,7 +480,9 @@ class DistributeTranspiler: "OptimizeBlock": pserver_program.block(1), "endpoint": endpoint, "Fanin": self.trainer_num, - "PrefetchBlock": prefetch_block + "PrefetchBlock": prefetch_block, + "sync_mode": self.sync_mode, + "grad_to_block_id": grad_to_block_id }) pserver_program.sync_with_cpp() @@ -683,17 +693,6 @@ class DistributeTranspiler: self.table_name)], persistable=False) - # create grad vars in pserver program - table_grad_var = self.table_param_grad[1] - table_grad_list = [ - pserver_program.global_block().create_var( - name="%s.trainer_%d.pserver_%d" % - (table_grad_var.name, index, pserver_index), - type=table_grad_var.type, - shape=table_grad_var.shape, - dtype=table_grad_var.dtype) for index in range(self.trainer_num) - ] - # create table optimize block in pserver program table_opt_op = [ op for op in self.optimize_ops @@ -703,11 +702,24 @@ class DistributeTranspiler: # only support sgd now assert table_opt_op.type == "sgd" - # append sum op for table_grad_list - table_opt_block.append_op( - type="sum", - inputs={"X": table_grad_list}, - outputs={"Out": [grad_var]}) + if self.sync_mode: + # create grad vars in pserver program + table_grad_var = self.table_param_grad[1] + table_grad_list = [ + pserver_program.global_block().create_var( + name="%s.trainer_%d.pserver_%d" % + (table_grad_var.name, index, pserver_index), + type=table_grad_var.type, + shape=table_grad_var.shape, + dtype=table_grad_var.dtype) + for index in range(self.trainer_num) + ] + + # append sum op for table_grad_list + table_opt_block.append_op( + type="sum", + inputs={"X": table_grad_list}, + outputs={"Out": [grad_var]}) lr_var = pserver_program.global_block().vars[table_opt_op.input( "LearningRate")[0]] @@ -746,7 +758,7 @@ class DistributeTranspiler: for varname, splited in block_map.iteritems(): orig_var = program.global_block().var(varname) if len(splited) == 1: - if add_trainer_suffix: + if self.sync_mode and add_trainer_suffix: new_var_name = "%s.trainer_%d" % \ (orig_var.name, self.trainer_id) program.global_block().rename_var(varname, new_var_name) @@ -770,7 +782,7 @@ class DistributeTranspiler: if len(orig_shape) >= 2: splited_shape.extend(orig_shape[1:]) new_var_name = "" - if add_trainer_suffix: + if self.sync_mode and add_trainer_suffix: new_var_name = "%s.block%d.trainer_%d" % \ (varname, i, self.trainer_id) else: @@ -879,7 +891,7 @@ class DistributeTranspiler: return orig_var_name def _append_pserver_ops(self, optimize_block, opt_op, endpoint, - origin_program): + grad_to_block_id, origin_program): program = optimize_block.program pserver_block = program.global_block() new_inputs = dict() @@ -900,7 +912,9 @@ class DistributeTranspiler: return merged_var = \ pserver_block.vars[self._orig_varname(grad_block.name)] - if self.trainer_num > 1: + grad_to_block_id.append(merged_var.name + ":" + str( + optimize_block.idx)) + if self.sync_mode and self.trainer_num > 1: vars2merge = [] for i in xrange(self.trainer_num): per_trainer_name = "%s.trainer_%d" % \ @@ -918,6 +932,7 @@ class DistributeTranspiler: inputs={"X": merged_var}, outputs={"Out": merged_var}, attrs={"scale": 1.0 / float(self.trainer_num)}) + new_inputs[key] = merged_var elif key == "Param": # param is already created on global program