diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 17e38b1cf042657834b4d0d1c12cbbb92f19fa45..194df3e4a8b50700e2be01ce5ebca83b92501fb8 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include // for unique_ptr -#include // for call_once #include #include "glog/logging.h" #include "paddle/fluid/framework/threadpool.h" @@ -39,6 +38,7 @@ Scope::~Scope() { } Scope& Scope::NewScope() const { + std::unique_lock lock(mutex_); kids_.push_back(new Scope(this)); return *kids_.back(); } @@ -92,6 +92,7 @@ std::vector Scope::LocalVarNames() const { } void Scope::DeleteScope(Scope* scope) { + std::unique_lock lock(mutex_); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); this->kids_.erase(it); @@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) { } } -void Scope::EraseVars(std::vector& var_names) { +void Scope::EraseVars(const std::vector& var_names) { std::set var_set(var_names.begin(), var_names.end()); for (auto it = vars_.begin(); it != vars_.end();) { if (var_set.find(it->first) != var_set.end()) { diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index c1e1f49caaa5a60df0e97289aada465b45213971..97a15c71773051dfc01c98f11cf9cb76adbcec7f 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include // NOLINT #include #include #include @@ -51,7 +52,7 @@ class Scope { /// Create a variable with a scope-unique name. Variable* Var(std::string* name = nullptr); - void EraseVars(std::vector& var_names); + void EraseVars(const std::vector& var_names); /// Find a variable in the scope or any of its ancestors. Returns /// nullptr if cannot find. @@ -88,6 +89,9 @@ class Scope { Scope const* parent_{nullptr}; DISABLE_COPY_AND_ASSIGN(Scope); + + private: + mutable std::mutex mutex_; }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index 3adeeda90645ca983d9d9229b4cc1c4c90302206..719a7465b8d58ef8588ff1e83c2b971eb6fbb00f 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc) - cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) + cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op) endif() diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ef987d07f08525bff5267cdc2076ae767417e4f1..8bbfd1f15925992efdeaaffbbe7b350ffbcee889 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -138,7 +138,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; - SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req); + SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); // var handle VarHandle var_h; diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 2e7bf1921a26fc88d854e4db2c501548695a136a..d5fc163bc25409e0607b149b61c6266b38119d9d 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -138,39 +138,48 @@ class RequestPrefetch final : public RequestBase { framework::Scope* scope, const platform::DeviceContext* dev_ctx, framework::Executor* executor, - framework::ProgramDesc* program, int blkid) + framework::ProgramDesc* program, + framework::ExecutorPrepareContext* prefetch_ctx) : RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), - blkid_(blkid) { + prefetch_ctx_(prefetch_ctx) { + request_.reset(new VariableResponse(scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); - service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, - cq_, this); + service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, + cq_, cq_, this); } virtual ~RequestPrefetch() {} - virtual std::string GetReqName() { return request_.varname(); } + virtual std::string GetReqName() { return request_->Varname(); } virtual void Process() { // prefetch process... ::grpc::ByteBuffer reply; - // TODO(Yancey1989): execute the Block which containers prefetch ops - VLOG(3) << "RequestPrefetch Process in"; + std::string var_name = request_->OutVarname(); + auto var_desc = program_->Block(0).FindVar(var_name); + framework::Scope* local_scope = &scope_->NewScope(); + auto* var = local_scope->FindVar(var_name); + InitializeVariable(var, var_desc->GetType()); + executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false); + + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; } protected: - sendrecv::VariableMessage request_; + std::shared_ptr request_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; framework::Executor* executor_; framework::ProgramDesc* program_; + framework::ExecutorPrepareContext* prefetch_ctx_; int blkid_; }; @@ -268,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { } RequestPrefetch* prefetch = new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, - executor_, program_, prefetch_blk_id_); + executor_, program_, prefetch_ctx_); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 380447f47c142bdc16e60f78c4b2d94235ec5060..b6110f92ed4f38a156e0c99ecfb399f3f47a169e 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -63,6 +63,10 @@ class AsyncGRPCServer final { void SetExecutor(framework::Executor *executor) { executor_ = executor; } + void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) { + prefetch_ctx_ = prepared; + } + int GetSelectedPort() { return selected_port_; } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } @@ -111,6 +115,7 @@ class AsyncGRPCServer final { std::unique_ptr t_prefetch_; int prefetch_blk_id_; + framework::ExecutorPrepareContext *prefetch_ctx_; framework::ProgramDesc *program_; framework::Executor *executor_; int selected_port_; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index b89aed0157de8e95564015b3e7f42316a39537f5..c51933718f4ca78e87c77e007c485642000d247d 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -20,43 +20,121 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + namespace framework = paddle::framework; namespace platform = paddle::platform; namespace detail = paddle::operators::detail; +USE_OP(lookup_table); + std::unique_ptr rpc_service_; +framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { + auto root_block = program->MutableBlock(0); + auto* block = program->AppendBlock(*root_block); + + framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); + framework::VariableNameMap output({{"Output", {"out"}}}); + auto op = block->AppendOp(); + op->SetType("lookup_table"); + op->SetInput("W", {"w"}); + op->SetInput("Ids", {"ids"}); + op->SetOutput("Out", {"out"}); + + auto& out = *root_block->Var("out"); + out.SetType(framework::proto::VarType::SELECTED_ROWS); + out.SetShape({10, 10}); + + return block; +} + +void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { + auto w_var = scope->Var("w"); + w_var->GetMutable(); + + auto out_var = scope->Var("out"); + out_var->GetMutable(); + + auto ids_var = scope->Var("ids"); + ids_var->GetMutable(); +} + +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + auto ids_var = scope->Var("ids")->GetMutable(); + auto rows = ids_var->mutable_rows(); + for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2); + ids_var->mutable_value()->Resize({rows_numel, 1}); + ids_var->mutable_value()->mutable_data(*place); +} + +void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + auto w = scope->Var("w")->GetMutable(); + auto rows = w->mutable_rows(); + for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i); + auto w_value = w->mutable_value(); + w_value->Resize({rows_numel, 10}); + + auto ptr = w_value->mutable_data(*place); + + for (int64_t i = 0; i < w_value->numel(); ++i) { + ptr[i] = static_cast(i / 10); + } +} + void StartServer(const std::string& endpoint) { rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + framework::ProgramDesc program; + framework::Scope scope; + platform::CPUPlace place; + framework::Executor exe(place); + platform::CPUDeviceContext ctx(place); + auto* block = AppendPrefetchBlcok(&program); + auto prepared = exe.Prepare(program, block->ID()); + InitTensorsOnServer(&scope, &place, 10); + + rpc_service_->SetProgram(&program); + rpc_service_->SetPrefetchPreparedCtx(prepared.get()); + rpc_service_->SetDevCtx(&ctx); + rpc_service_->SetScope(&scope); + rpc_service_->SetExecutor(&exe); + rpc_service_->RunSyncUpdate(); } TEST(PREFETCH, CPU) { // start up a server instance backend - // TODO(Yancey1989): Need to start a server with optimize blocks and - // prefetch blocks. std::thread server_thread(StartServer, "127.0.0.1:8889"); + sleep(2); framework::Scope scope; platform::CPUPlace place; platform::CPUDeviceContext ctx(place); // create var on local scope - std::string in_var_name("in"); + int64_t rows_numel = 5; + InitTensorsOnClient(&scope, &place, rows_numel); + std::string in_var_name("ids"); std::string out_var_name("out"); - auto* in_var = scope.Var(in_var_name); - auto* in_tensor = in_var->GetMutable(); - in_tensor->Resize({10, 10}); - VLOG(3) << "before mutable_data"; - in_tensor->mutable_data(place); - scope.Var(out_var_name); - - VLOG(3) << "before fetch"; detail::RPCClient client; client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, out_var_name); client.Wait(); + auto var = scope.Var(out_var_name); + auto value = var->GetMutable()->value(); + auto ptr = value.mutable_data(place); + rpc_service_->ShutDown(); server_thread.join(); rpc_service_.reset(nullptr); + + for (int64_t i = 0; i < rows_numel; ++i) { + EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast(i * 2)); + } } diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index fc12e82a7e6bd10262092d1ca367980df64e91c2..02bb2b9cebb87b83aa1cbef0c644f969b4d17284 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -21,7 +21,7 @@ service SendRecvService { rpc SendVariable(VariableMessage) returns (VoidMessage) {} // Argument VariableMessage for GetVariable should only contain varname. rpc GetVariable(VariableMessage) returns (VariableMessage) {} - // Prefetch variable by Ids + // pre-fetch variable by given variable name and Ids rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} } @@ -67,6 +67,8 @@ message VariableMessage { bytes serialized = 8; // selected_rows data bytes rows = 9; + // Look up table block execution output variable name. + string out_varname = 10; } message VoidMessage {} diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index f8576d01b10f4c0fda4d12d371b2966739acfc21..1577111a9628350b0cf3f01f2cf15f8c27994673 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -30,7 +30,8 @@ namespace detail { void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg) { + ::grpc::ByteBuffer* msg, + const std::string& out_name) { using VarMsg = sendrecv::VariableMessage; sendrecv::VariableMessage request; std::string header; @@ -52,6 +53,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, e.WriteUint64(VarMsg::kTypeFieldNumber, 1); } + if (!out_name.empty()) { + e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name); + } switch (framework::ToVarType(var->Type())) { case framework::proto::VarType_Type_LOD_TENSOR: { auto tensor = var->Get(); diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index d7954440846b8db9a9add0110fb9a546a762774d..c72e1bd076f670458f3915072154847db6205092 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*); void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg); + ::grpc::ByteBuffer* msg, + const std::string& out_varname = std::string()); void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 78e1d274a92241b5f2093beb63acdc8c497dfb83..c9d7fd6d1581f6f4182e9e3e0d633c13a3c336a5 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -416,6 +416,20 @@ int VariableResponse::Parse(Source* source) { } break; } + case sendrecv::VariableMessage::kOutVarnameFieldNumber: { + uint32_t length; + if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) { + return tag; + } + + std::string temp; + if (!input.ReadString(&temp, length)) { + return tag; + } + + meta_.set_out_varname(temp); + break; + } default: { // Unknown tag, return unknown error. diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index 050b6b84010b4f3e95bc88e5bb738ff18b7fe423..93b0d3cfb4f7d7f336414361773f872d7b259482 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -55,6 +55,7 @@ class VariableResponse { int Parse(const ::grpc::ByteBuffer& byte_buffer); inline std::string Varname() { return meta_.varname(); } + inline std::string OutVarname() { return meta_.out_varname(); } // should call parse first. framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }