提交 722c078b 编写于 作者: X Xin Pan

fix test and clean up

上级 11fe3c79
...@@ -25,6 +25,7 @@ namespace detail { ...@@ -25,6 +25,7 @@ namespace detail {
namespace { namespace {
const int kNumHandleSendThreads = 20; const int kNumHandleSendThreads = 20;
const int kNumHandleGetThreads = 20; const int kNumHandleGetThreads = 20;
const int kNumHandlePrefetchThreads = 1;
} // namespace } // namespace
enum CallStatus { PROCESS = 0, FINISH }; enum CallStatus { PROCESS = 0, FINISH };
...@@ -180,8 +181,9 @@ class RequestPrefetch final : public RequestBase { ...@@ -180,8 +181,9 @@ class RequestPrefetch final : public RequestBase {
request_.reset(new VariableResponse(scope, dev_ctx_, true)); request_.reset(new VariableResponse(scope, dev_ctx_, true));
} }
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(
cq_, cq_, this); method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
virtual ~RequestPrefetch() {} virtual ~RequestPrefetch() {}
...@@ -190,7 +192,6 @@ class RequestPrefetch final : public RequestBase { ...@@ -190,7 +192,6 @@ class RequestPrefetch final : public RequestBase {
virtual void Process() { virtual void Process() {
// prefetch process... // prefetch process...
::grpc::ByteBuffer reply;
std::string var_name = request_->OutVarname(); std::string var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << var_name; VLOG(3) << "RequestPrefetch " << var_name;
...@@ -200,15 +201,16 @@ class RequestPrefetch final : public RequestBase { ...@@ -200,15 +201,16 @@ class RequestPrefetch final : public RequestBase {
InitializeVariable(var, var_desc->GetType()); InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_); executor_->RunPreparedContext(prefetch_ctx_, scope_);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
responder_.Finish(reply, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
protected: protected:
std::shared_ptr<VariableResponse> request_; std::shared_ptr<VariableResponse> request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_; framework::Scope* scope_;
framework::Executor* executor_; framework::Executor* executor_;
...@@ -262,6 +264,9 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -262,6 +264,9 @@ void AsyncGRPCServer::RunSyncUpdate() {
for (int i = 0; i < kGetReqsBufSize; ++i) { for (int i = 0; i < kGetReqsBufSize; ++i) {
TryToRegisterNewGetOne(i); TryToRegisterNewGetOne(i);
} }
for (int i = 0; i < kPrefetchReqsBufSize; ++i) {
TryToRegisterNewPrefetchOne(i);
}
for (int i = 0; i < kNumHandleSendThreads; ++i) { for (int i = 0; i < kNumHandleSendThreads; ++i) {
t_sends_.emplace_back( t_sends_.emplace_back(
...@@ -273,12 +278,11 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -273,12 +278,11 @@ void AsyncGRPCServer::RunSyncUpdate() {
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register))); cq_get_.get(), "cq_get", get_register)));
} }
for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
// TODO(wuyi): Run these "HandleRequest" in thread pool t_prefetchs_.emplace_back(new std::thread(
t_prefetch_.reset(new std::thread( std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), "cq_prefetch", prefetch_register)));
"cq_prefetch", prefetch_register))); }
{ {
std::lock_guard<std::mutex> lock(this->mutex_ready_); std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1; ready_ = 1;
...@@ -292,7 +296,9 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -292,7 +296,9 @@ void AsyncGRPCServer::RunSyncUpdate() {
for (int i = 0; i < kNumHandleGetThreads; ++i) { for (int i = 0; i < kNumHandleGetThreads; ++i) {
t_gets_[i]->join(); t_gets_[i]->join();
} }
t_prefetch_->join(); for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
t_prefetchs_[i]->join();
}
} }
void AsyncGRPCServer::ShutdownQueue() { void AsyncGRPCServer::ShutdownQueue() {
...@@ -342,6 +348,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) { ...@@ -342,6 +348,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
RequestPrefetch* prefetch = new RequestPrefetch( RequestPrefetch* prefetch = new RequestPrefetch(
&service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
program_, prefetch_ctx_.get(), req_id); program_, prefetch_ctx_.get(), req_id);
prefetch_reqs_[req_id] = static_cast<RequestBase*>(prefetch);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
} }
...@@ -376,8 +383,8 @@ void AsyncGRPCServer::HandleRequest( ...@@ -376,8 +383,8 @@ void AsyncGRPCServer::HandleRequest(
base = get_reqs_[req_id]; base = get_reqs_[req_id];
} else if (cq_name == "cq_send") { } else if (cq_name == "cq_send") {
base = send_reqs_[req_id]; base = send_reqs_[req_id];
} else { } else if (cq_name == "cq_prefetch") {
CHECK(false); base = prefetch_reqs_[req_id];
} }
} }
// reference: // reference:
......
...@@ -93,6 +93,7 @@ class AsyncGRPCServer final { ...@@ -93,6 +93,7 @@ class AsyncGRPCServer final {
private: private:
static const int kSendReqsBufSize = 100; static const int kSendReqsBufSize = 100;
static const int kGetReqsBufSize = 100; static const int kGetReqsBufSize = 100;
static const int kPrefetchReqsBufSize = 10;
std::mutex cq_mutex_; std::mutex cq_mutex_;
volatile bool is_shut_down_ = false; volatile bool is_shut_down_ = false;
...@@ -102,6 +103,7 @@ class AsyncGRPCServer final { ...@@ -102,6 +103,7 @@ class AsyncGRPCServer final {
RequestBase *send_reqs_[kSendReqsBufSize]; RequestBase *send_reqs_[kSendReqsBufSize];
RequestBase *get_reqs_[kGetReqsBufSize]; RequestBase *get_reqs_[kGetReqsBufSize];
RequestBase *prefetch_reqs_[kPrefetchReqsBufSize];
GrpcService::AsyncService service_; GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
...@@ -123,6 +125,7 @@ class AsyncGRPCServer final { ...@@ -123,6 +125,7 @@ class AsyncGRPCServer final {
std::vector<std::unique_ptr<std::thread>> t_sends_; std::vector<std::unique_ptr<std::thread>> t_sends_;
std::vector<std::unique_ptr<std::thread>> t_gets_; std::vector<std::unique_ptr<std::thread>> t_gets_;
std::vector<std::unique_ptr<std::thread>> t_prefetchs_;
std::unique_ptr<std::thread> t_prefetch_; std::unique_ptr<std::thread> t_prefetch_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册