提交 11fe3c79 编写于 作者: X Xin Pan

clean up

上级 b4dd4c04
...@@ -204,7 +204,7 @@ def main(): ...@@ -204,7 +204,7 @@ def main():
with profiler.profiler('All', 'total', with profiler.profiler('All', 'total',
'/tmp/profile_vgg_%d' % args.task_index): '/tmp/profile_vgg_%d' % args.task_index):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if batch_id > 4: break if batch_id > 5: break
run_step(batch_id, data) run_step(batch_id, data)
total_time = 0.0 total_time = 0.0
......
...@@ -33,7 +33,7 @@ ExternalProject_Add( ...@@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc extern_grpc
DEPENDS protobuf zlib DEPENDS protobuf zlib
GIT_REPOSITORY "https://github.com/grpc/grpc.git" GIT_REPOSITORY "https://github.com/grpc/grpc.git"
GIT_TAG "v1.8.x" GIT_TAG "v1.10.x"
PREFIX ${GRPC_SOURCES_DIR} PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -66,11 +66,11 @@ class RequestSend final : public RequestBase { ...@@ -66,11 +66,11 @@ class RequestSend final : public RequestBase {
explicit RequestSend(GrpcService::AsyncService* service, explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, ReceivedQueue* queue, framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx, int i) const platform::DeviceContext* dev_ctx, int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
queue_(queue), queue_(queue),
responder_(&ctx_), responder_(&ctx_),
i_(i) { req_id_(req_id) {
if (sync_mode_) { if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false)); request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else { } else {
...@@ -79,7 +79,7 @@ class RequestSend final : public RequestBase { ...@@ -79,7 +79,7 @@ class RequestSend final : public RequestBase {
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(i))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
} }
virtual ~RequestSend() {} virtual ~RequestSend() {}
...@@ -93,7 +93,7 @@ class RequestSend final : public RequestBase { ...@@ -93,7 +93,7 @@ class RequestSend final : public RequestBase {
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(i_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
protected: protected:
...@@ -101,7 +101,7 @@ class RequestSend final : public RequestBase { ...@@ -101,7 +101,7 @@ class RequestSend final : public RequestBase {
std::shared_ptr<VariableResponse> request_; std::shared_ptr<VariableResponse> request_;
ReceivedQueue* queue_; ReceivedQueue* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
int i_; int req_id_;
}; };
class RequestGet final : public RequestBase { class RequestGet final : public RequestBase {
...@@ -110,16 +110,17 @@ class RequestGet final : public RequestBase { ...@@ -110,16 +110,17 @@ class RequestGet final : public RequestBase {
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
framework::BlockingQueue<MessageWithName>* queue, int i) framework::BlockingQueue<MessageWithName>* queue,
int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
queue_(queue), queue_(queue),
i_(i) { req_id_(req_id) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable); auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_, method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(i))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
virtual ~RequestGet() {} virtual ~RequestGet() {}
...@@ -138,7 +139,7 @@ class RequestGet final : public RequestBase { ...@@ -138,7 +139,7 @@ class RequestGet final : public RequestBase {
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(i_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
if (var_name == FETCH_BARRIER_MESSAGE) { if (var_name == FETCH_BARRIER_MESSAGE) {
sendrecv::VariableMessage msg; sendrecv::VariableMessage msg;
...@@ -153,7 +154,7 @@ class RequestGet final : public RequestBase { ...@@ -153,7 +154,7 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_; framework::Scope* scope_;
framework::BlockingQueue<MessageWithName>* queue_; framework::BlockingQueue<MessageWithName>* queue_;
int i_; int req_id_;
}; };
class RequestPrefetch final : public RequestBase { class RequestPrefetch final : public RequestBase {
...@@ -165,14 +166,14 @@ class RequestPrefetch final : public RequestBase { ...@@ -165,14 +166,14 @@ class RequestPrefetch final : public RequestBase {
framework::Executor* executor, framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx, framework::ExecutorPrepareContext* prefetch_ctx,
int i) int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
executor_(executor), executor_(executor),
program_(program), program_(program),
prefetch_ctx_(prefetch_ctx), prefetch_ctx_(prefetch_ctx),
i_(i) { req_id_(req_id) {
if (sync_mode_) { if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false)); request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else { } else {
...@@ -202,7 +203,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -202,7 +203,7 @@ class RequestPrefetch final : public RequestBase {
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
responder_.Finish(reply, ::grpc::Status::OK, responder_.Finish(reply, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(i_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
status_ = FINISH; status_ = FINISH;
} }
...@@ -213,7 +214,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -213,7 +214,7 @@ class RequestPrefetch final : public RequestBase {
framework::Executor* executor_; framework::Executor* executor_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_; framework::ExecutorPrepareContext* prefetch_ctx_;
int i_; int req_id_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) { void AsyncGRPCServer::WaitClientGet(int count) {
...@@ -291,21 +292,6 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -291,21 +292,6 @@ void AsyncGRPCServer::RunSyncUpdate() {
for (int i = 0; i < kNumHandleGetThreads; ++i) { for (int i = 0; i < kNumHandleGetThreads; ++i) {
t_gets_[i]->join(); t_gets_[i]->join();
} }
{
std::lock_guard<std::mutex> l(cq_mutex_);
for (int i = 0; i < kSendReqsBufSize; ++i) {
if (send_reqs_[i]) {
delete send_reqs_[i];
send_reqs_[i] = nullptr;
}
}
for (int i = 0; i < kGetReqsBufSize; ++i) {
if (get_reqs_[i]) {
delete get_reqs_[i];
get_reqs_[i] = nullptr;
}
}
}
t_prefetch_->join(); t_prefetch_->join();
} }
...@@ -335,19 +321,19 @@ void AsyncGRPCServer::TryToRegisterNewSendOne(int i) { ...@@ -335,19 +321,19 @@ void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
VLOG(4) << "Create RequestSend status:" << send->Status(); VLOG(4) << "Create RequestSend status:" << send->Status();
} }
void AsyncGRPCServer::TryToRegisterNewGetOne(int i) { void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
dev_ctx_, &var_get_queue_, i); dev_ctx_, &var_get_queue_, req_id);
get_reqs_[i] = static_cast<RequestBase*>(get); get_reqs_[req_id] = static_cast<RequestBase*>(get);
VLOG(4) << "Create RequestGet status:" << get->Status(); VLOG(4) << "Create RequestGet status:" << get->Status();
} }
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) { void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
...@@ -355,7 +341,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) { ...@@ -355,7 +341,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) {
} }
RequestPrefetch* prefetch = new RequestPrefetch( RequestPrefetch* prefetch = new RequestPrefetch(
&service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
program_, prefetch_ctx_.get(), i); program_, prefetch_ctx_.get(), req_id);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
} }
...@@ -374,7 +360,7 @@ void AsyncGRPCServer::HandleRequest( ...@@ -374,7 +360,7 @@ void AsyncGRPCServer::HandleRequest(
break; break;
} }
VLOG(3) << "HandleRequest for " << cq_name << " get Next"; VLOG(3) << "HandleRequest for " << cq_name << " get Next";
int i = static_cast<int>(reinterpret_cast<intptr_t>(tag)); int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
if (sync_mode_) { if (sync_mode_) {
// FIXME(typhoonzero): de-couple the barriers with recv_op // FIXME(typhoonzero): de-couple the barriers with recv_op
...@@ -387,9 +373,9 @@ void AsyncGRPCServer::HandleRequest( ...@@ -387,9 +373,9 @@ void AsyncGRPCServer::HandleRequest(
{ {
std::lock_guard<std::mutex> l(cq_mutex_); std::lock_guard<std::mutex> l(cq_mutex_);
if (cq_name == "cq_get") { if (cq_name == "cq_get") {
base = get_reqs_[i]; base = get_reqs_[req_id];
} else if (cq_name == "cq_send") { } else if (cq_name == "cq_send") {
base = send_reqs_[i]; base = send_reqs_[req_id];
} else { } else {
CHECK(false); CHECK(false);
} }
...@@ -401,7 +387,7 @@ void AsyncGRPCServer::HandleRequest( ...@@ -401,7 +387,7 @@ void AsyncGRPCServer::HandleRequest(
if (!ok) { if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name[" LOG(WARNING) << cq_name << " recv no regular event:argument name["
<< base->GetReqName() << "]"; << base->GetReqName() << "]";
TryToRegisterNewOne(i); TryToRegisterNewOne(req_id);
delete base; delete base;
continue; continue;
} }
...@@ -413,7 +399,7 @@ void AsyncGRPCServer::HandleRequest( ...@@ -413,7 +399,7 @@ void AsyncGRPCServer::HandleRequest(
break; break;
} }
case FINISH: { case FINISH: {
TryToRegisterNewOne(i); TryToRegisterNewOne(req_id);
VLOG(4) << cq_name << " FINISH status:" << base->Status(); VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base; delete base;
break; break;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册