提交 103c9bb3 编写于 作者: Q Qiao Longfei

update rpc_client

上级 b7661d7e
...@@ -37,7 +37,16 @@ class ConcurrentSet { ...@@ -37,7 +37,16 @@ class ConcurrentSet {
~ConcurrentSet() {} ~ConcurrentSet() {}
std::future<void> Update(const std::vector<int64_t>& rows) { std::future<void> Update(const std::vector<int64_t>& rows) {
auto task = [this, &rows] { auto task = [this, rows] {
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& id : rows) {
sstream << id << ", ";
}
sstream << "]";
VLOG(3) << "update ids -> " << sstream.str();
}
for (auto row : rows) { for (auto row : rows) {
set_.insert(row); set_.insert(row);
} }
...@@ -46,9 +55,21 @@ class ConcurrentSet { ...@@ -46,9 +55,21 @@ class ConcurrentSet {
} }
std::future<void> GetAndClear(std::vector<int64_t>* result) { std::future<void> GetAndClear(std::vector<int64_t>* result) {
auto task = [this, result] { auto task = [this, &result] {
result->clear(); result->clear();
result->insert(result->end(), set_.begin(), set_.end()); for (auto& id : set_) {
result->push_back(id);
}
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& id : *result) {
sstream << id << ", ";
}
sstream << "]";
VLOG(3) << "result ids size: " << result->size() << " "
<< sstream.str();
}
set_.clear(); set_.clear();
}; };
return pool_->enqueue(std::move(task)); return pool_->enqueue(std::move(task));
...@@ -67,6 +88,7 @@ class AsyncSparseParamUpdateRecorder { ...@@ -67,6 +88,7 @@ class AsyncSparseParamUpdateRecorder {
int trainer_num, int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param) const std::unordered_map<std::string, std::string>& grad_to_param)
: trainer_num_(trainer_num), grad_to_param_(grad_to_param) { : trainer_num_(trainer_num), grad_to_param_(grad_to_param) {
if (VLOG_IS_ON(3)) {
std::ostringstream sstream; std::ostringstream sstream;
sstream << "["; sstream << "[";
for (auto& item : grad_to_param) { for (auto& item : grad_to_param) {
...@@ -74,7 +96,8 @@ class AsyncSparseParamUpdateRecorder { ...@@ -74,7 +96,8 @@ class AsyncSparseParamUpdateRecorder {
} }
sstream << "]"; sstream << "]";
VLOG(3) << "trainer_num: " << trainer_num VLOG(3) << "trainer_num: " << trainer_num
<< "grad_to_param_: " << sstream.str(); << " grad_to_param_: " << sstream.str();
}
for (auto& iter : grad_to_param) { for (auto& iter : grad_to_param) {
param_to_grad_[iter.second] = iter.first; param_to_grad_[iter.second] = iter.first;
auto& param_name = iter.second; auto& param_name = iter.second;
...@@ -103,13 +126,12 @@ class AsyncSparseParamUpdateRecorder { ...@@ -103,13 +126,12 @@ class AsyncSparseParamUpdateRecorder {
void GetAndClear(const std::string& param_name, int trainer_id, void GetAndClear(const std::string& param_name, int trainer_id,
std::vector<int64_t>* result) { std::vector<int64_t>* result) {
VLOG(3) << "GetAndClear param: " << param_name
<< " for trainer: " << trainer_id;
PADDLE_ENFORCE_LT(trainer_id, trainer_num_); PADDLE_ENFORCE_LT(trainer_id, trainer_num_);
param_to_updated_rows_.at(param_name)[trainer_id] param_to_updated_rows_.at(param_name)[trainer_id]
->GetAndClear(result) ->GetAndClear(result)
.wait(); .wait();
VLOG(3) << "GetAndClear param: " << param_name
<< " for trainer: " << trainer_id
<< " with size: " << result->size();
} }
bool HasParam(const std::string& param_name) { bool HasParam(const std::string& param_name) {
......
...@@ -234,9 +234,10 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep, ...@@ -234,9 +234,10 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const std::string& out_var_name, const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) { int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC, return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC,
time_out); table_name time_out);
} }
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
......
...@@ -66,6 +66,7 @@ class BRPCClient : public RPCClient { ...@@ -66,6 +66,7 @@ class BRPCClient : public RPCClient {
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const std::string& out_var_name, const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerBarrier( VarHandlePtr AsyncGetMonomerBarrier(
...@@ -107,13 +108,11 @@ class BRPCClient : public RPCClient { ...@@ -107,13 +108,11 @@ class BRPCClient : public RPCClient {
void SendComplete() override; void SendComplete() override;
private: private:
VarHandlePtr _AsyncGetVar(const std::string& ep, VarHandlePtr _AsyncGetVar(
const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope, const std::string& var_name,
const std::string& var_name, const std::string& out_var_name, const std::string& method_name,
const std::string& out_var_name, const std::string& table_name, int64_t time_out = FLAGS_rpc_deadline);
const std::string& method_name,
int64_t time_out = FLAGS_rpc_deadline);
void Proceed(); void Proceed();
ChannelQueuePtr GetChannel(const std::string& ep); ChannelQueuePtr GetChannel(const std::string& ep);
......
...@@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const std::string& out_varname, const std::string& out_varname,
const std::string& table_name,
int64_t time_out) { int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname, return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
"/sendrecv.SendRecvService/GetVariable", time_out); "/sendrecv.SendRecvService/GetVariable", table_name,
time_out);
} }
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
...@@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( ...@@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
return _AsyncGetVar( return _AsyncGetVar(
ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname, ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
"/sendrecv.SendRecvService/GetVariableNoBarrier", time_out); "/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out);
} }
VarHandlePtr GRPCClient::AsyncGetMonomerVariable( VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
...@@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable( ...@@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out) { int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name, return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
"/sendrecv.SendRecvService/GetMonomerVariable", time_out); "/sendrecv.SendRecvService/GetMonomerVariable", "",
time_out);
} }
VarHandlePtr GRPCClient::_AsyncGetVar( VarHandlePtr GRPCClient::_AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method, const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname, const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out) { const std::string& rpc_path, const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const std::string out_varname_val = out_varname; const std::string out_varname_val = out_varname;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch); GetProcessor* s = new GetProcessor(ch);
...@@ -169,13 +174,14 @@ VarHandlePtr GRPCClient::_AsyncGetVar( ...@@ -169,13 +174,14 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO( framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method,
[var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] { p_ctx, h, rpc_path, this] {
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
req.set_out_varname(out_varname_val); req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_); req.set_trainer_id(trainer_id_);
req.set_table_name(table_name_val);
::grpc::ByteBuffer buf; ::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf); RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
......
...@@ -187,6 +187,7 @@ class GRPCClient : public RPCClient { ...@@ -187,6 +187,7 @@ class GRPCClient : public RPCClient {
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const std::string& out_varname, const std::string& out_varname,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVarNoBarrier( VarHandlePtr AsyncGetVarNoBarrier(
...@@ -239,7 +240,8 @@ class GRPCClient : public RPCClient { ...@@ -239,7 +240,8 @@ class GRPCClient : public RPCClient {
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method, const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname, const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline); const std::string& rpc_path, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline);
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
......
...@@ -136,6 +136,7 @@ class RequestGet final : public RequestBase { ...@@ -136,6 +136,7 @@ class RequestGet final : public RequestBase {
// proc request. // proc request.
std::string varname = request_.varname(); std::string varname = request_.varname();
std::string out_varname = request_.out_varname(); std::string out_varname = request_.out_varname();
std::string table_name = request_.table_name();
int trainer_id = request_.trainer_id(); int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << out_varname << " from " << varname; VLOG(4) << "RequestGet " << out_varname << " from " << varname;
...@@ -146,12 +147,14 @@ class RequestGet final : public RequestBase { ...@@ -146,12 +147,14 @@ class RequestGet final : public RequestBase {
auto* tmp_scope = scope->NewTmpScope(); auto* tmp_scope = scope->NewTmpScope();
request_handler_->Handle(varname, tmp_scope, invar, &outvar, trainer_id, request_handler_->Handle(varname, tmp_scope, invar, &outvar, trainer_id,
out_varname); out_varname, table_name);
VLOG(1) << "before SerializeToByteBuffer";
if (outvar) { if (outvar) {
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
} }
VLOG(1) << "after SerializeToByteBuffer";
delete tmp_scope; delete tmp_scope;
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
......
...@@ -41,7 +41,7 @@ using DDim = framework::DDim; ...@@ -41,7 +41,7 @@ using DDim = framework::DDim;
template <typename T> template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) { const framework::Scope &scope) {
VLOG(3) << "ParameterRecv in"; VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
framework::Scope *local_scope = scope.NewTmpScope(); framework::Scope *local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
...@@ -61,7 +61,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -61,7 +61,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope, recv_var_name, *local_scope, recv_var_name,
recv_var_name)); recv_var_name, recv_var_name));
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
...@@ -73,6 +73,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -73,6 +73,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
// concat recved tensor into one var // concat recved tensor into one var
{ {
size_t output_offset = 0; size_t output_offset = 0;
size_t row_offset = 0;
framework::Tensor *recv_tensor = framework::Tensor *recv_tensor =
recv_var->GetMutable<framework::LoDTensor>(); recv_var->GetMutable<framework::LoDTensor>();
auto dev_ctx = paddle::platform::CPUDeviceContext(); auto dev_ctx = paddle::platform::CPUDeviceContext();
...@@ -92,16 +93,28 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -92,16 +93,28 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto &recv_slr = recv_var->Get<framework::SelectedRows>(); auto &recv_slr = recv_var->Get<framework::SelectedRows>();
auto &recv_dims = recv_tensor->dims(); auto &recv_dims = recv_tensor->dims();
int64_t width = recv_dims[1]; int64_t width = recv_dims[1];
PADDLE_ENFORCE_EQ(recv_slr.height(), recv_dims[0]); recv_numel += recv_slr.height() * width;
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width); PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width);
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size()); PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size());
VLOG(3) << "recv slr " << recv_var_name << " dims " VLOG(3) << "recv slr " << recv_var_name << " dims "
<< recv_slr.value().dims(); << recv_slr.value().dims();
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto &row_id : recv_slr.rows()) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " "
<< sstream.str();
}
for (auto i = 0; i < recv_slr.rows().size(); ++i) { for (auto i = 0; i < recv_slr.rows().size(); ++i) {
auto row_id = recv_slr.rows()[i]; auto row_id = recv_slr.rows()[i] + row_offset;
PADDLE_ENFORCE_LT(row_id, recv_dims[1]);
memcpy(recv_tensor->data<T>() + row_id * width, memcpy(recv_tensor->data<T>() + row_id * width,
recv_slr.value().data<T>() + i * width, sizeof(T) * width); recv_slr.value().data<T>() + i * width, sizeof(T) * width);
} }
row_offset += recv_slr.height();
} else { } else {
PADDLE_THROW("unsupported recieved var type"); PADDLE_THROW("unsupported recieved var type");
} }
...@@ -110,7 +123,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -110,7 +123,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
} }
delete local_scope; delete local_scope;
VLOG(3) << "ParameterRecv out"; VLOG(3) << "ParameterRecv out" << rpc_ctx.var_name;
} }
template struct ParameterRecv<float>; template struct ParameterRecv<float>;
......
...@@ -89,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -89,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string& out_var_name,
const std::string& table_name) { const std::string& table_name) {
VLOG(4) << "RequestGetHandler:" << varname VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name;
if (sync_mode_) { if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) { if (varname == FETCH_BARRIER_MESSAGE) {
...@@ -115,10 +116,21 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -115,10 +116,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name; VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
} }
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname)) { if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
std::vector<int64_t> updated_rows; std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows); varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto& origin_tensor = auto& origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>(); scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>(); auto* origin_tensor_data = origin_tensor.data<float>();
...@@ -133,6 +145,7 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -133,6 +145,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
out_dims, origin_tensor.place()); out_dims, origin_tensor.place());
auto width = dims[1]; auto width = dims[1];
for (auto i = 0; i < updated_rows.size(); ++i) { for (auto i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width); sizeof(float) * width);
} }
......
...@@ -44,6 +44,7 @@ class RPCClient { ...@@ -44,6 +44,7 @@ class RPCClient {
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const std::string& out_varname, const std::string& out_varname,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetVarNoBarrier( virtual VarHandlePtr AsyncGetVarNoBarrier(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册