From 065b68b6ca53b3eb140a9f3ebe95b8cdd856fef4 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 14 Mar 2019 23:34:25 +0800 Subject: [PATCH] clean code --- .../fluid/operators/distributed/grpc/grpc_server.cc | 6 ------ paddle/fluid/operators/distributed/parameter_send.cc | 6 +++--- paddle/fluid/operators/distributed/request_handler.h | 6 +----- .../operators/distributed/request_handler_impl.cc | 11 ++--------- .../fluid/operators/distributed/variable_response.h | 11 +++-------- .../fluid/operators/distributed_ops/send_recv_util.h | 1 + 6 files changed, 10 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index f32681738c3..b86f0a53c48 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -106,12 +106,6 @@ class RequestSend final : public RequestBase { auto invar = request_->GetVar(); int trainer_id = request_->GetTrainerId(); framework::Variable* outvar = nullptr; - - /* - if (!request_handler_->sync_mode()) { - request_->ReleaseOwnershipOfLocalScope(); - } - */ request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); Finish(reply_, &responder_); } diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index 3fe3be193a3..388bc781c13 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -80,7 +80,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, auto &send_slr = send_var->Get(); auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections); - auto send_rows = send_slr.rows(); + auto &send_rows = send_slr.rows(); std::vector> outs_rows_idx; std::vector> outs_dense_idx; @@ -88,7 +88,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, outs_dense_idx.resize(out_num); auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0]; - auto src = send_slr.value().data(); + auto *src = send_slr.value().data(); // create output var in local scope std::vector outs; @@ -110,8 +110,8 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, outs[i]->set_height(rpc_ctx.height_sections[i]); auto dims = send_slr.GetCompleteDims(); dims[0] = rows_idx.size(); - outs[i]->mutable_value()->mutable_data(dims, send_slr.place()); outs[i]->mutable_rows()->clear(); + outs[i]->mutable_value()->mutable_data(dims, send_slr.place()); if (rows_idx.size() > 0) { for (auto idx : rows_idx) { outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index e777d515ce9..991158ac720 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -71,15 +71,13 @@ class VarHandle { VarHandle(const std::string ep, const std::string& method, const std::string& name, const platform::DeviceContext* p_ctx = nullptr, - const framework::Scope* p_scope = nullptr, - bool delete_local_scope = false) + const framework::Scope* p_scope = nullptr) : status_(kDefaultState) { ep_ = ep; ctx_ = p_ctx; scope_ = p_scope; name_ = name; method_ = method; - delete_local_scope_ = delete_local_scope; } virtual ~VarHandle() {} @@ -101,7 +99,6 @@ class VarHandle { std::unique_lock lk(sync_mutex_); status_ = ok ? kFinishState : kErrorState; } - if (delete_local_scope_ && scope_) delete scope_; VLOG(7) << "VarHandle finish:" << ok; wait_cond_.notify_all(); } @@ -128,7 +125,6 @@ class VarHandle { std::string name_; // RPC method name. std::string method_; - bool delete_local_scope_; protected: std::mutex sync_mutex_; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index e5318f98ca9..e289ec929db 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -59,15 +59,8 @@ bool RequestSendHandler::Handle(const std::string& varname, "async mode should not recv BATCH_BARRIER_MESSAGE or " "COMPLETE_MESSAGE"); } - - try { - executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), - scope); - delete scope; - } catch (std::exception& e) { - LOG(ERROR) << "async: run sub program error " << e.what(); - return false; - } + executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), + scope); return true; } else { // sync rpc_server_->WaitCond(kRequestSend); diff --git a/paddle/fluid/operators/distributed/variable_response.h b/paddle/fluid/operators/distributed/variable_response.h index edc12e2091f..eb3265e0923 100644 --- a/paddle/fluid/operators/distributed/variable_response.h +++ b/paddle/fluid/operators/distributed/variable_response.h @@ -60,13 +60,14 @@ class VariableResponse { bool create_scope = false) : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { if (create_scope) { - local_scope_ = &scope->NewScope(); + local_scope_ = scope->NewTmpScope(); } } virtual ~VariableResponse() { if (local_scope_) { - scope_->DeleteScope(local_scope_); + delete local_scope_; + local_scope_ = nullptr; } } @@ -86,12 +87,6 @@ class VariableResponse { inline std::string Varname() const { return meta_.varname(); } inline std::string OutVarname() const { return meta_.out_varname(); } inline std::string TableName() const { return meta_.table_name(); } - inline void ReleaseOwnershipOfLocalScope() { - PADDLE_ENFORCE(create_scope_, - "only when create_scope_ is true can you release the " - "ownership of local scope"); - local_scope_ = nullptr; - } // should call parse first. framework::Variable* GetVar() { diff --git a/paddle/fluid/operators/distributed_ops/send_recv_util.h b/paddle/fluid/operators/distributed_ops/send_recv_util.h index 1e91f0dd51a..01caee9a925 100644 --- a/paddle/fluid/operators/distributed_ops/send_recv_util.h +++ b/paddle/fluid/operators/distributed_ops/send_recv_util.h @@ -54,6 +54,7 @@ inline int FindOutIdx(int row, const std::vector& abs_sections) { return i - 1; } } + PADDLE_ENFORCE_LT(row, abs_sections.back(), "row should be less then max id"); return abs_sections.size() - 1; } -- GitLab