提交 11fe3c79 编写于 作者: X Xin Pan

clean up

上级 b4dd4c04
......@@ -204,7 +204,7 @@ def main():
with profiler.profiler('All', 'total',
'/tmp/profile_vgg_%d' % args.task_index):
for batch_id, data in enumerate(train_reader()):
if batch_id > 4: break
if batch_id > 5: break
run_step(batch_id, data)
total_time = 0.0
......
......@@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
GIT_REPOSITORY "https://github.com/grpc/grpc.git"
GIT_TAG "v1.8.x"
GIT_TAG "v1.10.x"
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
......
......@@ -66,11 +66,11 @@ class RequestSend final : public RequestBase {
explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx, int i)
const platform::DeviceContext* dev_ctx, int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx),
queue_(queue),
responder_(&ctx_),
i_(i) {
req_id_(req_id) {
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
......@@ -79,7 +79,7 @@ class RequestSend final : public RequestBase {
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(i)));
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestSend() {}
......@@ -93,7 +93,7 @@ class RequestSend final : public RequestBase {
status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(i_)));
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
}
protected:
......@@ -101,7 +101,7 @@ class RequestSend final : public RequestBase {
std::shared_ptr<VariableResponse> request_;
ReceivedQueue* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
int i_;
int req_id_;
};
class RequestGet final : public RequestBase {
......@@ -110,16 +110,17 @@ class RequestGet final : public RequestBase {
::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::BlockingQueue<MessageWithName>* queue, int i)
framework::BlockingQueue<MessageWithName>* queue,
int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_),
scope_(scope),
queue_(queue),
i_(i) {
req_id_(req_id) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(i)));
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
}
virtual ~RequestGet() {}
......@@ -138,7 +139,7 @@ class RequestGet final : public RequestBase {
status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(i_)));
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
if (var_name == FETCH_BARRIER_MESSAGE) {
sendrecv::VariableMessage msg;
......@@ -153,7 +154,7 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::BlockingQueue<MessageWithName>* queue_;
int i_;
int req_id_;
};
class RequestPrefetch final : public RequestBase {
......@@ -165,14 +166,14 @@ class RequestPrefetch final : public RequestBase {
framework::Executor* executor,
framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx,
int i)
int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_),
scope_(scope),
executor_(executor),
program_(program),
prefetch_ctx_(prefetch_ctx),
i_(i) {
req_id_(req_id) {
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
......@@ -202,7 +203,7 @@ class RequestPrefetch final : public RequestBase {
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
responder_.Finish(reply, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(i_)));
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
status_ = FINISH;
}
......@@ -213,7 +214,7 @@ class RequestPrefetch final : public RequestBase {
framework::Executor* executor_;
framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int i_;
int req_id_;
};
void AsyncGRPCServer::WaitClientGet(int count) {
......@@ -291,21 +292,6 @@ void AsyncGRPCServer::RunSyncUpdate() {
for (int i = 0; i < kNumHandleGetThreads; ++i) {
t_gets_[i]->join();
}
{
std::lock_guard<std::mutex> l(cq_mutex_);
for (int i = 0; i < kSendReqsBufSize; ++i) {
if (send_reqs_[i]) {
delete send_reqs_[i];
send_reqs_[i] = nullptr;
}
}
for (int i = 0; i < kGetReqsBufSize; ++i) {
if (get_reqs_[i]) {
delete get_reqs_[i];
get_reqs_[i] = nullptr;
}
}
}
t_prefetch_->join();
}
......@@ -335,19 +321,19 @@ void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
VLOG(4) << "Create RequestSend status:" << send->Status();
}
void AsyncGRPCServer::TryToRegisterNewGetOne(int i) {
void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
dev_ctx_, &var_get_queue_, i);
get_reqs_[i] = static_cast<RequestBase*>(get);
dev_ctx_, &var_get_queue_, req_id);
get_reqs_[req_id] = static_cast<RequestBase*>(get);
VLOG(4) << "Create RequestGet status:" << get->Status();
}
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) {
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
......@@ -355,7 +341,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) {
}
RequestPrefetch* prefetch = new RequestPrefetch(
&service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
program_, prefetch_ctx_.get(), i);
program_, prefetch_ctx_.get(), req_id);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
......@@ -374,7 +360,7 @@ void AsyncGRPCServer::HandleRequest(
break;
}
VLOG(3) << "HandleRequest for " << cq_name << " get Next";
int i = static_cast<int>(reinterpret_cast<intptr_t>(tag));
int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
if (sync_mode_) {
// FIXME(typhoonzero): de-couple the barriers with recv_op
......@@ -387,9 +373,9 @@ void AsyncGRPCServer::HandleRequest(
{
std::lock_guard<std::mutex> l(cq_mutex_);
if (cq_name == "cq_get") {
base = get_reqs_[i];
base = get_reqs_[req_id];
} else if (cq_name == "cq_send") {
base = send_reqs_[i];
base = send_reqs_[req_id];
} else {
CHECK(false);
}
......@@ -401,7 +387,7 @@ void AsyncGRPCServer::HandleRequest(
if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name["
<< base->GetReqName() << "]";
TryToRegisterNewOne(i);
TryToRegisterNewOne(req_id);
delete base;
continue;
}
......@@ -413,7 +399,7 @@ void AsyncGRPCServer::HandleRequest(
break;
}
case FINISH: {
TryToRegisterNewOne(i);
TryToRegisterNewOne(req_id);
VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base;
break;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册