diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 0f5cd2a253554a7c37e1d34cd798f30ea2119bdd..e9360ab4c79d23bdf9f84d0c0d407af6d39bde3e 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -204,7 +204,7 @@ def main(): with profiler.profiler('All', 'total', '/tmp/profile_vgg_%d' % args.task_index): for batch_id, data in enumerate(train_reader()): - if batch_id > 4: break + if batch_id > 5: break run_step(batch_id, data) total_time = 0.0 diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake index ef520b128794d44df03fd11a916d8a791cba5c2a..e90948782bb5e333bbdb47ef9d61c1e37e3cf9e4 100644 --- a/cmake/external/grpc.cmake +++ b/cmake/external/grpc.cmake @@ -33,7 +33,7 @@ ExternalProject_Add( extern_grpc DEPENDS protobuf zlib GIT_REPOSITORY "https://github.com/grpc/grpc.git" - GIT_TAG "v1.8.x" + GIT_TAG "v1.10.x" PREFIX ${GRPC_SOURCES_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 604321cd1f35abd4237e4b2ecc34dc148c4fd0d8..c2c1df4cd6485933945077451464659d6b89f0f4 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -66,11 +66,11 @@ class RequestSend final : public RequestBase { explicit RequestSend(GrpcService::AsyncService* service, ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, ReceivedQueue* queue, - const platform::DeviceContext* dev_ctx, int i) + const platform::DeviceContext* dev_ctx, int req_id) : RequestBase(service, cq, sync_mode, dev_ctx), queue_(queue), responder_(&ctx_), - i_(i) { + req_id_(req_id) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { @@ -79,7 +79,7 @@ class RequestSend final : public RequestBase { int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, - reinterpret_cast(static_cast(i))); + reinterpret_cast(static_cast(req_id))); } virtual ~RequestSend() {} @@ -93,7 +93,7 @@ class RequestSend final : public RequestBase { status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(i_))); + reinterpret_cast(static_cast(req_id_))); } protected: @@ -101,7 +101,7 @@ class RequestSend final : public RequestBase { std::shared_ptr request_; ReceivedQueue* queue_; ServerAsyncResponseWriter responder_; - int i_; + int req_id_; }; class RequestGet final : public RequestBase { @@ -110,16 +110,17 @@ class RequestGet final : public RequestBase { ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, - framework::BlockingQueue* queue, int i) + framework::BlockingQueue* queue, + int req_id) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), queue_(queue), - i_(i) { + req_id_(req_id) { auto method_id = static_cast(detail::GrpcMethod::kGetVariable); service_->RequestAsyncUnary( method_id, &ctx_, &request_, &responder_, cq_, cq_, - reinterpret_cast(static_cast(i))); + reinterpret_cast(static_cast(req_id_))); } virtual ~RequestGet() {} @@ -138,7 +139,7 @@ class RequestGet final : public RequestBase { status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(i_))); + reinterpret_cast(static_cast(req_id_))); if (var_name == FETCH_BARRIER_MESSAGE) { sendrecv::VariableMessage msg; @@ -153,7 +154,7 @@ class RequestGet final : public RequestBase { ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; framework::BlockingQueue* queue_; - int i_; + int req_id_; }; class RequestPrefetch final : public RequestBase { @@ -165,14 +166,14 @@ class RequestPrefetch final : public RequestBase { framework::Executor* executor, framework::ProgramDesc* program, framework::ExecutorPrepareContext* prefetch_ctx, - int i) + int req_id) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), prefetch_ctx_(prefetch_ctx), - i_(i) { + req_id_(req_id) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { @@ -202,7 +203,7 @@ class RequestPrefetch final : public RequestBase { SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); responder_.Finish(reply, ::grpc::Status::OK, - reinterpret_cast(static_cast(i_))); + reinterpret_cast(static_cast(req_id_))); status_ = FINISH; } @@ -213,7 +214,7 @@ class RequestPrefetch final : public RequestBase { framework::Executor* executor_; framework::ProgramDesc* program_; framework::ExecutorPrepareContext* prefetch_ctx_; - int i_; + int req_id_; }; void AsyncGRPCServer::WaitClientGet(int count) { @@ -291,21 +292,6 @@ void AsyncGRPCServer::RunSyncUpdate() { for (int i = 0; i < kNumHandleGetThreads; ++i) { t_gets_[i]->join(); } - { - std::lock_guard l(cq_mutex_); - for (int i = 0; i < kSendReqsBufSize; ++i) { - if (send_reqs_[i]) { - delete send_reqs_[i]; - send_reqs_[i] = nullptr; - } - } - for (int i = 0; i < kGetReqsBufSize; ++i) { - if (get_reqs_[i]) { - delete get_reqs_[i]; - get_reqs_[i] = nullptr; - } - } - } t_prefetch_->join(); } @@ -335,19 +321,19 @@ void AsyncGRPCServer::TryToRegisterNewSendOne(int i) { VLOG(4) << "Create RequestSend status:" << send->Status(); } -void AsyncGRPCServer::TryToRegisterNewGetOne(int i) { +void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; return; } RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, - dev_ctx_, &var_get_queue_, i); - get_reqs_[i] = static_cast(get); + dev_ctx_, &var_get_queue_, req_id); + get_reqs_[req_id] = static_cast(get); VLOG(4) << "Create RequestGet status:" << get->Status(); } -void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) { +void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; @@ -355,7 +341,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) { } RequestPrefetch* prefetch = new RequestPrefetch( &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, - program_, prefetch_ctx_.get(), i); + program_, prefetch_ctx_.get(), req_id); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } @@ -374,7 +360,7 @@ void AsyncGRPCServer::HandleRequest( break; } VLOG(3) << "HandleRequest for " << cq_name << " get Next"; - int i = static_cast(reinterpret_cast(tag)); + int req_id = static_cast(reinterpret_cast(tag)); if (sync_mode_) { // FIXME(typhoonzero): de-couple the barriers with recv_op @@ -387,9 +373,9 @@ void AsyncGRPCServer::HandleRequest( { std::lock_guard l(cq_mutex_); if (cq_name == "cq_get") { - base = get_reqs_[i]; + base = get_reqs_[req_id]; } else if (cq_name == "cq_send") { - base = send_reqs_[i]; + base = send_reqs_[req_id]; } else { CHECK(false); } @@ -401,7 +387,7 @@ void AsyncGRPCServer::HandleRequest( if (!ok) { LOG(WARNING) << cq_name << " recv no regular event:argument name[" << base->GetReqName() << "]"; - TryToRegisterNewOne(i); + TryToRegisterNewOne(req_id); delete base; continue; } @@ -413,7 +399,7 @@ void AsyncGRPCServer::HandleRequest( break; } case FINISH: { - TryToRegisterNewOne(i); + TryToRegisterNewOne(req_id); VLOG(4) << cq_name << " FINISH status:" << base->Status(); delete base; break;