未验证 提交 0d598cf9 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #10822 from panyx0718/dist_opt

multi-thread handlerequest
...@@ -38,7 +38,7 @@ def str2bool(v): ...@@ -38,7 +38,7 @@ def str2bool(v):
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default=128, help="Batch size for training.") '--batch_size', type=int, default=16, help="Batch size for training.")
parser.add_argument( parser.add_argument(
'--learning_rate', '--learning_rate',
type=float, type=float,
...@@ -61,7 +61,7 @@ parser.add_argument( ...@@ -61,7 +61,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
'--data_set', '--data_set',
type=str, type=str,
default='cifar10', default='flowers',
choices=['cifar10', 'flowers'], choices=['cifar10', 'flowers'],
help='Optional dataset for benchmark.') help='Optional dataset for benchmark.')
parser.add_argument( parser.add_argument(
...@@ -200,26 +200,30 @@ def main(): ...@@ -200,26 +200,30 @@ def main():
fetch_list=[avg_cost, batch_acc, batch_size]) fetch_list=[avg_cost, batch_acc, batch_size])
return loss, acc, b_size return loss, acc, b_size
if args.profile and args.task_index == 0: if args.profile:
# warmup. with profiler.profiler('All', 'total',
for batch_id, data in enumerate(train_reader()): '/tmp/profile_vgg_%d' % args.task_index):
if batch_id > 5: break
run_step(batch_id, data)
with profiler.profiler('All', 'total', '/tmp/profile_vgg'):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if batch_id > 5: break if batch_id > 5: break
run_step(batch_id, data) run_step(batch_id, data)
total_time = 0.0
count = 0
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
ts = time.time() ts = time.time()
loss, acc, b_size = run_step(batch_id, data) loss, acc, b_size = run_step(batch_id, data)
iters += 1 iters += 1
num_samples += len(data) num_samples += len(data)
train_pass_acc.add(value=acc, weight=b_size) train_pass_acc.add(value=acc, weight=b_size)
duration = time.time() - ts
total_time += duration
count += len(data)
print( print(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, " "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s" % (pass_id, iters, loss, acc, "Speed = %.2f (%.2f) img/s" % (pass_id, iters, loss, acc,
len(data) / (time.time() - ts)) len(data) / duration,
count / total_time)
) # The accuracy is the accumulation of batches, but not the current batch. ) # The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed = time.time() - start_time pass_elapsed = time.time() - start_time
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -196,9 +197,14 @@ bool RPCClient::Wait() { ...@@ -196,9 +197,14 @@ bool RPCClient::Wait() {
const size_t kReqCnt = req_count_; const size_t kReqCnt = req_count_;
bool a[kReqCnt]; bool a[kReqCnt];
std::vector<std::future<void>> waits(req_count_); std::vector<std::future<void>> waits(req_count_);
std::mutex mu;
for (int i = 0; i < req_count_; i++) { for (int i = 0; i < req_count_; i++) {
waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); waits[i] = framework::AsyncIO([i, &a, &mu, this] {
bool ret = Proceed();
std::lock_guard<std::mutex> l(mu);
a[i] = ret;
});
} }
for (int i = 0; i < req_count_; i++) { for (int i = 0; i < req_count_; i++) {
......
...@@ -19,10 +19,16 @@ limitations under the License. */ ...@@ -19,10 +19,16 @@ limitations under the License. */
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
DEFINE_int32(rpc_server_handle_send_threads, 20,
"Number of threads used to handle send at rpc server.");
DEFINE_int32(rpc_server_handle_get_threads, 20,
"Number of threads used to handle get at rpc server.");
DEFINE_int32(rpc_server_handle_prefetch_threads, 1,
"Number of threads used to handle prefetch at rpc server.");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
enum CallStatus { PROCESS = 0, FINISH }; enum CallStatus { PROCESS = 0, FINISH };
// reference: // reference:
...@@ -63,18 +69,20 @@ class RequestSend final : public RequestBase { ...@@ -63,18 +69,20 @@ class RequestSend final : public RequestBase {
explicit RequestSend(GrpcService::AsyncService* service, explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, ReceivedQueue* queue, framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx, int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
queue_(queue), queue_(queue),
responder_(&ctx_) { responder_(&ctx_),
req_id_(req_id) {
if (sync_mode_) { if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false)); request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else { } else {
request_.reset(new VariableResponse(scope, dev_ctx_, true)); request_.reset(new VariableResponse(scope, dev_ctx_, true));
} }
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(
cq_, cq_, this); method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
} }
virtual ~RequestSend() {} virtual ~RequestSend() {}
...@@ -86,15 +94,17 @@ class RequestSend final : public RequestBase { ...@@ -86,15 +94,17 @@ class RequestSend final : public RequestBase {
VLOG(3) << "RequestSend " << var_name; VLOG(3) << "RequestSend " << var_name;
queue_->Push(std::make_pair(var_name, request_)); queue_->Push(std::make_pair(var_name, request_));
sendrecv::VoidMessage reply;
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
protected: protected:
sendrecv::VoidMessage reply_;
std::shared_ptr<VariableResponse> request_; std::shared_ptr<VariableResponse> request_;
ReceivedQueue* queue_; ReceivedQueue* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
int req_id_;
}; };
class RequestGet final : public RequestBase { class RequestGet final : public RequestBase {
...@@ -103,14 +113,17 @@ class RequestGet final : public RequestBase { ...@@ -103,14 +113,17 @@ class RequestGet final : public RequestBase {
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
framework::BlockingQueue<MessageWithName>* queue) framework::BlockingQueue<MessageWithName>* queue,
int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
queue_(queue) { queue_(queue),
req_id_(req_id) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable); auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, service_->RequestAsyncUnary(
cq_, this); method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
virtual ~RequestGet() {} virtual ~RequestGet() {}
...@@ -123,13 +136,13 @@ class RequestGet final : public RequestBase { ...@@ -123,13 +136,13 @@ class RequestGet final : public RequestBase {
VLOG(3) << "RequestGet " << var_name; VLOG(3) << "RequestGet " << var_name;
auto* var = scope_->FindVar(var_name); auto* var = scope_->FindVar(var_name);
::grpc::ByteBuffer reply;
if (var_name != FETCH_BARRIER_MESSAGE) { if (var_name != FETCH_BARRIER_MESSAGE) {
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
} }
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
if (var_name == FETCH_BARRIER_MESSAGE) { if (var_name == FETCH_BARRIER_MESSAGE) {
sendrecv::VariableMessage msg; sendrecv::VariableMessage msg;
...@@ -140,9 +153,11 @@ class RequestGet final : public RequestBase { ...@@ -140,9 +153,11 @@ class RequestGet final : public RequestBase {
protected: protected:
sendrecv::VariableMessage request_; sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_; framework::Scope* scope_;
framework::BlockingQueue<MessageWithName>* queue_; framework::BlockingQueue<MessageWithName>* queue_;
int req_id_;
}; };
class RequestPrefetch final : public RequestBase { class RequestPrefetch final : public RequestBase {
...@@ -153,21 +168,24 @@ class RequestPrefetch final : public RequestBase { ...@@ -153,21 +168,24 @@ class RequestPrefetch final : public RequestBase {
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
framework::Executor* executor, framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx) framework::ExecutorPrepareContext* prefetch_ctx,
int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
executor_(executor), executor_(executor),
program_(program), program_(program),
prefetch_ctx_(prefetch_ctx) { prefetch_ctx_(prefetch_ctx),
req_id_(req_id) {
if (sync_mode_) { if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false)); request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else { } else {
request_.reset(new VariableResponse(scope, dev_ctx_, true)); request_.reset(new VariableResponse(scope, dev_ctx_, true));
} }
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(
cq_, cq_, this); method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
virtual ~RequestPrefetch() {} virtual ~RequestPrefetch() {}
...@@ -176,7 +194,6 @@ class RequestPrefetch final : public RequestBase { ...@@ -176,7 +194,6 @@ class RequestPrefetch final : public RequestBase {
virtual void Process() { virtual void Process() {
// prefetch process... // prefetch process...
::grpc::ByteBuffer reply;
std::string var_name = request_->OutVarname(); std::string var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << var_name; VLOG(3) << "RequestPrefetch " << var_name;
...@@ -186,19 +203,22 @@ class RequestPrefetch final : public RequestBase { ...@@ -186,19 +203,22 @@ class RequestPrefetch final : public RequestBase {
InitializeVariable(var, var_desc->GetType()); InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_); executor_->RunPreparedContext(prefetch_ctx_, scope_);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
protected: protected:
std::shared_ptr<VariableResponse> request_; std::shared_ptr<VariableResponse> request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_; framework::Scope* scope_;
framework::Executor* executor_; framework::Executor* executor_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_; framework::ExecutorPrepareContext* prefetch_ctx_;
int req_id_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) { void AsyncGRPCServer::WaitClientGet(int count) {
...@@ -232,24 +252,39 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -232,24 +252,39 @@ void AsyncGRPCServer::RunSyncUpdate() {
LOG(INFO) << "Server listening on " << address_ LOG(INFO) << "Server listening on " << address_
<< " selected port: " << selected_port_; << " selected port: " << selected_port_;
std::function<void()> send_register = std::function<void(int)> send_register = std::bind(
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1);
std::function<void()> get_register = std::function<void(int)> get_register = std::bind(
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1);
std::function<void()> prefetch_register = std::function<void(int)> prefetch_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this,
std::placeholders::_1);
// TODO(wuyi): Run these "HandleRequest" in thread pool
t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));
t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));
t_prefetch_.reset(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
for (int i = 0; i < kSendReqsBufSize; ++i) {
TryToRegisterNewSendOne(i);
}
for (int i = 0; i < kGetReqsBufSize; ++i) {
TryToRegisterNewGetOne(i);
}
for (int i = 0; i < kPrefetchReqsBufSize; ++i) {
TryToRegisterNewPrefetchOne(i);
}
for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
t_sends_.emplace_back(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));
}
for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
t_gets_.emplace_back(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));
}
for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
t_prefetchs_.emplace_back(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
}
{ {
std::lock_guard<std::mutex> lock(this->mutex_ready_); std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1; ready_ = 1;
...@@ -257,9 +292,15 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -257,9 +292,15 @@ void AsyncGRPCServer::RunSyncUpdate() {
condition_ready_.notify_all(); condition_ready_.notify_all();
// wait server // wait server
server_->Wait(); server_->Wait();
t_send_->join(); for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
t_get_->join(); t_sends_[i]->join();
t_prefetch_->join(); }
for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
t_gets_[i]->join();
}
for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
t_prefetchs_[i]->join();
}
} }
void AsyncGRPCServer::ShutdownQueue() { void AsyncGRPCServer::ShutdownQueue() {
...@@ -276,47 +317,48 @@ void AsyncGRPCServer::ShutDown() { ...@@ -276,47 +317,48 @@ void AsyncGRPCServer::ShutDown() {
server_->Shutdown(); server_->Shutdown();
} }
void AsyncGRPCServer::TryToRegisterNewSendOne() { void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_, RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
scope_, &var_recv_queue_, dev_ctx_); scope_, &var_recv_queue_, dev_ctx_, i);
send_reqs_[i] = static_cast<RequestBase*>(send);
VLOG(4) << "Create RequestSend status:" << send->Status(); VLOG(4) << "Create RequestSend status:" << send->Status();
} }
void AsyncGRPCServer::TryToRegisterNewGetOne() { void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
dev_ctx_, &var_get_queue_); dev_ctx_, &var_get_queue_, req_id);
get_reqs_[req_id] = static_cast<RequestBase*>(get);
VLOG(4) << "Create RequestGet status:" << get->Status(); VLOG(4) << "Create RequestGet status:" << get->Status();
} }
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
return; return;
} }
RequestPrefetch* prefetch = RequestPrefetch* prefetch = new RequestPrefetch(
new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_, &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
dev_ctx_, executor_, program_, prefetch_ctx_.get()); program_, prefetch_ctx_.get(), req_id);
prefetch_reqs_[req_id] = static_cast<RequestBase*>(prefetch);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
} }
// FIXME(typhoonzero): change cq_name to enum. // FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, void AsyncGRPCServer::HandleRequest(
const std::string& cq_name, ::grpc::ServerCompletionQueue* cq, const std::string& cq_name,
std::function<void()> TryToRegisterNewOne) { std::function<void(int)> TryToRegisterNewOne) {
TryToRegisterNewOne();
void* tag = NULL; void* tag = NULL;
bool ok = false; bool ok = false;
...@@ -327,8 +369,7 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -327,8 +369,7 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
break; break;
} }
VLOG(3) << "HandleRequest for " << cq_name << " get Next"; VLOG(3) << "HandleRequest for " << cq_name << " get Next";
int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
PADDLE_ENFORCE(tag);
if (sync_mode_) { if (sync_mode_) {
// FIXME(typhoonzero): de-couple the barriers with recv_op // FIXME(typhoonzero): de-couple the barriers with recv_op
...@@ -337,7 +378,17 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -337,7 +378,17 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond"; VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
} }
RequestBase* base = reinterpret_cast<RequestBase*>(tag); RequestBase* base = nullptr;
{
std::lock_guard<std::mutex> l(cq_mutex_);
if (cq_name == "cq_get") {
base = get_reqs_[req_id];
} else if (cq_name == "cq_send") {
base = send_reqs_[req_id];
} else if (cq_name == "cq_prefetch") {
base = prefetch_reqs_[req_id];
}
}
// reference: // reference:
// https://github.com/tensorflow/tensorflow/issues/5596 // https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
...@@ -345,19 +396,19 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -345,19 +396,19 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
if (!ok) { if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name[" LOG(WARNING) << cq_name << " recv no regular event:argument name["
<< base->GetReqName() << "]"; << base->GetReqName() << "]";
TryToRegisterNewOne(); TryToRegisterNewOne(req_id);
delete base; delete base;
continue; continue;
} }
switch (base->Status()) { switch (base->Status()) {
case PROCESS: { case PROCESS: {
TryToRegisterNewOne();
base->Process(); base->Process();
VLOG(4) << cq_name << " PROCESS status:" << base->Status(); VLOG(4) << cq_name << " PROCESS status:" << base->Status();
break; break;
} }
case FINISH: { case FINISH: {
TryToRegisterNewOne(req_id);
VLOG(4) << cq_name << " FINISH status:" << base->Status(); VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base; delete base;
break; break;
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
#include <vector>
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
...@@ -30,6 +31,7 @@ limitations under the License. */ ...@@ -30,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -82,19 +84,27 @@ class AsyncGRPCServer final { ...@@ -82,19 +84,27 @@ class AsyncGRPCServer final {
protected: protected:
void HandleRequest(::grpc::ServerCompletionQueue *cq, void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name, const std::string &cq_name,
std::function<void()> TryToRegisterNewOne); std::function<void(int)> TryToRegisterNewOne);
void TryToRegisterNewSendOne(); void TryToRegisterNewSendOne(int req_id);
void TryToRegisterNewGetOne(); void TryToRegisterNewGetOne(int req_id);
void TryToRegisterNewPrefetchOne(); void TryToRegisterNewPrefetchOne(int req_id);
void ShutdownQueue(); void ShutdownQueue();
private: private:
static const int kSendReqsBufSize = 100;
static const int kGetReqsBufSize = 100;
static const int kPrefetchReqsBufSize = 10;
std::mutex cq_mutex_; std::mutex cq_mutex_;
volatile bool is_shut_down_ = false; volatile bool is_shut_down_ = false;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_;
RequestBase *send_reqs_[kSendReqsBufSize];
RequestBase *get_reqs_[kGetReqsBufSize];
RequestBase *prefetch_reqs_[kPrefetchReqsBufSize];
GrpcService::AsyncService service_; GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
...@@ -113,8 +123,10 @@ class AsyncGRPCServer final { ...@@ -113,8 +123,10 @@ class AsyncGRPCServer final {
mutable int barrier_cond_step_; mutable int barrier_cond_step_;
std::condition_variable barrier_condition_; std::condition_variable barrier_condition_;
std::unique_ptr<std::thread> t_send_; std::vector<std::unique_ptr<std::thread>> t_sends_;
std::unique_ptr<std::thread> t_get_; std::vector<std::unique_ptr<std::thread>> t_gets_;
std::vector<std::unique_ptr<std::thread>> t_prefetchs_;
std::unique_ptr<std::thread> t_prefetch_; std::unique_ptr<std::thread> t_prefetch_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_; std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include <grpc++/support/byte_buffer.h> #include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
// NOTE: This method was originally created by tensorflow // NOTE: This method was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this // (https://github.com/tensorflow/tensorflow/) we borrow this
// method and did some modifications so that we can parse gRPC // method and did some modifications so that we can parse gRPC
......
...@@ -70,10 +70,10 @@ message VariableMessage { ...@@ -70,10 +70,10 @@ message VariableMessage {
bytes rows = 9; bytes rows = 9;
// Look up table block execution output variable name. // Look up table block execution output variable name.
string out_varname = 10; string out_varname = 10;
// If true, the ps server will start profiling, the ps // If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_* // server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from true to false. // when profile switches from 1 to 2.
bool profile = 11; int64 profile = 11;
} }
message VoidMessage {} message VoidMessage {}
...@@ -123,7 +123,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -123,7 +123,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
// 1 trainer returns true for ShouldSendProfileState(). It tells PS // 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the // servers the trainer's profiling state so that PS can follow the
// trainer. // trainer.
request.set_profile(platform::IsProfileEnabled()); if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request.set_profile(platform::kEnableProfiler);
} else {
request.set_profile(platform::kDisableProfiler);
}
}
if (!out_name.empty()) { if (!out_name.empty()) {
request.set_out_varname(out_name); request.set_out_varname(out_name);
} }
......
...@@ -449,8 +449,8 @@ int VariableResponse::Parse(Source* source) { ...@@ -449,8 +449,8 @@ int VariableResponse::Parse(Source* source) {
break; break;
} }
case sendrecv::VariableMessage::kProfileFieldNumber: { case sendrecv::VariableMessage::kProfileFieldNumber: {
bool profiling; uint64_t profiling = 0;
if (!input.ReadRaw(reinterpret_cast<void*>(&profiling), 1)) { if (!input.ReadVarint64(&profiling)) {
return tag; return tag;
} }
meta_.set_profile(profiling); meta_.set_profile(profiling);
...@@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) { ...@@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) {
if (listener_id <= 0) { if (listener_id <= 0) {
break; break;
} }
if (profiling && !platform::IsProfileEnabled()) { if (profiling == platform::kEnableProfiler &&
!platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU); platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (!profiling && platform::IsProfileEnabled()) { } else if (profiling == platform::kDisableProfiler &&
platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir. // TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler( platform::DisableProfiler(
platform::EventSortingKey::kDefault, platform::EventSortingKey::kDefault,
......
...@@ -245,7 +245,6 @@ class DeviceTracerImpl : public DeviceTracer { ...@@ -245,7 +245,6 @@ class DeviceTracerImpl : public DeviceTracer {
void Enable() { void Enable() {
std::lock_guard<std::mutex> l(trace_mu_); std::lock_guard<std::mutex> l(trace_mu_);
if (enabled_) { if (enabled_) {
fprintf(stderr, "DeviceTracer already enabled\n");
return; return;
} }
EnableActivity(); EnableActivity();
......
...@@ -116,6 +116,8 @@ void ResetProfiler(); ...@@ -116,6 +116,8 @@ void ResetProfiler();
void DisableProfiler(EventSortingKey sorted_key, void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path); const std::string& profile_path);
const int kEnableProfiler = 1;
const int kDisableProfiler = 2;
// Test if the profiler is currently enabled. // Test if the profiler is currently enabled.
bool IsProfileEnabled(); bool IsProfileEnabled();
// Whether the trainer should send profiling state to PS. // Whether the trainer should send profiling state to PS.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册