提交 036a90f1 编写于 作者: W Wu Yi 提交者: gongweibao

Refine rpc client wait sync (#11132)

上级 a3858036
......@@ -38,6 +38,25 @@ void RPCClient::Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient());
}
rpc_client_->InitEventLoop();
}
void RPCClient::InitEventLoop() {
// start the client process thread
// TODO(wuyi): can make this in a threadpool
client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this)));
}
RPCClient::~RPCClient() {
Wait();
cq_.Shutdown();
{
std::lock_guard<std::mutex> guard(chan_mutex_);
for (auto& it : channels_) {
it.second.reset();
}
}
client_thread_->join();
}
bool RPCClient::AsyncSendVariable(const std::string& ep,
......@@ -204,70 +223,37 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
req_count_++;
}
bool RPCClient::Wait() {
VLOG(3) << "RPCClient begin Wait()"
<< " req_count_:" << req_count_;
if (req_count_ <= 0) {
return true;
}
const size_t kReqCnt = req_count_;
bool a[kReqCnt];
std::vector<std::future<void>> waits(req_count_);
std::mutex mu;
for (int i = 0; i < req_count_; i++) {
waits[i] = framework::AsyncIO([i, &a, &mu, this] {
bool ret = Proceed();
std::lock_guard<std::mutex> l(mu);
a[i] = ret;
});
}
for (int i = 0; i < req_count_; i++) {
waits[i].wait();
}
int last_req_count = req_count_;
req_count_ = 0;
for (int i = 0; i < last_req_count; i++) {
if (!a[i]) {
return false;
}
}
return true;
void RPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
}
bool RPCClient::Proceed() {
void* tag = NULL;
void RPCClient::Proceed() {
void* tag = nullptr;
bool ok = false;
// request counts.
if (!cq_.Next(&tag, &ok)) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
GPR_ASSERT(ok);
PADDLE_ENFORCE(tag);
// TODO(gongwb): add more retries.
while (cq_.Next(&tag, &ok)) {
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String()
GPR_ASSERT(ok);
PADDLE_ENFORCE(c);
if (c->status_.ok()) {
c->Process();
} else {
LOG(ERROR) << "var: " << c->var_h_.String()
<< " grpc error:" << c->status_.error_message();
delete c;
return false;
}
c->Process();
delete c;
return true;
{
std::lock_guard<std::mutex> lk(sync_mutex_);
req_count_--;
}
sync_cond_.notify_all();
}
}
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe
std::unique_lock<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
return it->second;
......
......@@ -17,14 +17,17 @@ limitations under the License. */
#include <time.h>
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "grpc++/channel.h"
#include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h"
......@@ -164,6 +167,7 @@ class FetchBarrierProcessor : public BaseProcessor {
class RPCClient {
public:
RPCClient() {}
~RPCClient();
static RPCClient* GetInstance();
......@@ -192,19 +196,28 @@ class RPCClient {
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
bool Wait();
void Wait();
// InitEventLoop should only be called by Init()
void InitEventLoop();
private:
bool Proceed();
void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance.
static void Init();
private:
grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
std::unique_ptr<std::thread> client_thread_;
// mutex for Wait client sync
std::mutex sync_mutex_;
std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0};
std::mutex mutex_;
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
static std::unique_ptr<RPCClient> rpc_client_;
static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient);
......
......@@ -68,9 +68,7 @@ class RequestSend final : public RequestBase {
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestSend() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
......@@ -82,7 +80,6 @@ class RequestSend final : public RequestBase {
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar);
status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
......@@ -125,7 +122,6 @@ class RequestGet final : public RequestBase {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
......@@ -170,10 +166,9 @@ class RequestPrefetch final : public RequestBase {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_);
status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
status_ = FINISH;
}
protected:
......
......@@ -113,10 +113,6 @@ void StartServer() {
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
// FIXME(gongwb): don't use hard time.
sleep(10);
LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown();
server_thread.join();
}
......@@ -127,7 +123,7 @@ TEST(PREFETCH, CPU) {
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
detail::RPCClient client;
detail::RPCClient* client = detail::RPCClient::GetInstance();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
......@@ -141,8 +137,8 @@ TEST(PREFETCH, CPU) {
std::string in_var_name("ids");
std::string out_var_name("out");
client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
client.Wait();
client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place);
......@@ -152,6 +148,7 @@ TEST(PREFETCH, CPU) {
}
}
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
......
......@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
}
};
......
......@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
}
};
......
......@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
}
}
};
......
......@@ -49,13 +49,13 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
if (sync_mode) {
for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
}
}
};
......
......@@ -59,14 +59,14 @@ class SendOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
if (sync_mode) {
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
}
if (outs.size() > 0) {
......@@ -74,13 +74,13 @@ class SendOp : public framework::OperatorBase {
VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
// tell pservers that current trainer have called fetch
for (auto& ep : endpoints) {
VLOG(2) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
rpc_client->Wait();
}
}
};
......
......@@ -61,7 +61,6 @@ void StartServer() {
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(detail::kRequestSend);
std::cout << "before WaitFanInOfSend" << std::endl;
g_rpc_service->WaitBarrier(detail::kRequestSend);
LOG(INFO) << "got nccl id and stop server...";
......@@ -88,12 +87,12 @@ TEST(SendNcclId, GrpcServer) {
int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client;
LOG(INFO) << "connect to server" << ep;
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client.Wait();
client.AsyncSendBatchBarrier(ep);
client.Wait();
detail::RPCClient* client = detail::RPCClient::GetInstance();
LOG(INFO) << "connect to server " << ep;
client->AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client->Wait();
client->AsyncSendBatchBarrier(ep);
client->Wait();
server_thread.join();
g_rpc_service.reset(nullptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册