提交 722c078b 编写于 作者: X Xin Pan

fix test and clean up

上级 11fe3c79
......@@ -25,6 +25,7 @@ namespace detail {
namespace {
const int kNumHandleSendThreads = 20;
const int kNumHandleGetThreads = 20;
const int kNumHandlePrefetchThreads = 1;
} // namespace
enum CallStatus { PROCESS = 0, FINISH };
......@@ -180,8 +181,9 @@ class RequestPrefetch final : public RequestBase {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
}
virtual ~RequestPrefetch() {}
......@@ -190,7 +192,6 @@ class RequestPrefetch final : public RequestBase {
virtual void Process() {
// prefetch process...
::grpc::ByteBuffer reply;
std::string var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << var_name;
......@@ -200,15 +201,16 @@ class RequestPrefetch final : public RequestBase {
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
responder_.Finish(reply, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
}
protected:
std::shared_ptr<VariableResponse> request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::Executor* executor_;
......@@ -262,6 +264,9 @@ void AsyncGRPCServer::RunSyncUpdate() {
for (int i = 0; i < kGetReqsBufSize; ++i) {
TryToRegisterNewGetOne(i);
}
for (int i = 0; i < kPrefetchReqsBufSize; ++i) {
TryToRegisterNewPrefetchOne(i);
}
for (int i = 0; i < kNumHandleSendThreads; ++i) {
t_sends_.emplace_back(
......@@ -273,12 +278,11 @@ void AsyncGRPCServer::RunSyncUpdate() {
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));
}
// TODO(wuyi): Run these "HandleRequest" in thread pool
t_prefetch_.reset(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
t_prefetchs_.emplace_back(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1;
......@@ -292,7 +296,9 @@ void AsyncGRPCServer::RunSyncUpdate() {
for (int i = 0; i < kNumHandleGetThreads; ++i) {
t_gets_[i]->join();
}
t_prefetch_->join();
for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
t_prefetchs_[i]->join();
}
}
void AsyncGRPCServer::ShutdownQueue() {
......@@ -342,6 +348,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
RequestPrefetch* prefetch = new RequestPrefetch(
&service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
program_, prefetch_ctx_.get(), req_id);
prefetch_reqs_[req_id] = static_cast<RequestBase*>(prefetch);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
......@@ -376,8 +383,8 @@ void AsyncGRPCServer::HandleRequest(
base = get_reqs_[req_id];
} else if (cq_name == "cq_send") {
base = send_reqs_[req_id];
} else {
CHECK(false);
} else if (cq_name == "cq_prefetch") {
base = prefetch_reqs_[req_id];
}
}
// reference:
......
......@@ -93,6 +93,7 @@ class AsyncGRPCServer final {
private:
static const int kSendReqsBufSize = 100;
static const int kGetReqsBufSize = 100;
static const int kPrefetchReqsBufSize = 10;
std::mutex cq_mutex_;
volatile bool is_shut_down_ = false;
......@@ -102,6 +103,7 @@ class AsyncGRPCServer final {
RequestBase *send_reqs_[kSendReqsBufSize];
RequestBase *get_reqs_[kGetReqsBufSize];
RequestBase *prefetch_reqs_[kPrefetchReqsBufSize];
GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_;
......@@ -123,6 +125,7 @@ class AsyncGRPCServer final {
std::vector<std::unique_ptr<std::thread>> t_sends_;
std::vector<std::unique_ptr<std::thread>> t_gets_;
std::vector<std::unique_ptr<std::thread>> t_prefetchs_;
std::unique_ptr<std::thread> t_prefetch_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册