From 46989663b11ab1e7dbedec9db0e91f1102fa2398 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 4 Apr 2018 16:59:03 +0800 Subject: [PATCH] prefetch selected rows --- .../operators/detail/grpc_server_test.cc | 65 +++++++++---------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 61b9484451f..9ae96f8584b 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -45,7 +45,7 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { op->SetOutput("Out", {"out"}); auto& out = *root_block->Var("out"); - out.SetType(framework::proto::VarType::LOD_TENSOR); + out.SetType(framework::proto::VarType::SELECTED_ROWS); out.SetShape({10, 10}); return block; @@ -53,35 +53,37 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { auto w_var = scope->Var("w"); - auto w = w_var->GetMutable(); - w->Resize({10, 10}); - w->mutable_data(*place); + w_var->GetMutable(); auto out_var = scope->Var("out"); - auto out = out_var->GetMutable(); - out->Resize({5, 10}); - out->mutable_data(*place); + out_var->GetMutable(); auto ids_var = scope->Var("ids"); - auto ids = ids_var->GetMutable(); - ids->Resize({5, 1}); + ids_var->GetMutable(); } -void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place) { +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { CreateVarsOnScope(scope, place); - auto ids = scope->Var("ids")->GetMutable(); - auto ptr = ids->mutable_data(*place); - for (int64_t i = 0; i < ids->numel(); ++i) { - ptr[i] = i * 2; - } + auto ids_var = scope->Var("ids")->GetMutable(); + auto rows = ids_var->mutable_rows(); + for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2); + ids_var->mutable_value()->Resize({rows_numel, 1}); + ids_var->mutable_value()->mutable_data(*place); } -void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place) { +void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { CreateVarsOnScope(scope, place); - auto w_var = scope->Var("w"); - auto w = w_var->GetMutable(); - auto ptr = w->mutable_data(*place); - for (int64_t i = 0; i < w->numel(); ++i) { + auto w = scope->Var("w")->GetMutable(); + auto rows = w->mutable_rows(); + for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i); + auto w_value = w->mutable_value(); + w_value->Resize({rows_numel, 10}); + + auto ptr = w_value->mutable_data(*place); + + for (int64_t i = 0; i < w_value->numel(); ++i) { ptr[i] = static_cast(i / 10); } } @@ -94,7 +96,7 @@ void StartServer(const std::string& endpoint) { framework::Executor exe(place); platform::CPUDeviceContext ctx(place); auto* block = AppendPrefetchBlcok(&program); - InitTensorsOnServer(&scope, &place); + InitTensorsOnServer(&scope, &place, 10); rpc_service_->SetProgram(&program); rpc_service_->SetPrefetchBlkdId(block->ID()); @@ -107,15 +109,14 @@ void StartServer(const std::string& endpoint) { TEST(PREFETCH, CPU) { // start up a server instance backend - // TODO(Yancey1989): Need to start a server with optimize blocks and - // prefetch blocks. std::thread server_thread(StartServer, "127.0.0.1:8889"); sleep(2); framework::Scope scope; platform::CPUPlace place; platform::CPUDeviceContext ctx(place); // create var on local scope - InitTensorsOnClient(&scope, &place); + int64_t rows_numel = 5; + InitTensorsOnClient(&scope, &place, rows_numel); std::string in_var_name("ids"); std::string out_var_name("out"); @@ -124,18 +125,16 @@ TEST(PREFETCH, CPU) { out_var_name); client.Wait(); - auto out_var = scope.Var(out_var_name); - auto out = out_var->Get(); + // auto out_var = scope.Var(out_var_name); + auto var = scope.Var(out_var_name); + auto value = var->GetMutable()->value(); + auto ptr = value.mutable_data(place); - auto out_ptr = out.data(); rpc_service_->ShutDown(); server_thread.join(); rpc_service_.reset(nullptr); - EXPECT_EQ(out.dims().size(), 2); - EXPECT_EQ(out_ptr[0], static_cast(0)); - EXPECT_EQ(out_ptr[0 + 1 * out.dims()[1]], static_cast(2)); - EXPECT_EQ(out_ptr[0 + 2 * out.dims()[1]], static_cast(4)); - EXPECT_EQ(out_ptr[0 + 3 * out.dims()[1]], static_cast(6)); - EXPECT_EQ(out_ptr[0 + 4 * out.dims()[1]], static_cast(8)); + for (int64_t i = 0; i < rows_numel; ++i) { + EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast(i * 2)); + } } -- GitLab