diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index b8d3b77ae41e5846ec752ce8513dbbf17dde7735..00956d8e6d961f10c5a709900f3c0c13066382aa 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -39,9 +39,7 @@ using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; template -void ParameterRecv::operator()(const std::string &var_name, - const std::vector &recv_varnames, - const std::vector &epmap, +void ParameterRecv::operator()(const RpcContext &rpc_ctx, const framework::ExecutionContext &ctx, const framework::Scope &scope) { framework::Scope *local_scope = scope.NewTmpScope(); @@ -53,21 +51,22 @@ void ParameterRecv::operator()(const std::string &var_name, distributed::RPCClient::GetInstance( ctx.Attr("trainer_id")); - auto *recv_var = scope.FindVar(var_name); + auto *recv_var = scope.FindVar(rpc_ctx.var_name); std::vector recved_tensors; // recv all vars to local scope if (recv_var->IsType()) { std::vector rets; - for (size_t i = 0; i < recv_varnames.size(); i++) { - auto &recv_var_name = recv_varnames[i]; + for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { + auto &recv_var_name = rpc_ctx.splited_var_names[i]; framework::Tensor *t = local_scope->Var(recv_var_name)->GetMutable(); recved_tensors.push_back(t); - VLOG(3) << "recv " << recv_var_name << " from " << epmap[i]; - rets.push_back(rpc_client->AsyncGetVar(epmap[i], cpu_ctx, *local_scope, - recv_var_name, recv_var_name)); + VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; + rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, + *local_scope, recv_var_name, + recv_var_name)); } for (size_t i = 0; i < rets.size(); i++) { PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); diff --git a/paddle/fluid/operators/distributed/parameter_recv.h b/paddle/fluid/operators/distributed/parameter_recv.h index bc6f5f5adf26a399b9adc001fb767c20fbd95f5e..e25594024af98ca86f92a918c49b3823faf08acf 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.h +++ b/paddle/fluid/operators/distributed/parameter_recv.h @@ -18,6 +18,7 @@ #include #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/distributed/rpc_common.h" namespace paddle { namespace operators { @@ -25,9 +26,7 @@ namespace distributed { template struct ParameterRecv { - void operator()(const std::string &var_name, - const std::vector &recv_varnames, - const std::vector &epmap, + void operator()(const RpcContext &rpc_ctx, const framework::ExecutionContext &context, const framework::Scope &scope); }; diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index fd97926623ddcc9083645903c857d353edc1019c..eaa1c3ae8e8a1528e692f52ddb2b4d3b834ca689 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -38,10 +38,7 @@ using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; template -void ParameterSend::operator()(const std::string &var_name, - const std::vector &send_varnames, - const std::vector &epmap, - const std::vector &height_sections, +void ParameterSend::operator()(const RpcContext &rpc_ctx, const framework::ExecutionContext &ctx, const framework::Scope &scope, bool sync) { framework::Scope *local_scope = scope.NewTmpScope(); @@ -53,8 +50,8 @@ void ParameterSend::operator()(const std::string &var_name, distributed::RPCClient::GetInstance( ctx.Attr("trainer_id")); - auto *send_var = scope.FindVar(var_name); - size_t out_num = send_varnames.size(); + auto *send_var = scope.FindVar(rpc_ctx.var_name); + size_t out_num = rpc_ctx.splited_var_names.size(); if (send_var->IsType()) { if (out_num > 1) { auto &send_tensor = send_var->Get(); @@ -63,19 +60,19 @@ void ParameterSend::operator()(const std::string &var_name, outs_dims.reserve(out_num); // infer output shape - PADDLE_ENFORCE_EQ(height_sections.size(), out_num, + PADDLE_ENFORCE_EQ(rpc_ctx.height_sections.size(), out_num, "tensor split sections size" "should be equal to output size."); for (size_t i = 0; i < out_num; ++i) { auto dim = send_tensor_dims; - dim[0] = height_sections[i]; + dim[0] = rpc_ctx.height_sections[i]; outs_dims.push_back(dim); } // create output var in local scope size_t row_offset = 0; for (auto i = 0; i < out_num; ++i) { - framework::Tensor *out = local_scope->Var(send_varnames[i]) + framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i]) ->GetMutable(); *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]); row_offset += outs_dims[i][0]; @@ -83,7 +80,7 @@ void ParameterSend::operator()(const std::string &var_name, } } else if (send_var->IsType()) { auto &send_slr = send_var->Get(); - auto abs_sections = ToAbsoluteSection(height_sections); + auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections); auto send_rows = send_slr.rows(); std::vector> outs_rows_idx; @@ -97,7 +94,7 @@ void ParameterSend::operator()(const std::string &var_name, // create output var in local scope std::vector outs; - for (auto &name : send_varnames) { + for (auto &name : rpc_ctx.splited_var_names) { auto *out = local_scope->Var(name)->GetMutable(); outs.push_back(out); } @@ -112,7 +109,7 @@ void ParameterSend::operator()(const std::string &var_name, for (size_t i = 0; i < outs_rows_idx.size(); ++i) { auto rows_idx = outs_rows_idx[i]; - outs[i]->set_height(height_sections[i]); + outs[i]->set_height(rpc_ctx.height_sections[i]); auto dims = send_slr.GetCompleteDims(); dims[0] = rows_idx.size(); outs[i]->mutable_value()->mutable_data(dims, send_slr.place()); @@ -149,15 +146,16 @@ void ParameterSend::operator()(const std::string &var_name, } std::vector rets; - for (size_t i = 0; i < send_varnames.size(); i++) { - auto &send_var_name = send_varnames[i]; - auto &endpoint = epmap[i]; + for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { + auto &send_var_name = rpc_ctx.splited_var_names[i]; + auto &endpoint = rpc_ctx.epmap[i]; if (NeedSend(*local_scope, send_var_name)) { VLOG(3) << "sending " << send_var_name << " to " << endpoint; rets.push_back(rpc_client->AsyncSendVar(endpoint, cpu_ctx, *local_scope, send_var_name)); } else { - VLOG(3) << "don't send non-initialized variable: " << send_varnames[i]; + VLOG(3) << "don't send non-initialized variable: " + << rpc_ctx.splited_var_names[i]; } } diff --git a/paddle/fluid/operators/distributed/parameter_send.h b/paddle/fluid/operators/distributed/parameter_send.h index 1746377228d9befb1b9d9a62f30f13cf98ca3f37..4500497163fcc77e148a33cb6020e8de6213b748 100644 --- a/paddle/fluid/operators/distributed/parameter_send.h +++ b/paddle/fluid/operators/distributed/parameter_send.h @@ -18,6 +18,7 @@ #include #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/distributed/rpc_common.h" namespace paddle { namespace operators { @@ -25,10 +26,7 @@ namespace distributed { template struct ParameterSend { - void operator()(const std::string &var_name, - const std::vector &send_varnames, - const std::vector &epmap, - const std::vector &height_sections, + void operator()(const RpcContext &rpc_ctx, const framework::ExecutionContext &context, const framework::Scope &scope, bool sync); }; diff --git a/paddle/fluid/operators/distributed/rpc_common.h b/paddle/fluid/operators/distributed/rpc_common.h index dc50414b9af5132a282c81f85046fa321e6d94a3..7dede07b5ad623f85935f994d66bd7c91172755a 100644 --- a/paddle/fluid/operators/distributed/rpc_common.h +++ b/paddle/fluid/operators/distributed/rpc_common.h @@ -22,6 +22,13 @@ namespace operators { namespace distributed { struct RpcContext { + RpcContext(const std::string& name, const std::vector& names, + const std::vector& emap, + const std::vector& sections) + : var_name(name), + splited_var_names(names), + epmap(emap), + height_sections(sections) {} std::string var_name; std::vector splited_var_names; std::vector epmap; diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index bcb16ff2e576fc26f9bcd35df02be301d79da379..a4a5ab89a7b169e227f6ca275c8b1ec70cb88e19 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/parameter_recv.h" +#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -57,9 +58,11 @@ class RecvOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto *dev_ctx = pool.Get(place); - auto exe_ctx = framework::ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr); + auto exe_ctx = + framework::ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr); auto recv_functor = distributed::ParameterRecv(); - recv_functor(outs[0], recv_varnames, epmap, exe_ctx, scope); + auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); + recv_functor(rpc_ctx, exe_ctx, scope); } else { if (with_barrier) { std::vector rets; diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 801909e2c063e6d6e032fe82536fe9b920dc70b3..1823d89897f7653f7a705f25ad3671c55cf4b1f5 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/parameter_send.h" +#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/platform/profiler.h" @@ -50,10 +51,12 @@ class SendOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); - auto exe_ctx = framework::ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr); + auto exe_ctx = + framework::ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr); auto send_functor = distributed::ParameterSend(); - send_functor(ins[0], send_varnames, epmap, height_sections, exe_ctx, - scope, static_cast(sync_send)); + auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, + height_sections); + send_functor(rpc_ctx, exe_ctx, scope, static_cast(sync_send)); } else { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();