提交 260bf5ac 编写于 作者: Q qiaolongfei

add sync_mode

上级 63fbdcf9
...@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH }; ...@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
class RequestBase { class RequestBase {
public: public:
explicit RequestBase(GrpcService::AsyncService* service, explicit RequestBase(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx)
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { : service_(service),
cq_(cq),
sync_mode_(sync_mode),
status_(PROCESS),
dev_ctx_(dev_ctx) {
PADDLE_ENFORCE(cq_); PADDLE_ENFORCE(cq_);
} }
virtual ~RequestBase() {} virtual ~RequestBase() {}
...@@ -49,6 +53,7 @@ class RequestBase { ...@@ -49,6 +53,7 @@ class RequestBase {
::grpc::ServerContext ctx_; ::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_; GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_; ::grpc::ServerCompletionQueue* cq_;
const bool sync_mode_;
CallStatus status_; CallStatus status_;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
}; };
...@@ -56,11 +61,17 @@ class RequestBase { ...@@ -56,11 +61,17 @@ class RequestBase {
class RequestSend final : public RequestBase { class RequestSend final : public RequestBase {
public: public:
explicit RequestSend(GrpcService::AsyncService* service, explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, ReceivedQueue* queue, framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { : RequestBase(service, cq, sync_mode, dev_ctx),
request_.reset(new VariableResponse(false, scope, dev_ctx_)); queue_(queue),
responder_(&ctx_) {
if (sync_mode_) {
request_.reset(new VariableResponse(false, scope, dev_ctx_));
} else {
request_.reset(new VariableResponse(true, scope, dev_ctx_));
}
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this); cq_, cq_, this);
...@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase { ...@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
class RequestGet final : public RequestBase { class RequestGet final : public RequestBase {
public: public:
explicit RequestGet(GrpcService::AsyncService* service, explicit RequestGet(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
SimpleBlockQueue<MessageWithName>* queue) SimpleBlockQueue<MessageWithName>* queue)
: RequestBase(service, cq, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
queue_(queue) { queue_(queue) {
...@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase { ...@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
class RequestPrefetch final : public RequestBase { class RequestPrefetch final : public RequestBase {
public: public:
explicit RequestPrefetch(GrpcService::AsyncService* service, explicit RequestPrefetch(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq, bool sync_mode,
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
framework::Executor* executor, framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx) framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx), : RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
executor_(executor), executor_(executor),
program_(program), program_(program),
prefetch_ctx_(prefetch_ctx) { prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(false, scope, dev_ctx_)); if (sync_mode_) {
request_.reset(new VariableResponse(false, scope, dev_ctx_));
} else {
request_.reset(new VariableResponse(true, scope, dev_ctx_));
}
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this); cq_, cq_, this);
...@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase { ...@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
framework::Executor* executor_; framework::Executor* executor_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_; framework::ExecutorPrepareContext* prefetch_ctx_;
int blkid_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) { void AsyncGRPCServer::WaitClientGet(int count) {
...@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { ...@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
&var_recv_queue_, dev_ctx_); scope_, &var_recv_queue_, dev_ctx_);
VLOG(4) << "Create RequestSend status:" << send->Status(); VLOG(4) << "Create RequestSend status:" << send->Status();
} }
...@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
&var_get_queue_); dev_ctx_, &var_get_queue_);
VLOG(4) << "Create RequestGet status:" << get->Status(); VLOG(4) << "Create RequestGet status:" << get->Status();
} }
...@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { ...@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
return; return;
} }
RequestPrefetch* prefetch = RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_,
executor_, program_, prefetch_ctx_); dev_ctx_, executor_, program_, prefetch_ctx_);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
} }
......
...@@ -44,7 +44,8 @@ class RequestBase; ...@@ -44,7 +44,8 @@ class RequestBase;
class AsyncGRPCServer final { class AsyncGRPCServer final {
public: public:
explicit AsyncGRPCServer(const std::string &address) : address_(address) {} explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode) {}
void RunSyncUpdate(); void RunSyncUpdate();
...@@ -95,6 +96,7 @@ class AsyncGRPCServer final { ...@@ -95,6 +96,7 @@ class AsyncGRPCServer final {
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
std::string address_; std::string address_;
const bool sync_mode_;
framework::Scope *scope_; framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_; const platform::DeviceContext *dev_ctx_;
......
...@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, ...@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
} }
void StartServer(const std::string& endpoint) { void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true));
framework::ProgramDesc program; framework::ProgramDesc program;
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
......
...@@ -278,7 +278,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -278,7 +278,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode));
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock); auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册