未验证 提交 535fefb7 编写于 作者: G gongweibao 提交者: GitHub

Fix grpc bugs (#7435)

Fix grpc bugs
上级 448fee3d
......@@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
GIT_REPOSITORY "https://github.com/grpc/grpc.git"
GIT_TAG "v1.7.x"
GIT_TAG "v1.8.x"
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
......
......@@ -87,7 +87,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}
bool RPCClient::wait() {
bool RPCClient::Wait() {
bool ok = true;
while (true) {
......@@ -96,7 +96,6 @@ bool RPCClient::wait() {
}
if (!Proceed()) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
}
......@@ -110,9 +109,9 @@ bool RPCClient::Proceed() {
// request counts.
if (!cq_.Next(&tag, &ok)) {
LOG(ERROR) << "Get meets CompletionQueue error";
return false;
}
req_count_--;
GPR_ASSERT(ok);
PADDLE_ENFORCE(tag);
......@@ -120,12 +119,15 @@ bool RPCClient::Proceed() {
// TODO(gongwb): add more retries.
ClientBase* c = static_cast<ClientBase*>(tag);
if (!c->status_.ok()) {
LOG(ERROR) << "proc param error:" << c->var_h_.String()
<< " grpc error:" << c->status_.error_message();
delete c;
return true;
return false;
}
c->Process();
delete c;
req_count_--;
return true;
}
......@@ -135,8 +137,12 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
return it->second;
}
grpc::ChannelArguments args;
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto ch = std::shared_ptr<grpc::Channel>(
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args));
channels_[ep] = ch;
return ch;
......
......@@ -130,7 +130,7 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool wait();
bool Wait();
private:
bool Proceed();
......
......@@ -28,12 +28,15 @@ class RequestBase {
public:
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq)
: service_(service), cq_(cq), status_(PROCESS) {}
: service_(service), cq_(cq), status_(PROCESS) {
PADDLE_ENFORCE(cq_);
}
virtual ~RequestBase() {}
virtual void Process() { assert(false); }
CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { assert(false); }
protected:
grpc::ServerContext ctx_;
......@@ -56,12 +59,14 @@ class RequestSend final : public RequestBase {
virtual ~RequestSend() {}
virtual std::string GetReqName() { return request_.varname(); }
virtual void Process() {
MessageWithName msg_with_name =
std::make_pair(request_.varname(), std::move(request_));
queue_->Push(std::move(msg_with_name));
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
}
protected:
......@@ -81,6 +86,8 @@ class RequestGet final : public RequestBase {
virtual ~RequestGet() {}
virtual std::string GetReqName() { return request_.varname(); }
virtual void Process() {
// proc request.
std::string var_name = request_.varname();
......@@ -88,6 +95,7 @@ class RequestGet final : public RequestBase {
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
status_ = FINISH;
}
protected:
......@@ -100,6 +108,8 @@ class RequestGet final : public RequestBase {
void AsyncGRPCServer::RunSyncUpdate() {
grpc::ServerBuilder builder;
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_);
cq_send_ = builder.AddCompletionQueue();
......@@ -159,18 +169,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(4) << "create Requestget status:" << get->Status();
}
void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
delete last;
last = NULL;
return;
}
last->SetStatus(FINISH);
return;
}
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
......@@ -184,13 +182,19 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
break;
}
PADDLE_ENFORCE(tag);
if (wait && !done_) {
Wait();
}
RequestBase* base = (RequestBase*)tag;
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) {
VLOG(4) << cq_name << " recv no regular event";
LOG(WARNING) << cq_name << " recv no regular event:argument name"
<< base->GetReqName();
TryToRegisterNewOne();
delete base;
continue;
......@@ -201,7 +205,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
VLOG(4) << cq_name << " status:" << base->Status();
TryToRegisterNewOne();
base->Process();
SetFinishOrDelete(base);
break;
}
case FINISH: {
......
......@@ -60,7 +60,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
void SetFinishOrDelete(RequestBase *&last);
void ShutdownQueue();
private:
......
......@@ -96,6 +96,8 @@ class RecvOp : public framework::OperatorBase {
rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
VLOG(4) << "param_count:" << param_count
<< " trainer_count:" << trainer_count;
while (!exit_flag) {
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
......
......@@ -48,7 +48,7 @@ class SendOp : public framework::OperatorBase {
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
client_.wait();
PADDLE_ENFORCE(client_.Wait());
}
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册