From 103c9bb3764d41db007c78c31e4a63c0c1f5bd53 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 24 Mar 2019 22:53:32 +0800 Subject: [PATCH] update rpc_client --- .../async_sparse_param_update_recorder.h | 48 +++++++++++----- .../operators/distributed/brpc/brpc_client.cc | 3 +- .../operators/distributed/brpc/brpc_client.h | 13 ++--- .../operators/distributed/grpc/grpc_client.cc | 56 ++++++++++--------- .../operators/distributed/grpc/grpc_client.h | 4 +- .../operators/distributed/grpc/grpc_server.cc | 5 +- .../operators/distributed/parameter_recv.cc | 23 ++++++-- .../distributed/request_handler_impl.cc | 19 ++++++- .../fluid/operators/distributed/rpc_client.h | 1 + 9 files changed, 116 insertions(+), 56 deletions(-) diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h index 55d6577ef..037187ea9 100644 --- a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h @@ -37,7 +37,16 @@ class ConcurrentSet { ~ConcurrentSet() {} std::future Update(const std::vector& rows) { - auto task = [this, &rows] { + auto task = [this, rows] { + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& id : rows) { + sstream << id << ", "; + } + sstream << "]"; + VLOG(3) << "update ids -> " << sstream.str(); + } for (auto row : rows) { set_.insert(row); } @@ -46,9 +55,21 @@ class ConcurrentSet { } std::future GetAndClear(std::vector* result) { - auto task = [this, result] { + auto task = [this, &result] { result->clear(); - result->insert(result->end(), set_.begin(), set_.end()); + for (auto& id : set_) { + result->push_back(id); + } + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& id : *result) { + sstream << id << ", "; + } + sstream << "]"; + VLOG(3) << "result ids size: " << result->size() << " " + << sstream.str(); + } set_.clear(); }; return pool_->enqueue(std::move(task)); @@ -67,14 +88,16 @@ class AsyncSparseParamUpdateRecorder { int trainer_num, const std::unordered_map& grad_to_param) : trainer_num_(trainer_num), grad_to_param_(grad_to_param) { - std::ostringstream sstream; - sstream << "["; - for (auto& item : grad_to_param) { - sstream << item.first << ":" << item.second << ", "; + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& item : grad_to_param) { + sstream << item.first << ":" << item.second << ", "; + } + sstream << "]"; + VLOG(3) << "trainer_num: " << trainer_num + << " grad_to_param_: " << sstream.str(); } - sstream << "]"; - VLOG(3) << "trainer_num: " << trainer_num - << "grad_to_param_: " << sstream.str(); for (auto& iter : grad_to_param) { param_to_grad_[iter.second] = iter.first; auto& param_name = iter.second; @@ -103,13 +126,12 @@ class AsyncSparseParamUpdateRecorder { void GetAndClear(const std::string& param_name, int trainer_id, std::vector* result) { + VLOG(3) << "GetAndClear param: " << param_name + << " for trainer: " << trainer_id; PADDLE_ENFORCE_LT(trainer_id, trainer_num_); param_to_updated_rows_.at(param_name)[trainer_id] ->GetAndClear(result) .wait(); - VLOG(3) << "GetAndClear param: " << param_name - << " for trainer: " << trainer_id - << " with size: " << result->size(); } bool HasParam(const std::string& param_name) { diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.cc b/paddle/fluid/operators/distributed/brpc/brpc_client.cc index a1a344334..410cc6d1b 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.cc @@ -234,9 +234,10 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep, const framework::Scope& scope, const std::string& var_name, const std::string& out_var_name, + const std::string& table_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC, - time_out); + table_name time_out); } VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep, diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.h b/paddle/fluid/operators/distributed/brpc/brpc_client.h index 501a593b1..33a6a805c 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.h +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.h @@ -66,6 +66,7 @@ class BRPCClient : public RPCClient { const framework::Scope& scope, const std::string& var_name, const std::string& out_var_name, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncGetMonomerBarrier( @@ -107,13 +108,11 @@ class BRPCClient : public RPCClient { void SendComplete() override; private: - VarHandlePtr _AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_var_name, - const std::string& method_name, - int64_t time_out = FLAGS_rpc_deadline); + VarHandlePtr _AsyncGetVar( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + const std::string& out_var_name, const std::string& method_name, + const std::string& table_name, int64_t time_out = FLAGS_rpc_deadline); void Proceed(); ChannelQueuePtr GetChannel(const std::string& ep); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 61e94dae3..8504110c6 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, const framework::Scope& scope, const std::string& var_name, const std::string& out_varname, + const std::string& table_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname, - "/sendrecv.SendRecvService/GetVariable", time_out); + "/sendrecv.SendRecvService/GetVariable", table_name, + time_out); } VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( @@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( return _AsyncGetVar( ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname, - "/sendrecv.SendRecvService/GetVariableNoBarrier", time_out); + "/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out); } VarHandlePtr GRPCClient::AsyncGetMonomerVariable( @@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable( const framework::Scope& scope, const std::string& var_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name, - "/sendrecv.SendRecvService/GetMonomerVariable", time_out); + "/sendrecv.SendRecvService/GetMonomerVariable", "", + time_out); } VarHandlePtr GRPCClient::_AsyncGetVar( const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& method, const std::string& var_name, const std::string& out_varname, - const std::string& rpc_path, int64_t time_out) { + const std::string& rpc_path, const std::string& table_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; const std::string out_varname_val = out_varname; + const std::string table_name_val = table_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); GetProcessor* s = new GetProcessor(ch); @@ -169,32 +174,33 @@ VarHandlePtr GRPCClient::_AsyncGetVar( VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); s->Prepare(h, time_out); - framework::AsyncIO( - [var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] { - // prepare input - sendrecv::VariableMessage req; - req.set_varname(var_name_val); - req.set_out_varname(out_varname_val); - req.set_trainer_id(trainer_id_); - ::grpc::ByteBuffer buf; - RequestToByteBuffer(req, &buf); + framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method, + p_ctx, h, rpc_path, this] { + // prepare input + sendrecv::VariableMessage req; + req.set_varname(var_name_val); + req.set_out_varname(out_varname_val); + req.set_trainer_id(trainer_id_); + req.set_table_name(table_name_val); + ::grpc::ByteBuffer buf; + RequestToByteBuffer(req, &buf); - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - // stub context - s->response_call_back_ = ProcGetResponse; + // stub context + s->response_call_back_ = ProcGetResponse; - platform::RecordRPCEvent record_event(method); + platform::RecordRPCEvent record_event(method); - auto call = - s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + auto call = + s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - }); + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + }); req_count_++; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index ce0d2152a..7eb292676 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -187,6 +187,7 @@ class GRPCClient : public RPCClient { const framework::Scope& scope, const std::string& var_name, const std::string& out_varname, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncGetVarNoBarrier( @@ -239,7 +240,8 @@ class GRPCClient : public RPCClient { const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& method, const std::string& var_name, const std::string& out_varname, - const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline); + const std::string& rpc_path, const std::string& table_name = "", + int64_t time_out = FLAGS_rpc_deadline); private: grpc::CompletionQueue cq_; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index 78cfd3d0c..e1ec9884c 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -136,6 +136,7 @@ class RequestGet final : public RequestBase { // proc request. std::string varname = request_.varname(); std::string out_varname = request_.out_varname(); + std::string table_name = request_.table_name(); int trainer_id = request_.trainer_id(); VLOG(4) << "RequestGet " << out_varname << " from " << varname; @@ -146,12 +147,14 @@ class RequestGet final : public RequestBase { auto* tmp_scope = scope->NewTmpScope(); request_handler_->Handle(varname, tmp_scope, invar, &outvar, trainer_id, - out_varname); + out_varname, table_name); + VLOG(1) << "before SerializeToByteBuffer"; if (outvar) { SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(), &reply_); } + VLOG(1) << "after SerializeToByteBuffer"; delete tmp_scope; Finish(reply_, &responder_); } diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index f40f25c75..a5983593c 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -41,7 +41,7 @@ using DDim = framework::DDim; template void ParameterRecv::operator()(const RpcContext &rpc_ctx, const framework::Scope &scope) { - VLOG(3) << "ParameterRecv in"; + VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name; framework::Scope *local_scope = scope.NewTmpScope(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); @@ -61,7 +61,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, *local_scope, recv_var_name, - recv_var_name)); + recv_var_name, recv_var_name)); } for (size_t i = 0; i < rets.size(); i++) { PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); @@ -73,6 +73,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, // concat recved tensor into one var { size_t output_offset = 0; + size_t row_offset = 0; framework::Tensor *recv_tensor = recv_var->GetMutable(); auto dev_ctx = paddle::platform::CPUDeviceContext(); @@ -92,16 +93,28 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, auto &recv_slr = recv_var->Get(); auto &recv_dims = recv_tensor->dims(); int64_t width = recv_dims[1]; - PADDLE_ENFORCE_EQ(recv_slr.height(), recv_dims[0]); + recv_numel += recv_slr.height() * width; PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width); PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size()); VLOG(3) << "recv slr " << recv_var_name << " dims " << recv_slr.value().dims(); + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto &row_id : recv_slr.rows()) { + sstream << row_id << ", "; + } + sstream << "]"; + VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " " + << sstream.str(); + } for (auto i = 0; i < recv_slr.rows().size(); ++i) { - auto row_id = recv_slr.rows()[i]; + auto row_id = recv_slr.rows()[i] + row_offset; + PADDLE_ENFORCE_LT(row_id, recv_dims[1]); memcpy(recv_tensor->data() + row_id * width, recv_slr.value().data() + i * width, sizeof(T) * width); } + row_offset += recv_slr.height(); } else { PADDLE_THROW("unsupported recieved var type"); } @@ -110,7 +123,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, } delete local_scope; - VLOG(3) << "ParameterRecv out"; + VLOG(3) << "ParameterRecv out" << rpc_ctx.var_name; } template struct ParameterRecv; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index e4c259722..a41536368 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -89,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname, const int trainer_id, const std::string& out_var_name, const std::string& table_name) { - VLOG(4) << "RequestGetHandler:" << varname - << " out_var_name: " << out_var_name; + VLOG(3) << "RequestGetHandler:" << varname + << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id + << " table_name: " << table_name; if (sync_mode_) { if (varname == FETCH_BARRIER_MESSAGE) { @@ -115,10 +116,21 @@ bool RequestGetHandler::Handle(const std::string& varname, VLOG(3) << "copying " << varname << " to " << param_bak_name; framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); } - if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname)) { + if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && + !table_name.empty()) { std::vector updated_rows; AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( varname, trainer_id, &updated_rows); + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& row_id : updated_rows) { + sstream << row_id << ", "; + } + sstream << "]"; + VLOG(3) << "updated_rows size: " << updated_rows.size() << " " + << sstream.str(); + } auto& origin_tensor = scope_->FindVar(varname)->Get(); auto* origin_tensor_data = origin_tensor.data(); @@ -133,6 +145,7 @@ bool RequestGetHandler::Handle(const std::string& varname, out_dims, origin_tensor.place()); auto width = dims[1]; for (auto i = 0; i < updated_rows.size(); ++i) { + PADDLE_ENFORCE_LT(updated_rows[i], dims[0]); memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, sizeof(float) * width); } diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index ea54e0c29..f893510ba 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -44,6 +44,7 @@ class RPCClient { const framework::Scope& scope, const std::string& var_name, const std::string& out_varname, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncGetVarNoBarrier( -- GitLab