diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index aa7d33438730ca28cb398a115652271ca3043133..d5fc163bc25409e0607b149b61c6266b38119d9d 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -138,13 +138,14 @@ class RequestPrefetch final : public RequestBase { framework::Scope* scope, const platform::DeviceContext* dev_ctx, framework::Executor* executor, - framework::ProgramDesc* program, int blkid) + framework::ProgramDesc* program, + framework::ExecutorPrepareContext* prefetch_ctx) : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), - blkid_(blkid) { + prefetch_ctx_(prefetch_ctx) { request_.reset(new VariableResponse(scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, @@ -164,8 +165,7 @@ class RequestPrefetch final : public RequestBase { framework::Scope* local_scope = &scope_->NewScope(); auto* var = local_scope->FindVar(var_name); InitializeVariable(var, var_desc->GetType()); - - executor_->Run(*program_, local_scope, blkid_, false, false); + executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); @@ -179,6 +179,7 @@ class RequestPrefetch final : public RequestBase { framework::Scope* scope_; framework::Executor* executor_; framework::ProgramDesc* program_; + framework::ExecutorPrepareContext* prefetch_ctx_; int blkid_; }; @@ -276,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { } RequestPrefetch* prefetch = new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, - executor_, program_, prefetch_blk_id_); + executor_, program_, prefetch_ctx_); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 380447f47c142bdc16e60f78c4b2d94235ec5060..b6110f92ed4f38a156e0c99ecfb399f3f47a169e 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -63,6 +63,10 @@ class AsyncGRPCServer final { void SetExecutor(framework::Executor *executor) { executor_ = executor; } + void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) { + prefetch_ctx_ = prepared; + } + int GetSelectedPort() { return selected_port_; } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } @@ -111,6 +115,7 @@ class AsyncGRPCServer final { std::unique_ptr t_prefetch_; int prefetch_blk_id_; + framework::ExecutorPrepareContext *prefetch_ctx_; framework::ProgramDesc *program_; framework::Executor *executor_; int selected_port_; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 9ae96f8584b92a3d5bac434b0e9d886f73eb556b..c51933718f4ca78e87c77e007c485642000d247d 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -96,10 +96,11 @@ void StartServer(const std::string& endpoint) { framework::Executor exe(place); platform::CPUDeviceContext ctx(place); auto* block = AppendPrefetchBlcok(&program); + auto prepared = exe.Prepare(program, block->ID()); InitTensorsOnServer(&scope, &place, 10); rpc_service_->SetProgram(&program); - rpc_service_->SetPrefetchBlkdId(block->ID()); + rpc_service_->SetPrefetchPreparedCtx(prepared.get()); rpc_service_->SetDevCtx(&ctx); rpc_service_->SetScope(&scope); rpc_service_->SetExecutor(&exe); @@ -125,7 +126,6 @@ TEST(PREFETCH, CPU) { out_var_name); client.Wait(); - // auto out_var = scope.Var(out_var_name); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); auto ptr = value.mutable_data(place); diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 48afa02ab89450e1585751b9faa3d64d3853e090..0616c1ae1c93e1e88d10698d1278691ca3a3d838 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -21,7 +21,7 @@ service SendRecvService { rpc SendVariable(VariableMessage) returns (VoidMessage) {} // Argument VariableMessage for GetVariable should only contain varname. rpc GetVariable(VariableMessage) returns (VariableMessage) {} - // Prefetch variable by Ids + // Look up table block execution output variable name. rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} }