提交 036a90f1 编写于 作者: W Wu Yi 提交者: gongweibao

Refine rpc client wait sync (#11132)

上级 a3858036
...@@ -38,6 +38,25 @@ void RPCClient::Init() { ...@@ -38,6 +38,25 @@ void RPCClient::Init() {
if (rpc_client_.get() == nullptr) { if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient()); rpc_client_.reset(new RPCClient());
} }
rpc_client_->InitEventLoop();
}
void RPCClient::InitEventLoop() {
// start the client process thread
// TODO(wuyi): can make this in a threadpool
client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this)));
}
RPCClient::~RPCClient() {
Wait();
cq_.Shutdown();
{
std::lock_guard<std::mutex> guard(chan_mutex_);
for (auto& it : channels_) {
it.second.reset();
}
}
client_thread_->join();
} }
bool RPCClient::AsyncSendVariable(const std::string& ep, bool RPCClient::AsyncSendVariable(const std::string& ep,
...@@ -204,70 +223,37 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { ...@@ -204,70 +223,37 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
req_count_++; req_count_++;
} }
bool RPCClient::Wait() { void RPCClient::Wait() {
VLOG(3) << "RPCClient begin Wait()" std::unique_lock<std::mutex> lk(sync_mutex_);
<< " req_count_:" << req_count_; sync_cond_.wait(lk, [this] { return req_count_ == 0; });
if (req_count_ <= 0) {
return true;
}
const size_t kReqCnt = req_count_;
bool a[kReqCnt];
std::vector<std::future<void>> waits(req_count_);
std::mutex mu;
for (int i = 0; i < req_count_; i++) {
waits[i] = framework::AsyncIO([i, &a, &mu, this] {
bool ret = Proceed();
std::lock_guard<std::mutex> l(mu);
a[i] = ret;
});
}
for (int i = 0; i < req_count_; i++) {
waits[i].wait();
}
int last_req_count = req_count_;
req_count_ = 0;
for (int i = 0; i < last_req_count; i++) {
if (!a[i]) {
return false;
}
}
return true;
} }
bool RPCClient::Proceed() { void RPCClient::Proceed() {
void* tag = NULL; void* tag = nullptr;
bool ok = false; bool ok = false;
// request counts. while (cq_.Next(&tag, &ok)) {
if (!cq_.Next(&tag, &ok)) { BaseProcessor* c = static_cast<BaseProcessor*>(tag);
LOG(ERROR) << "Get meets CompletionQueue error"; GPR_ASSERT(ok);
return false; PADDLE_ENFORCE(c);
} if (c->status_.ok()) {
c->Process();
GPR_ASSERT(ok); } else {
PADDLE_ENFORCE(tag); LOG(ERROR) << "var: " << c->var_h_.String()
<< " grpc error:" << c->status_.error_message();
// TODO(gongwb): add more retries. }
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String()
<< " grpc error:" << c->status_.error_message();
delete c; delete c;
return false; {
std::lock_guard<std::mutex> lk(sync_mutex_);
req_count_--;
}
sync_cond_.notify_all();
} }
c->Process();
delete c;
return true;
} }
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe // TODO(Yancey1989): make grpc client completely thread-safe
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep); auto it = channels_.find(ep);
if (it != channels_.end()) { if (it != channels_.end()) {
return it->second; return it->second;
......
...@@ -16,15 +16,18 @@ limitations under the License. */ ...@@ -16,15 +16,18 @@ limitations under the License. */
#include <time.h> #include <time.h>
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <ctime> #include <ctime>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
#include <thread> // NOLINT
#include <vector> #include <vector>
#include "grpc++/channel.h"
#include "grpc++/generic/generic_stub.h" #include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h" #include "grpc++/support/byte_buffer.h"
...@@ -164,6 +167,7 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -164,6 +167,7 @@ class FetchBarrierProcessor : public BaseProcessor {
class RPCClient { class RPCClient {
public: public:
RPCClient() {} RPCClient() {}
~RPCClient();
static RPCClient* GetInstance(); static RPCClient* GetInstance();
...@@ -192,19 +196,28 @@ class RPCClient { ...@@ -192,19 +196,28 @@ class RPCClient {
void AsyncSendFetchBarrier(const std::string& ep, void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000); int64_t time_out = 600 * 1000);
bool Wait(); void Wait();
// InitEventLoop should only be called by Init()
void InitEventLoop();
private: private:
bool Proceed(); void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance. // Init is called by GetInstance.
static void Init(); static void Init();
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_; std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
std::unique_ptr<std::thread> client_thread_;
// mutex for Wait client sync
std::mutex sync_mutex_;
std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0}; std::atomic<int64_t> req_count_{0};
std::mutex mutex_;
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
static std::unique_ptr<RPCClient> rpc_client_; static std::unique_ptr<RPCClient> rpc_client_;
static std::once_flag init_flag_; static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient); DISABLE_COPY_AND_ASSIGN(RPCClient);
......
...@@ -68,9 +68,7 @@ class RequestSend final : public RequestBase { ...@@ -68,9 +68,7 @@ class RequestSend final : public RequestBase {
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
} }
virtual ~RequestSend() {} virtual ~RequestSend() {}
std::string GetReqName() override { return request_->Varname(); } std::string GetReqName() override { return request_->Varname(); }
void Process() override { void Process() override {
...@@ -82,7 +80,6 @@ class RequestSend final : public RequestBase { ...@@ -82,7 +80,6 @@ class RequestSend final : public RequestBase {
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar); request_handler_->Handle(varname, scope, invar, &outvar);
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
...@@ -125,7 +122,6 @@ class RequestGet final : public RequestBase { ...@@ -125,7 +122,6 @@ class RequestGet final : public RequestBase {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
} }
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
...@@ -170,10 +166,9 @@ class RequestPrefetch final : public RequestBase { ...@@ -170,10 +166,9 @@ class RequestPrefetch final : public RequestBase {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
status_ = FINISH;
} }
protected: protected:
......
...@@ -113,10 +113,6 @@ void StartServer() { ...@@ -113,10 +113,6 @@ void StartServer() {
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
// FIXME(gongwb): don't use hard time.
sleep(10);
LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown();
server_thread.join(); server_thread.join();
} }
...@@ -127,7 +123,7 @@ TEST(PREFETCH, CPU) { ...@@ -127,7 +123,7 @@ TEST(PREFETCH, CPU) {
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
detail::RPCClient client; detail::RPCClient* client = detail::RPCClient::GetInstance();
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
...@@ -141,8 +137,8 @@ TEST(PREFETCH, CPU) { ...@@ -141,8 +137,8 @@ TEST(PREFETCH, CPU) {
std::string in_var_name("ids"); std::string in_var_name("ids");
std::string out_var_name("out"); std::string out_var_name("out");
client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
client.Wait(); client->Wait();
auto var = scope.Var(out_var_name); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value(); auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place); auto ptr = value.mutable_data<float>(place);
...@@ -152,6 +148,7 @@ TEST(PREFETCH, CPU) { ...@@ -152,6 +148,7 @@ TEST(PREFETCH, CPU) {
} }
} }
g_rpc_service->ShutDown();
server_thread.join(); server_thread.join();
LOG(INFO) << "begin reset"; LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr); g_rpc_service.reset(nullptr);
......
...@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
auto rpc_client = detail::RPCClient::GetInstance(); auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep; VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep); rpc_client->AsyncSendFetchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
} }
}; };
......
...@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
} }
}; };
......
...@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
if (sync_mode) { if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
} }
} }
}; };
......
...@@ -49,13 +49,13 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -49,13 +49,13 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
// need to wait before sending send_barrier message // need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
if (sync_mode) { if (sync_mode) {
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep; VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
} }
} }
}; };
......
...@@ -59,14 +59,14 @@ class SendOp : public framework::OperatorBase { ...@@ -59,14 +59,14 @@ class SendOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
if (sync_mode) { if (sync_mode) {
for (auto& ep : endpoints) { for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep; VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
} }
if (outs.size() > 0) { if (outs.size() > 0) {
...@@ -74,13 +74,13 @@ class SendOp : public framework::OperatorBase { ...@@ -74,13 +74,13 @@ class SendOp : public framework::OperatorBase {
VLOG(2) << "getting " << outs[i] << " from " << epmap[i]; VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
// tell pservers that current trainer have called fetch // tell pservers that current trainer have called fetch
for (auto& ep : endpoints) { for (auto& ep : endpoints) {
VLOG(2) << "send fetch barrier, ep: " << ep; VLOG(2) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep); rpc_client->AsyncSendFetchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); rpc_client->Wait();
} }
} }
}; };
......
...@@ -61,7 +61,6 @@ void StartServer() { ...@@ -61,7 +61,6 @@ void StartServer() {
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(detail::kRequestSend); g_rpc_service->SetCond(detail::kRequestSend);
std::cout << "before WaitFanInOfSend" << std::endl;
g_rpc_service->WaitBarrier(detail::kRequestSend); g_rpc_service->WaitBarrier(detail::kRequestSend);
LOG(INFO) << "got nccl id and stop server..."; LOG(INFO) << "got nccl id and stop server...";
...@@ -88,12 +87,12 @@ TEST(SendNcclId, GrpcServer) { ...@@ -88,12 +87,12 @@ TEST(SendNcclId, GrpcServer) {
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port); std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client; detail::RPCClient* client = detail::RPCClient::GetInstance();
LOG(INFO) << "connect to server" << ep; LOG(INFO) << "connect to server " << ep;
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); client->AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client.Wait(); client->Wait();
client.AsyncSendBatchBarrier(ep); client->AsyncSendBatchBarrier(ep);
client.Wait(); client->Wait();
server_thread.join(); server_thread.join();
g_rpc_service.reset(nullptr); g_rpc_service.reset(nullptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册