diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 5c2979222a8121aeaa82222085aebdcd51f6c603..82d422d110eb155026dec471fff703816e5c7f41 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -158,13 +158,14 @@ class RequestPrefetch final : public RequestBase { std::string in_var_name = request_->Varname(); std::string out_var_name = request_->OutVarname(); VLOG(3) << "in_var_name: " << in_var_name + << "out_var_name: " << out_var_name << " RequestPrefetch: " << out_var_name; auto scope = request_->GetMutableLocalScope(); auto invar = scope->FindVar(in_var_name); - framework::Variable* outvar = nullptr; + framework::Variable* outvar = scope->FindVar(out_var_name); - request_handler_->Handle(in_var_name, scope, invar, &outvar); + request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name); SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(), &reply_); @@ -284,7 +285,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, } else if (rpc_name == kRequestPrefetch) { b = new RequestPrefetch(&service_, cq.get(), handler, req_id); } else { - PADDLE_ENFORCE(false, "not surpported rpc"); + PADDLE_ENFORCE(false, "not supported rpc"); } reqs[req_id] = b; diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index 373a6aaa09c06899e36fdcab8c2c745a377924a3..e133df4896a8f370a45941932ebb56ac47e100af 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -96,8 +96,8 @@ class RequestHandler { // *request_handler_->dev_ctx(), &reply_); // } virtual bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, - framework::Variable** outvar) = 0; + framework::Variable* var, framework::Variable** outvar, + const std::string& out_var_name = "") = 0; protected: const bool sync_mode_; diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index dc28740bf04e5b17727b02741a20a8c09d2bbc88..0f42daa5bc2eb14af77eb589b6a192241e1b02cd 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -33,7 +33,8 @@ namespace detail { bool RequestSendHandler::Handle(const std::string& varname, framework::Scope* scope, framework::Variable* invar, - framework::Variable** outvar) { + framework::Variable** outvar, + const std::string& out_var_name) { VLOG(4) << "RequestSendHandler:" << varname; // Async @@ -82,7 +83,8 @@ void RequestSendHandler::ResetSparseVarRecorder() { bool RequestGetHandler::Handle(const std::string& varname, framework::Scope* scope, framework::Variable* invar, - framework::Variable** outvar) { + framework::Variable** outvar, + const std::string& out_var_name) { VLOG(4) << "RequestGetHandler:" << varname; if (varname != FETCH_BARRIER_MESSAGE) { @@ -105,11 +107,11 @@ bool RequestGetHandler::Handle(const std::string& varname, bool RequestPrefetchHandler::Handle(const std::string& varname, framework::Scope* scope, framework::Variable* invar, - framework::Variable** outvar) { + framework::Variable** outvar, + const std::string& out_var_name) { VLOG(4) << "RequestPrefetchHandler " << varname; - auto var_desc = program_->Block(0).FindVar(varname); - *outvar = scope->FindVar(varname); + auto var_desc = program_->Block(0).FindVar(out_var_name); InitializeVariable(*outvar, var_desc->GetType()); executor_->RunPreparedContext( (*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope); diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/detail/request_handler_impl.h index 443d951914dd0f40e8831abc637848363d9fef16..67bf277b24d76b5f76db2569947aed6d6bffd268 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.h +++ b/paddle/fluid/operators/detail/request_handler_impl.h @@ -40,7 +40,8 @@ class RequestSendHandler final : public RequestHandler { explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {} virtual ~RequestSendHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar) override; + framework::Variable* var, framework::Variable** outvar, + const std::string& out_var_name = "") override; void ResetSparseVarRecorder(); private: @@ -53,7 +54,8 @@ class RequestGetHandler final : public RequestHandler { explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {} virtual ~RequestGetHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar) override; + framework::Variable* var, framework::Variable** outvar, + const std::string& out_var_name = "") override; }; class RequestPrefetchHandler final : public RequestHandler { @@ -61,7 +63,8 @@ class RequestPrefetchHandler final : public RequestHandler { explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {} virtual ~RequestPrefetchHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, - framework::Variable* var, framework::Variable** outvar) override; + framework::Variable* var, framework::Variable** outvar, + const std::string& out_var_name = "") override; }; } // namespace detail