diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index 3adeeda90645ca983d9d9229b4cc1c4c90302206..719a7465b8d58ef8588ff1e83c2b971eb6fbb00f 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc) - cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) + cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op) endif() diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index d79ba6d291950e1f089eb11713bd1c3e4d154b27..9a0bd8a04ff02900e924b7dc7a5972387550bd46 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -136,7 +136,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; - SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req); + SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); // var handle VarHandle var_h; diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 7c978b28b6873d05afb435de4caf7f4ce5d33193..c685a8bde84fadd21b5e254ecf2b50ddacd90002 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -157,9 +157,12 @@ class RequestPrefetch final : public RequestBase { virtual void Process() { // prefetch process... ::grpc::ByteBuffer reply; - // TODO(Yancey1989): execute the Block which containers prefetch ops - VLOG(3) << "RequestPrefetch Process in"; + executor_->Run(*program_, scope_, blkid_, false, false); + + std::string var_name = request_.out_varname(); + auto* var = scope_->FindVar(var_name); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 1ad62863a1a98c28cb08f47dfa8a5bfae463ba91..c69917ff2c97aba51e0569c93789ba393186de1f 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -20,43 +20,106 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + namespace framework = paddle::framework; namespace platform = paddle::platform; namespace detail = paddle::operators::detail; +USE_OP(lookup_table); + std::unique_ptr rpc_service_; +framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc& program) { + const auto &root_block = program.Block(0); + auto *block= program.AppendBlock(root_block); + + framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); + framework::VariableNameMap output({{"Output", {"out"}}}); + auto op = block->AppendOp(); + op->SetType("lookup_table"); + op->SetInput("W", {"w"}); + op->SetInput("Ids", {"ids"}); + op->SetOutput("Out", {"out"}); + return block; +} + +void InitTensorsInScope(framework::Scope &scope, platform::CPUPlace &place) { + auto w_var = scope.Var("w"); + auto w = w_var->GetMutable(); + w->Resize({10, 10}); + float *ptr = w->mutable_data(place); + for (int64_t i = 0; i < w->numel(); ++i) { + ptr[i] = static_cast(i/10); + } + + auto out_var = scope.Var("out"); + auto out = out_var->GetMutable(); + out->Resize({5, 10}); + out->mutable_data(place); + + auto ids_var = scope.Var("ids"); + auto ids = ids_var->GetMutable(); + ids->Resize({5, 1}); + auto ids_ptr = ids->mutable_data(place); + for (int64_t i = 0; i < ids->numel(); ++i) { + ids_ptr[i] = i * 2; + } +} + + void StartServer(const std::string& endpoint) { rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + framework::ProgramDesc program; + framework::Scope scope; + platform::CPUPlace place; + framework::Executor exe(place); + platform::CPUDeviceContext ctx(place); + auto* block = AppendPrefetchBlcok(program); + InitTensorsInScope(scope, place); + + rpc_service_->SetProgram(&program); + rpc_service_->SetPrefetchBlkdId(block->ID()); + rpc_service_->SetDevCtx(&ctx); + rpc_service_->SetScope(&scope); + rpc_service_->SetExecutor(&exe); + rpc_service_->RunSyncUpdate(); } + TEST(PREFETCH, CPU) { // start up a server instance backend // TODO(Yancey1989): Need to start a server with optimize blocks and // prefetch blocks. std::thread server_thread(StartServer, "127.0.0.1:8889"); + sleep(3); framework::Scope scope; platform::CPUPlace place; platform::CPUDeviceContext ctx(place); // create var on local scope - std::string in_var_name("in"); + InitTensorsInScope(scope, place); + std::string in_var_name("ids"); std::string out_var_name("out"); - auto* in_var = scope.Var(in_var_name); - auto* in_tensor = in_var->GetMutable(); - in_tensor->Resize({10, 10}); - VLOG(3) << "before mutable_data"; - in_tensor->mutable_data(place); - scope.Var(out_var_name); - VLOG(3) << "before fetch"; detail::RPCClient client; client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, out_var_name); client.Wait(); + auto out_var = scope.Var(out_var_name); + auto out = out_var->Get(); + auto out_ptr = out.data(); rpc_service_->ShutDown(); server_thread.join(); rpc_service_.reset(nullptr); + + EXPECT_EQ(out.dims().size(), 2); + EXPECT_EQ(out_ptr[0], static_cast(0)); + EXPECT_EQ(out_ptr[0 + 1 * out.dims()[1]], static_cast(2)); + EXPECT_EQ(out_ptr[0 + 2 * out.dims()[1]], static_cast(4)); + EXPECT_EQ(out_ptr[0 + 3 * out.dims()[1]], static_cast(6)); + EXPECT_EQ(out_ptr[0 + 4 * out.dims()[1]], static_cast(8)); } diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index fc12e82a7e6bd10262092d1ca367980df64e91c2..48afa02ab89450e1585751b9faa3d64d3853e090 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -67,6 +67,8 @@ message VariableMessage { bytes serialized = 8; // selected_rows data bytes rows = 9; + // prefetch var name + string out_varname = 10; } message VoidMessage {} diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 7e3f015dabdb3fd6190d1ca2f422aa526e8889cd..7fca7fc4e741665832d58175d756a891b727ac94 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -28,7 +28,8 @@ namespace detail { void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg) { + ::grpc::ByteBuffer* msg, + const std::string& out_name) { using VarMsg = sendrecv::VariableMessage; sendrecv::VariableMessage request; std::string header; @@ -50,6 +51,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, e.WriteUint64(VarMsg::kTypeFieldNumber, 1); } + if (!out_name.empty()) { + e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name); + } switch (framework::ToVarType(var->Type())) { case framework::proto::VarType_Type_LOD_TENSOR: { auto tensor = var->Get(); diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index b3b2b8469c8f19313038f2551ab04708a05656d5..3d5ec421e48d1797ef851e2049aeff743fab9d30 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*); void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg); + ::grpc::ByteBuffer* msg, + const std::string& out_varname = std::string()); void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx,