From f132f51eb4976dd3873b2fe8bbb086b0f885a51f Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Sun, 8 Apr 2018 15:35:15 +0800 Subject: [PATCH] prepare prefetch context --- paddle/fluid/operators/detail/grpc_server.cc | 11 ++++++----- paddle/fluid/operators/detail/grpc_server.h | 5 +++++ paddle/fluid/operators/detail/grpc_server_test.cc | 4 ++-- paddle/fluid/operators/detail/send_recv.proto | 2 +- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index aa7d3343873..d5fc163bc25 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -138,13 +138,14 @@ class RequestPrefetch final : public RequestBase { framework::Scope* scope, const platform::DeviceContext* dev_ctx, framework::Executor* executor, - framework::ProgramDesc* program, int blkid) + framework::ProgramDesc* program, + framework::ExecutorPrepareContext* prefetch_ctx) : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), - blkid_(blkid) { + prefetch_ctx_(prefetch_ctx) { request_.reset(new VariableResponse(scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, @@ -164,8 +165,7 @@ class RequestPrefetch final : public RequestBase { framework::Scope* local_scope = &scope_->NewScope(); auto* var = local_scope->FindVar(var_name); InitializeVariable(var, var_desc->GetType()); - - executor_->Run(*program_, local_scope, blkid_, false, false); + executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); @@ -179,6 +179,7 @@ class RequestPrefetch final : public RequestBase { framework::Scope* scope_; framework::Executor* executor_; framework::ProgramDesc* program_; + framework::ExecutorPrepareContext* prefetch_ctx_; int blkid_; }; @@ -276,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { } RequestPrefetch* prefetch = new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, - executor_, program_, prefetch_blk_id_); + executor_, program_, prefetch_ctx_); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 380447f47c1..b6110f92ed4 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -63,6 +63,10 @@ class AsyncGRPCServer final { void SetExecutor(framework::Executor *executor) { executor_ = executor; } + void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) { + prefetch_ctx_ = prepared; + } + int GetSelectedPort() { return selected_port_; } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } @@ -111,6 +115,7 @@ class AsyncGRPCServer final { std::unique_ptr t_prefetch_; int prefetch_blk_id_; + framework::ExecutorPrepareContext *prefetch_ctx_; framework::ProgramDesc *program_; framework::Executor *executor_; int selected_port_; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 9ae96f8584b..c51933718f4 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -96,10 +96,11 @@ void StartServer(const std::string& endpoint) { framework::Executor exe(place); platform::CPUDeviceContext ctx(place); auto* block = AppendPrefetchBlcok(&program); + auto prepared = exe.Prepare(program, block->ID()); InitTensorsOnServer(&scope, &place, 10); rpc_service_->SetProgram(&program); - rpc_service_->SetPrefetchBlkdId(block->ID()); + rpc_service_->SetPrefetchPreparedCtx(prepared.get()); rpc_service_->SetDevCtx(&ctx); rpc_service_->SetScope(&scope); rpc_service_->SetExecutor(&exe); @@ -125,7 +126,6 @@ TEST(PREFETCH, CPU) { out_var_name); client.Wait(); - // auto out_var = scope.Var(out_var_name); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); auto ptr = value.mutable_data(place); diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 48afa02ab89..0616c1ae1c9 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -21,7 +21,7 @@ service SendRecvService { rpc SendVariable(VariableMessage) returns (VoidMessage) {} // Argument VariableMessage for GetVariable should only contain varname. rpc GetVariable(VariableMessage) returns (VariableMessage) {} - // Prefetch variable by Ids + // Look up table block execution output variable name. rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} } -- GitLab