From be0c4823048cc88927c4fffb151785b3b1940e60 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 25 Mar 2019 09:30:23 +0800 Subject: [PATCH] update trainer_id --- .../fluid/framework/details/async_ssa_graph_executor.cc | 9 +++++++-- paddle/fluid/operators/distributed/parameter_recv.cc | 4 ++-- paddle/fluid/operators/distributed/parameter_send.cc | 2 +- paddle/fluid/operators/distributed/rpc_common.h | 7 +++++-- paddle/fluid/operators/distributed_ops/recv_op.cc | 7 ++++--- paddle/fluid/operators/distributed_ops/send_op.cc | 6 +++--- 6 files changed, 22 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 12822c64e9f..49c1c0a296b 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -60,9 +60,12 @@ void ProcessGraph(std::vector graphs, Scope *scope) { node->Op()->GetNullableAttr("epmap")); auto height_section = boost::get>( node->Op()->GetNullableAttr("sections")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(send_var_name, send_varnames, - epmap, height_section); + epmap, height_section, + trainer_id); VLOG(3) << "find and init an send op: " << send_varname_to_ctx[send_var_name]; } else if (node->Name() == "recv") { @@ -71,9 +74,11 @@ void ProcessGraph(std::vector graphs, Scope *scope) { node->Op()->GetNullableAttr("recv_varnames")); auto epmap = boost::get>( node->Op()->GetNullableAttr("epmap")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(recv_var_name, recv_varnames, - epmap, {}); + epmap, {}, trainer_id); nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index 7e44bfc82ee..27908aa4683 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -48,7 +48,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(0); + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); auto *recv_var = scope.FindVar(rpc_ctx.var_name); @@ -112,7 +112,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, // FIXME(qiao): use a trick to avoid the bug of recv an selected rows for (auto i = 1; i < recv_slr.rows().size(); ++i) { auto row_id = recv_slr.rows()[i] + row_offset; - PADDLE_ENFORCE_LT(row_id, recv_dims[1]); + PADDLE_ENFORCE_LT(row_id, recv_dims[0]); memcpy(recv_tensor->data() + row_id * width, recv_slr.value().data() + i * width, sizeof(T) * width); } diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index ec2884c2529..a8cebca8d9c 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -46,7 +46,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(0); + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); auto *send_var = scope.FindVar(rpc_ctx.var_name); size_t out_num = rpc_ctx.splited_var_names.size(); diff --git a/paddle/fluid/operators/distributed/rpc_common.h b/paddle/fluid/operators/distributed/rpc_common.h index 3de89c2ae89..eb127bf4ad5 100644 --- a/paddle/fluid/operators/distributed/rpc_common.h +++ b/paddle/fluid/operators/distributed/rpc_common.h @@ -27,23 +27,26 @@ struct RpcContext { RpcContext(const std::string &name, const std::vector &names, const std::vector &emap, - const std::vector §ions) + const std::vector §ions, int id) : var_name(name), splited_var_names(names), epmap(emap), - height_sections(sections) {} + height_sections(sections), + trainer_id(id) {} RpcContext(const RpcContext &ctx) { var_name = ctx.var_name; splited_var_names = ctx.splited_var_names; epmap = ctx.epmap; height_sections = ctx.height_sections; + trainer_id = ctx.trainer_id; } std::string var_name; std::vector splited_var_names; std::vector epmap; std::vector height_sections; + int trainer_id; }; inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 3fd0700a077..8e9846b1fc8 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &ctx = *pool.Get(place); + auto trainer_id = Attr("trainer_id"); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); + distributed::RPCClient::GetInstance(trainer_id); std::vector recv_varnames = Attr>("recv_varnames"); if (recv_varnames.size() > 0) { auto recv_functor = distributed::ParameterRecv(); - auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); + auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}, + trainer_id); recv_functor(rpc_ctx, scope); } else { if (with_barrier) { diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 67de7b4185b..1d5a9b1b22e 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase { auto epmap = Attr>("epmap"); int sync_send = Attr("sync_mode"); + auto trainer_id = Attr("trainer_id"); auto send_varnames = Attr>("send_varnames"); auto height_sections = Attr>("sections"); @@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase { /* auto send_functor = distributed::ParameterSend(); auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, - height_sections); + height_sections, trainer_id); send_functor(rpc_ctx, scope, static_cast(sync_send)); */ VLOG(3) << "send " << ins[0]; @@ -63,8 +64,7 @@ class SendOp : public framework::OperatorBase { auto& ctx = *pool.Get(place); distributed::RPCClient* rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); + distributed::RPCClient::GetInstance(trainer_id); std::vector rets; for (size_t i = 0; i < ins.size(); i++) { -- GitLab