diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 12822c64e9f7ce20ebd9d1ac3c7479396cb7ea2f..49c1c0a296b71bfa4d3ee90a1c5acd0356179fe2 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -60,9 +60,12 @@ void ProcessGraph(std::vector graphs, Scope *scope) { node->Op()->GetNullableAttr("epmap")); auto height_section = boost::get>( node->Op()->GetNullableAttr("sections")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(send_var_name, send_varnames, - epmap, height_section); + epmap, height_section, + trainer_id); VLOG(3) << "find and init an send op: " << send_varname_to_ctx[send_var_name]; } else if (node->Name() == "recv") { @@ -71,9 +74,11 @@ void ProcessGraph(std::vector graphs, Scope *scope) { node->Op()->GetNullableAttr("recv_varnames")); auto epmap = boost::get>( node->Op()->GetNullableAttr("epmap")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(recv_var_name, recv_varnames, - epmap, {}); + epmap, {}, trainer_id); nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index 7e44bfc82eeadee320771387ab518c7345b17acc..27908aa468347d0171e6e46bc5903ad47abcde70 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -48,7 +48,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(0); + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); auto *recv_var = scope.FindVar(rpc_ctx.var_name); @@ -112,7 +112,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, // FIXME(qiao): use a trick to avoid the bug of recv an selected rows for (auto i = 1; i < recv_slr.rows().size(); ++i) { auto row_id = recv_slr.rows()[i] + row_offset; - PADDLE_ENFORCE_LT(row_id, recv_dims[1]); + PADDLE_ENFORCE_LT(row_id, recv_dims[0]); memcpy(recv_tensor->data() + row_id * width, recv_slr.value().data() + i * width, sizeof(T) * width); } diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index ec2884c25290aa3cfd9818ead61119cc6c6b6feb..a8cebca8d9c53521f4cbda35cd76151b2b48399b 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -46,7 +46,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(0); + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); auto *send_var = scope.FindVar(rpc_ctx.var_name); size_t out_num = rpc_ctx.splited_var_names.size(); diff --git a/paddle/fluid/operators/distributed/rpc_common.h b/paddle/fluid/operators/distributed/rpc_common.h index 3de89c2ae89d29edc317ca123882d1c55038b6ca..eb127bf4ad5a5c9a28210e2fbcdb69b07543f4b9 100644 --- a/paddle/fluid/operators/distributed/rpc_common.h +++ b/paddle/fluid/operators/distributed/rpc_common.h @@ -27,23 +27,26 @@ struct RpcContext { RpcContext(const std::string &name, const std::vector &names, const std::vector &emap, - const std::vector §ions) + const std::vector §ions, int id) : var_name(name), splited_var_names(names), epmap(emap), - height_sections(sections) {} + height_sections(sections), + trainer_id(id) {} RpcContext(const RpcContext &ctx) { var_name = ctx.var_name; splited_var_names = ctx.splited_var_names; epmap = ctx.epmap; height_sections = ctx.height_sections; + trainer_id = ctx.trainer_id; } std::string var_name; std::vector splited_var_names; std::vector epmap; std::vector height_sections; + int trainer_id; }; inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 3fd0700a077321d931e87b1d94c3637d167c9eff..8e9846b1fc89953526149be3838103526d5c441b 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &ctx = *pool.Get(place); + auto trainer_id = Attr("trainer_id"); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); + distributed::RPCClient::GetInstance(trainer_id); std::vector recv_varnames = Attr>("recv_varnames"); if (recv_varnames.size() > 0) { auto recv_functor = distributed::ParameterRecv(); - auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); + auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}, + trainer_id); recv_functor(rpc_ctx, scope); } else { if (with_barrier) { diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 67de7b4185b574412292b98ee6ba182cf118a4e6..1d5a9b1b22eb9e949231838db052800c64f682c6 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase { auto epmap = Attr>("epmap"); int sync_send = Attr("sync_mode"); + auto trainer_id = Attr("trainer_id"); auto send_varnames = Attr>("send_varnames"); auto height_sections = Attr>("sections"); @@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase { /* auto send_functor = distributed::ParameterSend(); auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, - height_sections); + height_sections, trainer_id); send_functor(rpc_ctx, scope, static_cast(sync_send)); */ VLOG(3) << "send " << ins[0]; @@ -63,8 +64,7 @@ class SendOp : public framework::OperatorBase { auto& ctx = *pool.Get(place); distributed::RPCClient* rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); + distributed::RPCClient::GetInstance(trainer_id); std::vector rets; for (size_t i = 0; i < ins.size(); i++) {