From 1edc0423d2f2a96a342acdd8750e3608aa7b8ce9 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 24 Jan 2019 19:26:07 +0800 Subject: [PATCH] update send_op --- .../operators/distributed_ops/send_op.cc | 59 ++++++++++++------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index f8b9a1d15a8..21366701030 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/parameter_send.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/platform/profiler.h" @@ -37,30 +38,46 @@ class SendOp : public framework::OperatorBase { const platform::Place& place) const override { auto ins = Inputs("X"); - std::vector epmap = Attr>("epmap"); + auto epmap = Attr>("epmap"); int sync_send = Attr("sync_mode"); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - - distributed::RPCClient* rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); - - std::vector rets; - for (size_t i = 0; i < ins.size(); i++) { - if (NeedSend(scope, ins[i])) { - VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; - rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); - } else { - VLOG(3) << "don't send no-initialied variable: " << ins[i]; + auto send_varnames = Attr>("send_varnames"); + auto height_sections = Attr>("height_sections"); + + if (send_varnames.size() > 0) { + PADDLE_ENFORCE_EQ(ins.size(), 1, ""); + framework::RuntimeContext ctx(Inputs(), Outputs(), scope); + platform::DeviceContextPool& pool = + platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + auto exe_ctx = framework::ExecutionContext(*this, scope, *dev_ctx, ctx); + distributed::send(ins[0], send_varnames, epmap, height_sections, + exe_ctx, scope, static_cast(sync_send)); + } else { + platform::DeviceContextPool& pool = + platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance( + Attr("trainer_id")); + + std::vector rets; + for (size_t i = 0; i < ins.size(); i++) { + if (NeedSend(scope, ins[i])) { + VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; + rets.push_back( + rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); + } else { + VLOG(3) << "don't send no-initialied variable: " << ins[i]; + } } - } - if (sync_send) { - for (size_t i = 0; i < rets.size(); i++) { - VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i]; - PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); - VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i]; + if (sync_send) { + for (size_t i = 0; i < rets.size(); i++) { + VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i]; + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i]; + } } } } -- GitLab