From d741576e68ed8cf889fa16593e15502117f7bea1 Mon Sep 17 00:00:00 2001 From: seiriosPlus Date: Sun, 27 Sep 2020 19:48:00 +0800 Subject: [PATCH] add large scale UT --- .../operators/distributed/rpc_server_test.cc | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index 5ce7ac85269..bfb9c77dcc0 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h" +#include "paddle/fluid/operators/distributed/large_scale_kv.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_server.h" @@ -230,3 +231,85 @@ TEST(SENDANDRECV, CPU) { g_rpc_service.reset(nullptr); g_req_handler.reset(nullptr); } + +void StartCheckpointServer(const std::string& rpc_name) { + framework::ProgramDesc program; + framework::Scope scope; + platform::CPUPlace place; + framework::Executor exe(place); + platform::CPUDeviceContext ctx(place); + + std::vector metas; + + auto meta = distributed::SparseMeta(); + meta.name = "embedding.block0"; + meta.value_names = "Param"; + meta.value_dims = "64"; + meta.mode = "0"; + meta.grad_name = "embedding@Grad"; + meta.cached_varnames = "kSparseIds"; + meta.initializer_attrs = "fill_constant&1.0"; + meta.entry = "none"; + metas.push_back(meta); + distributed::LargeScaleKV::Init(metas); + + std::unordered_map> + prefetch_var_name_to_prepared; + + g_req_handler->SetProgram(&program); + g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); + g_req_handler->SetDevCtx(&ctx); + g_req_handler->SetScope(&scope); + g_req_handler->SetExecutor(&exe); + + g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); + + g_req_handler->SetRPCServer(g_rpc_service.get()); + + std::thread server_thread( + std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); + + server_thread.join(); +} + +TEST(LARGE_SCALE_CHECKPOINT, CPU) { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + g_req_handler.reset(new distributed::RequestNotifyHandler( + distributed::DistributedMode::kAsync)); + g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(0); + PADDLE_ENFORCE_NE(client, nullptr, + platform::errors::InvalidArgument( + "Client Start Fail, Check Your Code & Env")); + + std::thread server_thread(StartCheckpointServer, distributed::kRequestNotify); + + g_rpc_service->WaitServerReady(); + int port = g_rpc_service->GetSelectedPort(); + std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); + + framework::Scope scope; + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + + auto save_path = string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/base", + "embedding", "embedding.block0"); + int mode = 0; + client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode); + client->Wait(); + + save_path = string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/delta", + "embedding", "embedding.block0"); + mode = 1; + client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode); + client->Wait(); + + g_rpc_service->ShutDown(); + server_thread.join(); + LOG(INFO) << "begin reset"; + g_rpc_service.reset(nullptr); + g_req_handler.reset(nullptr); +} -- GitLab