From 7be79231e17b677f0925397e5a0663bcdd1bfe6e Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 18 Dec 2017 20:49:00 +0800 Subject: [PATCH] wip multi-trainer --- paddle/operators/detail/send_impl.cc | 6 +++ paddle/operators/detail/send_recv_impl.h | 1 + paddle/operators/recv_op.cc | 5 ++- paddle/operators/send_op.cc | 42 ++++++++++--------- .../paddle/v2/fluid/distribute_transpiler.py | 22 ++++++---- 5 files changed, 47 insertions(+), 29 deletions(-) diff --git a/paddle/operators/detail/send_impl.cc b/paddle/operators/detail/send_impl.cc index 7555cc63fb2..d7165e13db9 100644 --- a/paddle/operators/detail/send_impl.cc +++ b/paddle/operators/detail/send_impl.cc @@ -66,6 +66,12 @@ bool RPCClient::GetVariable(const framework::Scope& scope, return true; } +void RPCClient::Wait() { + ClientContext context; + VoidMessage call_msg, ret_msg; + stub_->Wait(&context, call_msg, &ret_msg); +} + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/operators/detail/send_recv_impl.h b/paddle/operators/detail/send_recv_impl.h index 6edbb2d8348..82ab3ab6892 100644 --- a/paddle/operators/detail/send_recv_impl.h +++ b/paddle/operators/detail/send_recv_impl.h @@ -81,6 +81,7 @@ class RPCClient { bool SendVariable(const framework::Scope &scope, const std::string &inname); bool GetVariable(const framework::Scope &scope, const std::string &outname); + void Wait(); private: std::unique_ptr stub_; diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 9af8d311d92..6fcb544b5b3 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -76,14 +76,14 @@ class RecvOp : public framework::OperatorBase { const platform::DeviceContext &dev_ctx) const override { // FIXME(typhoonzero): no new scopes for every run. framework::Scope &recv_scope = scope.NewScope(); - rpc_service_.SetScope(&recv_scope); + rpc_service_->SetScope(&recv_scope); auto param_list = Attr>("ParamList"); auto grad_list = Attr>("GradList"); auto trainer_count = Attr("Trainers"); size_t param_count = param_list.size(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. while (true) { - rpc_service_.Start(); + rpc_service_->Start(); // Get from multiple trainers, we don't care about order in which // the gradient arrives, just add suffix 0~n then average the gradient. for (size_t i = 0; i < param_count * trainer_count; ++i) { @@ -126,6 +126,7 @@ class RecvOp : public framework::OperatorBase { } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } + rpc_service_->Done(); // for (size_t i = 0; i < param_count; ++i) { // auto *out_var = recv_scope.FindVar(param_list[i]); diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 3fcd2144f96..e94209ec44f 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -34,34 +34,36 @@ class SendOp : public framework::OperatorBase { const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) { // init client when the operator is created at runtime. - if (!client_) { - std::string endpoint = Attr("endpoint"); - client_.reset(new detail::RPCClient( - grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials()))); - // TODO(typhoonzero): how to call InitVariables + std::vector endpoints = + Attr>("endpoints"); + for (auto ep : endpoints) { + client_map_[ep].reset(new detail::RPCClient( + grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()))); } } void Run(const framework::Scope &scope, const platform::DeviceContext &dev_ctx) const override { auto ins = Inputs("X"); - // TODO(typhoonzero): currently it's non-blocking, - // should block until server responds. - for (auto in : ins) { - bool ret = client_->SendVariable(scope, in); + std::vector epmap = Attr>("epmap"); + // TODO(typhoonzero): use async calls to send multiple variable asyncly. + for (size_t i = 0; i < ins.size(); ++i) { + bool ret = client_map_[epmap[i]]->SendVariable(scope, ins[i]); if (!ret) { - LOG(ERROR) << "send variable error"; + LOG(ERROR) << "send variable error: " << ins[i]; } } - for (auto in : ins) { - bool ret = client_->GetVariable(scope); + client_map_[0]->Wait(); // TODO(typhoonzero): support async optimization + for (size_t i = 0; i < ins.size(); ++i) { + bool ret = client_map_[epmap[i]]->GetVariable(scope, ins[i]); if (!ret) { - LOG(ERROR) << "GetVariable error"; + LOG(ERROR) << "GetVariable error: " << ins[i]; } } } protected: - std::shared_ptr client_{nullptr}; + mutable std::unordered_map> + client_map_; }; class SendOpMaker : public framework::OpProtoAndCheckerMaker { @@ -74,11 +76,13 @@ Recv operator This operator will recv tensor from send_op )DOC"); - AddAttr("endpoint", - "(string, default 127.0.0.1:6164)" - "IP address to listen on.") - .SetDefault("127.0.0.1:6164") - .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + AddAttr>("endpoints", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints to send variables to."); + AddAttr>("epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input " + "variables for mapping"); } }; diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 13006bfd137..e40cdc92b5c 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -145,14 +145,20 @@ class DistributeTranspiler: pserver_endpoints = kwargs["pservers"].split(",") self.param_grad_map = split_method(params_and_grads, pserver_endpoints) - for ep in pserver_endpoints: - # FIXME(typhoonzero): send to different servers can run in parrallel. - send_op = program.global_block().append_op( - type="send", - inputs={"X": self.param_grad_map[ep]["grads"] - }, # inputs is a list of tensors to be send - outputs={}, - attrs={"endpoint": ep}) + send_op_ordered_inputs = [] + epmap = [] + for ep, v in self.param_grad_map.iteritems(): + send_op_ordered_inputs.extend(v["grads"]) + for i in v: + epmap.append(ep) + + send_op = program.global_block().append_op( + type="send", + inputs={"X": send_op_ordered_inputs + }, # inputs is a list of tensors to be send + outputs={}, + attrs={"endpoints": pserver_endpoints, + "epmap": epmap}) def _create_var_for_trainers(self, block, var, trainers): var_list = [] -- GitLab