From b4dd4c048d1d121109f9f7f03c91113e02b4f5d0 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 21 May 2018 21:59:52 -0700 Subject: [PATCH] multi-thread handlerequest Experiment on vgg flower, 2 trainers, 1ps. more trainer could have more speedup. After: Pass = 0, Iters = 327, Speed = (7.52) img/s Before: Pass = 0, Iters = 385, Speed = (6.77) img/s --- benchmark/cluster/vgg16/vgg16_fluid.py | 26 +-- cmake/external/grpc.cmake | 2 +- paddle/fluid/framework/executor.cc | 5 +- paddle/fluid/operators/detail/grpc_client.cc | 8 +- paddle/fluid/operators/detail/grpc_server.cc | 154 ++++++++++++------ paddle/fluid/operators/detail/grpc_server.h | 21 ++- paddle/fluid/operators/detail/grpc_service.h | 2 + paddle/fluid/operators/detail/send_recv.proto | 2 +- .../operators/detail/sendrecvop_utils.cc | 8 +- .../operators/detail/variable_response.cc | 8 +- paddle/fluid/platform/device_tracer.cc | 1 - 11 files changed, 158 insertions(+), 79 deletions(-) diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 05b5f3977cb..0f5cd2a2535 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -38,7 +38,7 @@ def str2bool(v): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - '--batch_size', type=int, default=128, help="Batch size for training.") + '--batch_size', type=int, default=16, help="Batch size for training.") parser.add_argument( '--learning_rate', type=float, @@ -61,7 +61,7 @@ parser.add_argument( parser.add_argument( '--data_set', type=str, - default='cifar10', + default='flowers', choices=['cifar10', 'flowers'], help='Optional dataset for benchmark.') parser.add_argument( @@ -200,26 +200,30 @@ def main(): fetch_list=[avg_cost, batch_acc, batch_size]) return loss, acc, b_size - if args.profile and args.task_index == 0: - # warmup. - for batch_id, data in enumerate(train_reader()): - if batch_id > 5: break - run_step(batch_id, data) - with profiler.profiler('All', 'total', '/tmp/profile_vgg'): + if args.profile: + with profiler.profiler('All', 'total', + '/tmp/profile_vgg_%d' % args.task_index): for batch_id, data in enumerate(train_reader()): - if batch_id > 5: break + if batch_id > 4: break run_step(batch_id, data) + total_time = 0.0 + count = 0 for batch_id, data in enumerate(train_reader()): ts = time.time() loss, acc, b_size = run_step(batch_id, data) iters += 1 num_samples += len(data) train_pass_acc.add(value=acc, weight=b_size) + + duration = time.time() - ts + total_time += duration + count += len(data) print( "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, " - "Speed = %.2f img/s" % (pass_id, iters, loss, acc, - len(data) / (time.time() - ts)) + "Speed = %.2f (%.2f) img/s" % (pass_id, iters, loss, acc, + len(data) / duration, + count / total_time) ) # The accuracy is the accumulation of batches, but not the current batch. pass_elapsed = time.time() - start_time diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake index e90948782bb..ef520b12879 100644 --- a/cmake/external/grpc.cmake +++ b/cmake/external/grpc.cmake @@ -33,7 +33,7 @@ ExternalProject_Add( extern_grpc DEPENDS protobuf zlib GIT_REPOSITORY "https://github.com/grpc/grpc.git" - GIT_TAG "v1.10.x" + GIT_TAG "v1.8.x" PREFIX ${GRPC_SOURCES_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 4e431561f81..55be9b6c3bb 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -350,12 +350,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } } - platform::DeviceContextPool::Instance().Get(place_)->Wait(); + // platform::DeviceContextPool::Instance().Get(place_)->Wait(); if (create_vars && create_local_scope) { scope->DeleteScope(local_scope); - } else { - // Delete the local scopes created in operators. - scope->DropKids(); } if (FLAGS_benchmark) { VLOG(2) << "-------------------------------------------------------"; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ae60ab15325..47892b1bcc0 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -196,9 +197,14 @@ bool RPCClient::Wait() { const size_t kReqCnt = req_count_; bool a[kReqCnt]; std::vector> waits(req_count_); + std::mutex mu; for (int i = 0; i < req_count_; i++) { - waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); + waits[i] = framework::AsyncIO([i, &a, &mu, this] { + bool ret = Proceed(); + std::lock_guard l(mu); + a[i] = ret; + }); } for (int i = 0; i < req_count_; i++) { diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index eb114a47d99..604321cd1f3 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -22,7 +22,10 @@ using ::grpc::ServerAsyncResponseWriter; namespace paddle { namespace operators { namespace detail { - +namespace { +const int kNumHandleSendThreads = 20; +const int kNumHandleGetThreads = 20; +} // namespace enum CallStatus { PROCESS = 0, FINISH }; // reference: @@ -63,18 +66,20 @@ class RequestSend final : public RequestBase { explicit RequestSend(GrpcService::AsyncService* service, ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, ReceivedQueue* queue, - const platform::DeviceContext* dev_ctx) + const platform::DeviceContext* dev_ctx, int i) : RequestBase(service, cq, sync_mode, dev_ctx), queue_(queue), - responder_(&ctx_) { + responder_(&ctx_), + i_(i) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { request_.reset(new VariableResponse(scope, dev_ctx_, true)); } int method_id = static_cast(detail::GrpcMethod::kSendVariable); - service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, - cq_, cq_, this); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(i))); } virtual ~RequestSend() {} @@ -86,15 +91,17 @@ class RequestSend final : public RequestBase { VLOG(3) << "RequestSend " << var_name; queue_->Push(std::make_pair(var_name, request_)); - sendrecv::VoidMessage reply; - responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, + reinterpret_cast(static_cast(i_))); } protected: + sendrecv::VoidMessage reply_; std::shared_ptr request_; ReceivedQueue* queue_; ServerAsyncResponseWriter responder_; + int i_; }; class RequestGet final : public RequestBase { @@ -103,14 +110,16 @@ class RequestGet final : public RequestBase { ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, - framework::BlockingQueue* queue) + framework::BlockingQueue* queue, int i) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), - queue_(queue) { + queue_(queue), + i_(i) { auto method_id = static_cast(detail::GrpcMethod::kGetVariable); - service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, - cq_, this); + service_->RequestAsyncUnary( + method_id, &ctx_, &request_, &responder_, cq_, cq_, + reinterpret_cast(static_cast(i))); } virtual ~RequestGet() {} @@ -123,13 +132,13 @@ class RequestGet final : public RequestBase { VLOG(3) << "RequestGet " << var_name; auto* var = scope_->FindVar(var_name); - ::grpc::ByteBuffer reply; if (var_name != FETCH_BARRIER_MESSAGE) { - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); } - responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, + reinterpret_cast(static_cast(i_))); if (var_name == FETCH_BARRIER_MESSAGE) { sendrecv::VariableMessage msg; @@ -140,9 +149,11 @@ class RequestGet final : public RequestBase { protected: sendrecv::VariableMessage request_; + ::grpc::ByteBuffer reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; framework::BlockingQueue* queue_; + int i_; }; class RequestPrefetch final : public RequestBase { @@ -153,13 +164,15 @@ class RequestPrefetch final : public RequestBase { const platform::DeviceContext* dev_ctx, framework::Executor* executor, framework::ProgramDesc* program, - framework::ExecutorPrepareContext* prefetch_ctx) + framework::ExecutorPrepareContext* prefetch_ctx, + int i) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), - prefetch_ctx_(prefetch_ctx) { + prefetch_ctx_(prefetch_ctx), + i_(i) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { @@ -188,7 +201,8 @@ class RequestPrefetch final : public RequestBase { SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); - responder_.Finish(reply, ::grpc::Status::OK, this); + responder_.Finish(reply, ::grpc::Status::OK, + reinterpret_cast(static_cast(i_))); status_ = FINISH; } @@ -199,6 +213,7 @@ class RequestPrefetch final : public RequestBase { framework::Executor* executor_; framework::ProgramDesc* program_; framework::ExecutorPrepareContext* prefetch_ctx_; + int i_; }; void AsyncGRPCServer::WaitClientGet(int count) { @@ -232,20 +247,33 @@ void AsyncGRPCServer::RunSyncUpdate() { LOG(INFO) << "Server listening on " << address_ << " selected port: " << selected_port_; - std::function send_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); - std::function get_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); - std::function prefetch_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); + std::function send_register = std::bind( + &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1); + std::function get_register = std::bind( + &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1); + std::function prefetch_register = + std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this, + std::placeholders::_1); + + for (int i = 0; i < kSendReqsBufSize; ++i) { + TryToRegisterNewSendOne(i); + } + for (int i = 0; i < kGetReqsBufSize; ++i) { + TryToRegisterNewGetOne(i); + } + + for (int i = 0; i < kNumHandleSendThreads; ++i) { + t_sends_.emplace_back( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, + cq_send_.get(), "cq_send", send_register))); + } + for (int i = 0; i < kNumHandleGetThreads; ++i) { + t_gets_.emplace_back( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, + cq_get_.get(), "cq_get", get_register))); + } // TODO(wuyi): Run these "HandleRequest" in thread pool - t_send_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, - cq_send_.get(), "cq_send", send_register))); - t_get_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, - cq_get_.get(), "cq_get", get_register))); t_prefetch_.reset(new std::thread( std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), "cq_prefetch", prefetch_register))); @@ -257,8 +285,27 @@ void AsyncGRPCServer::RunSyncUpdate() { condition_ready_.notify_all(); // wait server server_->Wait(); - t_send_->join(); - t_get_->join(); + for (int i = 0; i < kNumHandleSendThreads; ++i) { + t_sends_[i]->join(); + } + for (int i = 0; i < kNumHandleGetThreads; ++i) { + t_gets_[i]->join(); + } + { + std::lock_guard l(cq_mutex_); + for (int i = 0; i < kSendReqsBufSize; ++i) { + if (send_reqs_[i]) { + delete send_reqs_[i]; + send_reqs_[i] = nullptr; + } + } + for (int i = 0; i < kGetReqsBufSize; ++i) { + if (get_reqs_[i]) { + delete get_reqs_[i]; + get_reqs_[i] = nullptr; + } + } + } t_prefetch_->join(); } @@ -276,47 +323,47 @@ void AsyncGRPCServer::ShutDown() { server_->Shutdown(); } -void AsyncGRPCServer::TryToRegisterNewSendOne() { +void AsyncGRPCServer::TryToRegisterNewSendOne(int i) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; return; } RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_, - scope_, &var_recv_queue_, dev_ctx_); + scope_, &var_recv_queue_, dev_ctx_, i); + send_reqs_[i] = static_cast(send); VLOG(4) << "Create RequestSend status:" << send->Status(); } -void AsyncGRPCServer::TryToRegisterNewGetOne() { +void AsyncGRPCServer::TryToRegisterNewGetOne(int i) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; return; } RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, - dev_ctx_, &var_get_queue_); + dev_ctx_, &var_get_queue_, i); + get_reqs_[i] = static_cast(get); VLOG(4) << "Create RequestGet status:" << get->Status(); } -void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { +void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; return; } - RequestPrefetch* prefetch = - new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_, - dev_ctx_, executor_, program_, prefetch_ctx_.get()); + RequestPrefetch* prefetch = new RequestPrefetch( + &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, + program_, prefetch_ctx_.get(), i); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } // FIXME(typhoonzero): change cq_name to enum. -void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, - const std::string& cq_name, - std::function TryToRegisterNewOne) { - TryToRegisterNewOne(); - +void AsyncGRPCServer::HandleRequest( + ::grpc::ServerCompletionQueue* cq, const std::string& cq_name, + std::function TryToRegisterNewOne) { void* tag = NULL; bool ok = false; @@ -327,8 +374,7 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, break; } VLOG(3) << "HandleRequest for " << cq_name << " get Next"; - - PADDLE_ENFORCE(tag); + int i = static_cast(reinterpret_cast(tag)); if (sync_mode_) { // FIXME(typhoonzero): de-couple the barriers with recv_op @@ -337,7 +383,17 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond"; } - RequestBase* base = reinterpret_cast(tag); + RequestBase* base = nullptr; + { + std::lock_guard l(cq_mutex_); + if (cq_name == "cq_get") { + base = get_reqs_[i]; + } else if (cq_name == "cq_send") { + base = send_reqs_[i]; + } else { + CHECK(false); + } + } // reference: // https://github.com/tensorflow/tensorflow/issues/5596 // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM @@ -345,19 +401,19 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, if (!ok) { LOG(WARNING) << cq_name << " recv no regular event:argument name[" << base->GetReqName() << "]"; - TryToRegisterNewOne(); + TryToRegisterNewOne(i); delete base; continue; } switch (base->Status()) { case PROCESS: { - TryToRegisterNewOne(); base->Process(); VLOG(4) << cq_name << " PROCESS status:" << base->Status(); break; } case FINISH: { + TryToRegisterNewOne(i); VLOG(4) << cq_name << " FINISH status:" << base->Status(); delete base; break; diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 238aaa29634..d70be1b7ce9 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include // NOLINT #include +#include #include "grpc++/grpc++.h" #include "paddle/fluid/framework/blocking_queue.h" @@ -30,6 +31,7 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -82,19 +84,25 @@ class AsyncGRPCServer final { protected: void HandleRequest(::grpc::ServerCompletionQueue *cq, const std::string &cq_name, - std::function TryToRegisterNewOne); - void TryToRegisterNewSendOne(); - void TryToRegisterNewGetOne(); - void TryToRegisterNewPrefetchOne(); + std::function TryToRegisterNewOne); + void TryToRegisterNewSendOne(int i); + void TryToRegisterNewGetOne(int i); + void TryToRegisterNewPrefetchOne(int i); void ShutdownQueue(); private: + static const int kSendReqsBufSize = 100; + static const int kGetReqsBufSize = 100; + std::mutex cq_mutex_; volatile bool is_shut_down_ = false; std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_; + RequestBase *send_reqs_[kSendReqsBufSize]; + RequestBase *get_reqs_[kGetReqsBufSize]; + GrpcService::AsyncService service_; std::unique_ptr<::grpc::Server> server_; @@ -113,8 +121,9 @@ class AsyncGRPCServer final { mutable int barrier_cond_step_; std::condition_variable barrier_condition_; - std::unique_ptr t_send_; - std::unique_ptr t_get_; + std::vector> t_sends_; + std::vector> t_gets_; + std::unique_ptr t_prefetch_; std::unique_ptr prefetch_ctx_; diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h index e6dab2f5a3a..e0505c2b9d0 100644 --- a/paddle/fluid/operators/detail/grpc_service.h +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -25,6 +25,8 @@ #include #include "paddle/fluid/operators/detail/variable_response.h" +#include "paddle/fluid/platform/profiler.h" + // NOTE: This method was originally created by tensorflow // (https://github.com/tensorflow/tensorflow/) we borrow this // method and did some modifications so that we can parse gRPC diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 9478c5702bc..078181909df 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -73,7 +73,7 @@ message VariableMessage { // If true, the ps server will start profiling, the ps // server stops profiling and generates a profile to /tmp/profile_ps_* // when profile switches from true to false. - bool profile = 11; + int64 profile = 11; } message VoidMessage {} diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 07c43554bc6..a9ea80c9173 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -122,7 +122,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, // 1 trainer returns true for ShouldSendProfileState(). It tells PS // servers the trainer's profiling state so that PS can follow the // trainer. - request.set_profile(platform::IsProfileEnabled()); + if (platform::ShouldSendProfileState()) { + if (platform::IsProfileEnabled()) { + request.set_profile(1); + } else { + request.set_profile(2); + } + } if (!out_name.empty()) { request.set_out_varname(out_name); } diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 462e303096e..2dfd9b2621e 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -449,8 +449,8 @@ int VariableResponse::Parse(Source* source) { break; } case sendrecv::VariableMessage::kProfileFieldNumber: { - bool profiling; - if (!input.ReadRaw(reinterpret_cast(&profiling), 1)) { + uint64_t profiling = 0; + if (!input.ReadVarint64(&profiling)) { return tag; } meta_.set_profile(profiling); @@ -458,9 +458,9 @@ int VariableResponse::Parse(Source* source) { if (listener_id <= 0) { break; } - if (profiling && !platform::IsProfileEnabled()) { + if (profiling == 1 && !platform::IsProfileEnabled()) { platform::EnableProfiler(platform::ProfilerState::kCPU); - } else if (!profiling && platform::IsProfileEnabled()) { + } else if (profiling == 2 && platform::IsProfileEnabled()) { // TODO(panyx0718): Should we allow to customize file dir. platform::DisableProfiler( platform::EventSortingKey::kDefault, diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index c9e10631680..1a9be044e02 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -245,7 +245,6 @@ class DeviceTracerImpl : public DeviceTracer { void Enable() { std::lock_guard l(trace_mu_); if (enabled_) { - fprintf(stderr, "DeviceTracer already enabled\n"); return; } EnableActivity(); -- GitLab