未验证 提交 535fefb7 编写于 作者: G gongweibao 提交者: GitHub

Fix grpc bugs (#7435)

Fix grpc bugs
上级 448fee3d
...@@ -33,7 +33,7 @@ ExternalProject_Add( ...@@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc extern_grpc
DEPENDS protobuf zlib DEPENDS protobuf zlib
GIT_REPOSITORY "https://github.com/grpc/grpc.git" GIT_REPOSITORY "https://github.com/grpc/grpc.git"
GIT_TAG "v1.7.x" GIT_TAG "v1.8.x"
PREFIX ${GRPC_SOURCES_DIR} PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -87,7 +87,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -87,7 +87,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true; return true;
} }
bool RPCClient::wait() { bool RPCClient::Wait() {
bool ok = true; bool ok = true;
while (true) { while (true) {
...@@ -96,7 +96,6 @@ bool RPCClient::wait() { ...@@ -96,7 +96,6 @@ bool RPCClient::wait() {
} }
if (!Proceed()) { if (!Proceed()) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false; return false;
} }
} }
...@@ -110,9 +109,9 @@ bool RPCClient::Proceed() { ...@@ -110,9 +109,9 @@ bool RPCClient::Proceed() {
// request counts. // request counts.
if (!cq_.Next(&tag, &ok)) { if (!cq_.Next(&tag, &ok)) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false; return false;
} }
req_count_--;
GPR_ASSERT(ok); GPR_ASSERT(ok);
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
...@@ -120,12 +119,15 @@ bool RPCClient::Proceed() { ...@@ -120,12 +119,15 @@ bool RPCClient::Proceed() {
// TODO(gongwb): add more retries. // TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag); ClientBase* c = static_cast<ClientBase*>(tag);
if (!c->status_.ok()) { if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String()
<< " grpc error:" << c->status_.error_message();
delete c; delete c;
return true; return false;
} }
c->Process(); c->Process();
delete c; delete c;
req_count_--;
return true; return true;
} }
...@@ -135,8 +137,12 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { ...@@ -135,8 +137,12 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
return it->second; return it->second;
} }
grpc::ChannelArguments args;
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto ch = std::shared_ptr<grpc::Channel>( auto ch = std::shared_ptr<grpc::Channel>(
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials())); grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args));
channels_[ep] = ch; channels_[ep] = ch;
return ch; return ch;
......
...@@ -130,7 +130,7 @@ class RPCClient { ...@@ -130,7 +130,7 @@ class RPCClient {
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = 600 * 1000); int64_t time_out = 600 * 1000);
bool wait(); bool Wait();
private: private:
bool Proceed(); bool Proceed();
......
...@@ -28,12 +28,15 @@ class RequestBase { ...@@ -28,12 +28,15 @@ class RequestBase {
public: public:
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service, explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq) grpc::ServerCompletionQueue* cq)
: service_(service), cq_(cq), status_(PROCESS) {} : service_(service), cq_(cq), status_(PROCESS) {
PADDLE_ENFORCE(cq_);
}
virtual ~RequestBase() {} virtual ~RequestBase() {}
virtual void Process() { assert(false); } virtual void Process() { assert(false); }
CallStatus Status() { return status_; } CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; } void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }
protected: protected:
grpc::ServerContext ctx_; grpc::ServerContext ctx_;
...@@ -56,12 +59,14 @@ class RequestSend final : public RequestBase { ...@@ -56,12 +59,14 @@ class RequestSend final : public RequestBase {
virtual ~RequestSend() {} virtual ~RequestSend() {}
virtual std::string GetReqName() { return request_.varname(); }
virtual void Process() { virtual void Process() {
MessageWithName msg_with_name = MessageWithName msg_with_name =
std::make_pair(request_.varname(), std::move(request_)); std::make_pair(request_.varname(), std::move(request_));
queue_->Push(std::move(msg_with_name)); queue_->Push(std::move(msg_with_name));
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this); responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
} }
protected: protected:
...@@ -81,6 +86,8 @@ class RequestGet final : public RequestBase { ...@@ -81,6 +86,8 @@ class RequestGet final : public RequestBase {
virtual ~RequestGet() {} virtual ~RequestGet() {}
virtual std::string GetReqName() { return request_.varname(); }
virtual void Process() { virtual void Process() {
// proc request. // proc request.
std::string var_name = request_.varname(); std::string var_name = request_.varname();
...@@ -88,6 +95,7 @@ class RequestGet final : public RequestBase { ...@@ -88,6 +95,7 @@ class RequestGet final : public RequestBase {
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_); SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
// TODO(gongwb): check var's info. // TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this); responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
} }
protected: protected:
...@@ -100,6 +108,8 @@ class RequestGet final : public RequestBase { ...@@ -100,6 +108,8 @@ class RequestGet final : public RequestBase {
void AsyncGRPCServer::RunSyncUpdate() { void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder; grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_); builder.RegisterService(&service_);
cq_send_ = builder.AddCompletionQueue(); cq_send_ = builder.AddCompletionQueue();
...@@ -159,18 +169,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -159,18 +169,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(4) << "create Requestget status:" << get->Status(); VLOG(4) << "create Requestget status:" << get->Status();
} }
void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
delete last;
last = NULL;
return;
}
last->SetStatus(FINISH);
return;
}
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name, std::string cq_name,
std::function<void()> TryToRegisterNewOne) { std::function<void()> TryToRegisterNewOne) {
...@@ -184,13 +182,19 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, ...@@ -184,13 +182,19 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
break; break;
} }
PADDLE_ENFORCE(tag);
if (wait && !done_) { if (wait && !done_) {
Wait(); Wait();
} }
RequestBase* base = (RequestBase*)tag; RequestBase* base = (RequestBase*)tag;
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) { if (!ok) {
VLOG(4) << cq_name << " recv no regular event"; LOG(WARNING) << cq_name << " recv no regular event:argument name"
<< base->GetReqName();
TryToRegisterNewOne(); TryToRegisterNewOne();
delete base; delete base;
continue; continue;
...@@ -201,7 +205,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, ...@@ -201,7 +205,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
VLOG(4) << cq_name << " status:" << base->Status(); VLOG(4) << cq_name << " status:" << base->Status();
TryToRegisterNewOne(); TryToRegisterNewOne();
base->Process(); base->Process();
SetFinishOrDelete(base);
break; break;
} }
case FINISH: { case FINISH: {
......
...@@ -60,7 +60,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -60,7 +60,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std::function<void()> TryToRegisterNewOne); std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne(); void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne(); void TryToRegisterNewGetOne();
void SetFinishOrDelete(RequestBase *&last);
void ShutdownQueue(); void ShutdownQueue();
private: private:
......
...@@ -96,6 +96,8 @@ class RecvOp : public framework::OperatorBase { ...@@ -96,6 +96,8 @@ class RecvOp : public framework::OperatorBase {
rpc_service_->Reset(); rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
VLOG(4) << "param_count:" << param_count
<< " trainer_count:" << trainer_count;
while (!exit_flag) { while (!exit_flag) {
// TODO(gognwb): simply this loop. // TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which // Get from multiple trainers, we don't care about order in which
......
...@@ -48,7 +48,7 @@ class SendOp : public framework::OperatorBase { ...@@ -48,7 +48,7 @@ class SendOp : public framework::OperatorBase {
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
client_.wait(); PADDLE_ENFORCE(client_.Wait());
} }
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册