From 0cafe39010eb6d69699c4dccbfa70715aec8bd85 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Apr 2018 14:25:21 +0800 Subject: [PATCH] run prefetch prog on server --- paddle/fluid/operators/detail/CMakeLists.txt | 2 +- paddle/fluid/operators/detail/grpc_client.cc | 2 +- paddle/fluid/operators/detail/grpc_server.cc | 7 +- .../operators/detail/grpc_server_test.cc | 79 +++++++++++++++++-- paddle/fluid/operators/detail/send_recv.proto | 2 + .../operators/detail/sendrecvop_utils.cc | 6 +- .../fluid/operators/detail/sendrecvop_utils.h | 3 +- 7 files changed, 87 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index 3adeeda9064..719a7465b8d 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc) - cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) + cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op) endif() diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index d79ba6d2919..9a0bd8a04ff 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -136,7 +136,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; - SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req); + SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); // var handle VarHandle var_h; diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 7c978b28b68..c685a8bde84 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -157,9 +157,12 @@ class RequestPrefetch final : public RequestBase { virtual void Process() { // prefetch process... ::grpc::ByteBuffer reply; - // TODO(Yancey1989): execute the Block which containers prefetch ops - VLOG(3) << "RequestPrefetch Process in"; + executor_->Run(*program_, scope_, blkid_, false, false); + + std::string var_name = request_.out_varname(); + auto* var = scope_->FindVar(var_name); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 1ad62863a1a..c69917ff2c9 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -20,43 +20,106 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + namespace framework = paddle::framework; namespace platform = paddle::platform; namespace detail = paddle::operators::detail; +USE_OP(lookup_table); + std::unique_ptr rpc_service_; +framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc& program) { + const auto &root_block = program.Block(0); + auto *block= program.AppendBlock(root_block); + + framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); + framework::VariableNameMap output({{"Output", {"out"}}}); + auto op = block->AppendOp(); + op->SetType("lookup_table"); + op->SetInput("W", {"w"}); + op->SetInput("Ids", {"ids"}); + op->SetOutput("Out", {"out"}); + return block; +} + +void InitTensorsInScope(framework::Scope &scope, platform::CPUPlace &place) { + auto w_var = scope.Var("w"); + auto w = w_var->GetMutable(); + w->Resize({10, 10}); + float *ptr = w->mutable_data(place); + for (int64_t i = 0; i < w->numel(); ++i) { + ptr[i] = static_cast(i/10); + } + + auto out_var = scope.Var("out"); + auto out = out_var->GetMutable(); + out->Resize({5, 10}); + out->mutable_data(place); + + auto ids_var = scope.Var("ids"); + auto ids = ids_var->GetMutable(); + ids->Resize({5, 1}); + auto ids_ptr = ids->mutable_data(place); + for (int64_t i = 0; i < ids->numel(); ++i) { + ids_ptr[i] = i * 2; + } +} + + void StartServer(const std::string& endpoint) { rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + framework::ProgramDesc program; + framework::Scope scope; + platform::CPUPlace place; + framework::Executor exe(place); + platform::CPUDeviceContext ctx(place); + auto* block = AppendPrefetchBlcok(program); + InitTensorsInScope(scope, place); + + rpc_service_->SetProgram(&program); + rpc_service_->SetPrefetchBlkdId(block->ID()); + rpc_service_->SetDevCtx(&ctx); + rpc_service_->SetScope(&scope); + rpc_service_->SetExecutor(&exe); + rpc_service_->RunSyncUpdate(); } + TEST(PREFETCH, CPU) { // start up a server instance backend // TODO(Yancey1989): Need to start a server with optimize blocks and // prefetch blocks. std::thread server_thread(StartServer, "127.0.0.1:8889"); + sleep(3); framework::Scope scope; platform::CPUPlace place; platform::CPUDeviceContext ctx(place); // create var on local scope - std::string in_var_name("in"); + InitTensorsInScope(scope, place); + std::string in_var_name("ids"); std::string out_var_name("out"); - auto* in_var = scope.Var(in_var_name); - auto* in_tensor = in_var->GetMutable(); - in_tensor->Resize({10, 10}); - VLOG(3) << "before mutable_data"; - in_tensor->mutable_data(place); - scope.Var(out_var_name); - VLOG(3) << "before fetch"; detail::RPCClient client; client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, out_var_name); client.Wait(); + auto out_var = scope.Var(out_var_name); + auto out = out_var->Get(); + auto out_ptr = out.data(); rpc_service_->ShutDown(); server_thread.join(); rpc_service_.reset(nullptr); + + EXPECT_EQ(out.dims().size(), 2); + EXPECT_EQ(out_ptr[0], static_cast(0)); + EXPECT_EQ(out_ptr[0 + 1 * out.dims()[1]], static_cast(2)); + EXPECT_EQ(out_ptr[0 + 2 * out.dims()[1]], static_cast(4)); + EXPECT_EQ(out_ptr[0 + 3 * out.dims()[1]], static_cast(6)); + EXPECT_EQ(out_ptr[0 + 4 * out.dims()[1]], static_cast(8)); } diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index fc12e82a7e6..48afa02ab89 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -67,6 +67,8 @@ message VariableMessage { bytes serialized = 8; // selected_rows data bytes rows = 9; + // prefetch var name + string out_varname = 10; } message VoidMessage {} diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 7e3f015dabd..7fca7fc4e74 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -28,7 +28,8 @@ namespace detail { void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg) { + ::grpc::ByteBuffer* msg, + const std::string& out_name) { using VarMsg = sendrecv::VariableMessage; sendrecv::VariableMessage request; std::string header; @@ -50,6 +51,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, e.WriteUint64(VarMsg::kTypeFieldNumber, 1); } + if (!out_name.empty()) { + e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name); + } switch (framework::ToVarType(var->Type())) { case framework::proto::VarType_Type_LOD_TENSOR: { auto tensor = var->Get(); diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index b3b2b8469c8..3d5ec421e48 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*); void SerializeToByteBuffer(const std::string& name, framework::Variable* var, const platform::DeviceContext& ctx, - ::grpc::ByteBuffer* msg); + ::grpc::ByteBuffer* msg, + const std::string& out_varname = std::string()); void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, -- GitLab