提交 103c9bb3 编写于 作者: Q Qiao Longfei

update rpc_client

上级 b7661d7e
......@@ -37,7 +37,16 @@ class ConcurrentSet {
~ConcurrentSet() {}
std::future<void> Update(const std::vector<int64_t>& rows) {
auto task = [this, &rows] {
auto task = [this, rows] {
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& id : rows) {
sstream << id << ", ";
}
sstream << "]";
VLOG(3) << "update ids -> " << sstream.str();
}
for (auto row : rows) {
set_.insert(row);
}
......@@ -46,9 +55,21 @@ class ConcurrentSet {
}
std::future<void> GetAndClear(std::vector<int64_t>* result) {
auto task = [this, result] {
auto task = [this, &result] {
result->clear();
result->insert(result->end(), set_.begin(), set_.end());
for (auto& id : set_) {
result->push_back(id);
}
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& id : *result) {
sstream << id << ", ";
}
sstream << "]";
VLOG(3) << "result ids size: " << result->size() << " "
<< sstream.str();
}
set_.clear();
};
return pool_->enqueue(std::move(task));
......@@ -67,14 +88,16 @@ class AsyncSparseParamUpdateRecorder {
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param)
: trainer_num_(trainer_num), grad_to_param_(grad_to_param) {
std::ostringstream sstream;
sstream << "[";
for (auto& item : grad_to_param) {
sstream << item.first << ":" << item.second << ", ";
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& item : grad_to_param) {
sstream << item.first << ":" << item.second << ", ";
}
sstream << "]";
VLOG(3) << "trainer_num: " << trainer_num
<< " grad_to_param_: " << sstream.str();
}
sstream << "]";
VLOG(3) << "trainer_num: " << trainer_num
<< "grad_to_param_: " << sstream.str();
for (auto& iter : grad_to_param) {
param_to_grad_[iter.second] = iter.first;
auto& param_name = iter.second;
......@@ -103,13 +126,12 @@ class AsyncSparseParamUpdateRecorder {
void GetAndClear(const std::string& param_name, int trainer_id,
std::vector<int64_t>* result) {
VLOG(3) << "GetAndClear param: " << param_name
<< " for trainer: " << trainer_id;
PADDLE_ENFORCE_LT(trainer_id, trainer_num_);
param_to_updated_rows_.at(param_name)[trainer_id]
->GetAndClear(result)
.wait();
VLOG(3) << "GetAndClear param: " << param_name
<< " for trainer: " << trainer_id
<< " with size: " << result->size();
}
bool HasParam(const std::string& param_name) {
......
......@@ -234,9 +234,10 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC,
time_out);
table_name time_out);
}
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
......
......@@ -66,6 +66,7 @@ class BRPCClient : public RPCClient {
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerBarrier(
......@@ -107,13 +108,11 @@ class BRPCClient : public RPCClient {
void SendComplete() override;
private:
VarHandlePtr _AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
const std::string& method_name,
int64_t time_out = FLAGS_rpc_deadline);
VarHandlePtr _AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_var_name, const std::string& method_name,
const std::string& table_name, int64_t time_out = FLAGS_rpc_deadline);
void Proceed();
ChannelQueuePtr GetChannel(const std::string& ep);
......
......@@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
const std::string& table_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
"/sendrecv.SendRecvService/GetVariable", time_out);
"/sendrecv.SendRecvService/GetVariable", table_name,
time_out);
}
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
......@@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
return _AsyncGetVar(
ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
"/sendrecv.SendRecvService/GetVariableNoBarrier", time_out);
"/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out);
}
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
......@@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
"/sendrecv.SendRecvService/GetMonomerVariable", time_out);
"/sendrecv.SendRecvService/GetMonomerVariable", "",
time_out);
}
VarHandlePtr GRPCClient::_AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out) {
const std::string& rpc_path, const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const std::string out_varname_val = out_varname;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
......@@ -169,32 +174,33 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO(
[var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method,
p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
req.set_table_name(table_name_val);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = ProcGetResponse;
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method);
platform::RecordRPCEvent record_event(method);
auto call =
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
auto call =
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
......
......@@ -187,6 +187,7 @@ class GRPCClient : public RPCClient {
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVarNoBarrier(
......@@ -239,7 +240,8 @@ class GRPCClient : public RPCClient {
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline);
const std::string& rpc_path, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline);
private:
grpc::CompletionQueue cq_;
......
......@@ -136,6 +136,7 @@ class RequestGet final : public RequestBase {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
std::string table_name = request_.table_name();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << out_varname << " from " << varname;
......@@ -146,12 +147,14 @@ class RequestGet final : public RequestBase {
auto* tmp_scope = scope->NewTmpScope();
request_handler_->Handle(varname, tmp_scope, invar, &outvar, trainer_id,
out_varname);
out_varname, table_name);
VLOG(1) << "before SerializeToByteBuffer";
if (outvar) {
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
VLOG(1) << "after SerializeToByteBuffer";
delete tmp_scope;
Finish(reply_, &responder_);
}
......
......@@ -41,7 +41,7 @@ using DDim = framework::DDim;
template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) {
VLOG(3) << "ParameterRecv in";
VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
framework::Scope *local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -61,7 +61,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope, recv_var_name,
recv_var_name));
recv_var_name, recv_var_name));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
......@@ -73,6 +73,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
// concat recved tensor into one var
{
size_t output_offset = 0;
size_t row_offset = 0;
framework::Tensor *recv_tensor =
recv_var->GetMutable<framework::LoDTensor>();
auto dev_ctx = paddle::platform::CPUDeviceContext();
......@@ -92,16 +93,28 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto &recv_slr = recv_var->Get<framework::SelectedRows>();
auto &recv_dims = recv_tensor->dims();
int64_t width = recv_dims[1];
PADDLE_ENFORCE_EQ(recv_slr.height(), recv_dims[0]);
recv_numel += recv_slr.height() * width;
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width);
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size());
VLOG(3) << "recv slr " << recv_var_name << " dims "
<< recv_slr.value().dims();
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto &row_id : recv_slr.rows()) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " "
<< sstream.str();
}
for (auto i = 0; i < recv_slr.rows().size(); ++i) {
auto row_id = recv_slr.rows()[i];
auto row_id = recv_slr.rows()[i] + row_offset;
PADDLE_ENFORCE_LT(row_id, recv_dims[1]);
memcpy(recv_tensor->data<T>() + row_id * width,
recv_slr.value().data<T>() + i * width, sizeof(T) * width);
}
row_offset += recv_slr.height();
} else {
PADDLE_THROW("unsupported recieved var type");
}
......@@ -110,7 +123,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
}
delete local_scope;
VLOG(3) << "ParameterRecv out";
VLOG(3) << "ParameterRecv out" << rpc_ctx.var_name;
}
template struct ParameterRecv<float>;
......
......@@ -89,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name;
VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name;
if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) {
......@@ -115,10 +116,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname)) {
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto& origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>();
......@@ -133,6 +145,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
out_dims, origin_tensor.place());
auto width = dims[1];
for (auto i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
}
......
......@@ -44,6 +44,7 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetVarNoBarrier(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册