From fb6cc3a1bd40378b3a9d560bd975ab22b730eb2d Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 1 Apr 2019 09:06:33 +0800 Subject: [PATCH] follow commnet, optimize code and add comment test=develop --- .../framework/details/multi_devices_graph_pass.h | 3 +++ paddle/fluid/framework/scope.h | 4 ++++ .../fluid/operators/distributed/parameter_send.cc | 13 ++++++------- paddle/fluid/operators/distributed_ops/send_op.cc | 2 +- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 26fc8dc1986..7cc68dd2d5a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -132,8 +132,11 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { node->Op()->Flush(); } else if (node->Name() == "lookup_table" || node->Name() == "nce" || node->Name() == "hierarchical_sigmoid") { + // in async_mode, we do not need remote prefetch, because communicator + // will do async parameter recv. VLOG(1) << "set " << node->Name() << " op remote_prefetch to false"; node->Op()->SetAttr("remote_prefetch", false); + node->Op()->Flush(); } return false; } diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index cd752077d66..6665458d4c8 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -52,6 +52,10 @@ class Scope { /// Mark it to const because that new kid scope cannot change parent scope. Scope& NewScope() const; + /// Create a sub-scope for current scope but do not record it in the kids to + /// avoid performance problems. + /// Note!!! You should delete the result pointer yourself to avoid memory + /// leak! Scope* NewTmpScope() const; /// Create a variable with given name if it doesn't exist. diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index ec2884c2529..4858dbe84e0 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -81,8 +81,8 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections); auto &send_rows = send_slr.rows(); - std::vector> outs_rows_idx; - std::vector> outs_dense_idx; + std::vector> outs_rows_idx; + std::vector> outs_dense_idx; outs_rows_idx.resize(out_num); outs_dense_idx.resize(out_num); @@ -99,7 +99,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, // split rows index into output sparse vars for (size_t i = 0; i < send_rows.size(); ++i) { - int out_idx = GetSectionIndex(send_rows[i], abs_sections); + size_t out_idx = GetSectionIndex(send_rows[i], abs_sections); outs_rows_idx[out_idx].push_back(send_rows[i]); outs_dense_idx[out_idx].push_back(i); } @@ -160,10 +160,9 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, } } - // note!! only support sync send now - if (true || sync) { - for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + if (sync) { + for (auto &handle : rets) { + PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient"); } } diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 47688d0ad45..b08cd0942f8 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -52,7 +52,7 @@ class SendOp : public framework::OperatorBase { auto send_functor = distributed::ParameterSend(); auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, height_sections); - send_functor(rpc_ctx, scope, static_cast(sync_send)); + send_functor(rpc_ctx, scope, true); } else { distributed::Communicator::GetInstance()->Send(ins[0], scope); } -- GitLab