diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index 09fce06b5a8d227bec2f011f6a485f4d13ff14d9..38b64c3fcd1aaf6d6baecb0e2e08f379351d0c20 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -38,27 +38,27 @@ using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; template -void send(const std::string& var_name, - const std::vector& send_varnames, - const std::vector& epmap, - const std::vector& height_sections, - const framework::ExecutionContext& ctx, const framework::Scope& scope, - bool sync) { - framework::Scope* local_scope = scope.NewTmpScope(); - - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& cpu_ctx = *pool.Get(platform::CPUPlace()); - auto& actual_ctx = *pool.Get(ctx.GetPlace()); - - distributed::RPCClient* rpc_client = +void ParameterSend::operator()(const std::string &var_name, + const std::vector &send_varnames, + const std::vector &epmap, + const std::vector &height_sections, + const framework::ExecutionContext &ctx, + const framework::Scope &scope, bool sync) { + framework::Scope *local_scope = scope.NewTmpScope(); + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &cpu_ctx = *pool.Get(platform::CPUPlace()); + auto &actual_ctx = *pool.Get(ctx.GetPlace()); + + distributed::RPCClient *rpc_client = distributed::RPCClient::GetInstance( ctx.Attr("trainer_id")); - auto* send_var = scope.FindVar(var_name); + auto *send_var = scope.FindVar(var_name); size_t out_num = send_varnames.size(); if (send_var->IsType()) { - auto& send_tensor = send_var->Get(); - auto& send_tensor_dims = send_tensor.dims(); + auto &send_tensor = send_var->Get(); + auto &send_tensor_dims = send_tensor.dims(); std::vector outs_dims; outs_dims.reserve(out_num); @@ -89,13 +89,13 @@ void send(const std::string& var_name, // create output var in local scope size_t row_offset = 0; for (auto i = 0; i < out_num; ++i) { - auto* out = + auto *out = local_scope->Var(send_varnames[i])->GetMutable(); *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]); row_offset += outs_dims[i][0]; } } else if (send_var->IsType()) { - auto& send_slr = send_var->Get(); + auto &send_slr = send_var->Get(); auto abs_sections = ToAbsoluteSection(height_sections); auto send_rows = send_slr.rows(); @@ -109,9 +109,9 @@ void send(const std::string& var_name, auto src = send_slr.value().data(); // create output var in local scope - std::vector outs; - for (auto& name : send_varnames) { - auto* out = local_scope->Var(name)->GetMutable(); + std::vector outs; + for (auto &name : send_varnames) { + auto *out = local_scope->Var(name)->GetMutable(); outs.push_back(out); } @@ -163,8 +163,8 @@ void send(const std::string& var_name, std::vector rets; for (size_t i = 0; i < send_varnames.size(); i++) { - auto& send_var_name = send_varnames[i]; - auto& endpoint = epmap[i]; + auto &send_var_name = send_varnames[i]; + auto &endpoint = epmap[i]; if (NeedSend(*local_scope, send_var_name)) { VLOG(3) << "sending " << send_var_name << " to " << endpoint; rets.push_back(rpc_client->AsyncSendVar(endpoint, cpu_ctx, *local_scope, @@ -183,6 +183,8 @@ void send(const std::string& var_name, delete local_scope; } +template struct ParameterSend; + }; // namespace distributed }; // namespace operators }; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_send.h b/paddle/fluid/operators/distributed/parameter_send.h index 6272cc5d2558902a1914c37ac223749791353111..1746377228d9befb1b9d9a62f30f13cf98ca3f37 100644 --- a/paddle/fluid/operators/distributed/parameter_send.h +++ b/paddle/fluid/operators/distributed/parameter_send.h @@ -24,12 +24,14 @@ namespace operators { namespace distributed { template -void send(const std::string& var_name, - const std::vector& send_varnames, - const std::vector& epmap, - const std::vector& height_sections, - const framework::ExecutionContext& context, - const framework::Scope& scope, bool sync); +struct ParameterSend { + void operator()(const std::string &var_name, + const std::vector &send_varnames, + const std::vector &epmap, + const std::vector &height_sections, + const framework::ExecutionContext &context, + const framework::Scope &scope, bool sync); +}; }; // namespace distributed }; // namespace operators diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index a8bb597cbd59290df1347c164d37104c6ac431e9..0eb30ce695a0364e4e4b759622a6516e5e80b885 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -2,9 +2,9 @@ include(operators) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) else() - set(DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send brpc leveldb snappystream snappy protobuf ssl crypto zlib node) if(WITH_BRPC_RDMA) find_library(IBVERBS_LIBRARY NAMES ibverbs) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 213667010303671da0deaa0da475eff10cfa5e2d..e7ccaa83dea006030ab1c9f3ed0fbd5c2e03012b 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -51,8 +51,9 @@ class SendOp : public framework::OperatorBase { platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); auto exe_ctx = framework::ExecutionContext(*this, scope, *dev_ctx, ctx); - distributed::send(ins[0], send_varnames, epmap, height_sections, - exe_ctx, scope, static_cast(sync_send)); + auto send_functor = distributed::ParameterSend(); + send_functor(ins[0], send_varnames, epmap, height_sections, exe_ctx, + scope, static_cast(sync_send)); } else { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();