提交 d741576e 编写于 作者: S seiriosPlus

add large scale UT

上级 8def0e34
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
...@@ -230,3 +231,85 @@ TEST(SENDANDRECV, CPU) { ...@@ -230,3 +231,85 @@ TEST(SENDANDRECV, CPU) {
g_rpc_service.reset(nullptr); g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr); g_req_handler.reset(nullptr);
} }
void StartCheckpointServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
std::vector<distributed::SparseMeta> metas;
auto meta = distributed::SparseMeta();
meta.name = "embedding.block0";
meta.value_names = "Param";
meta.value_dims = "64";
meta.mode = "0";
meta.grad_name = "embedding@Grad";
meta.cached_varnames = "kSparseIds";
meta.initializer_attrs = "fill_constant&1.0";
meta.entry = "none";
metas.push_back(meta);
distributed::LargeScaleKV::Init(metas);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join();
}
TEST(LARGE_SCALE_CHECKPOINT, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestNotifyHandler(
distributed::DistributedMode::kAsync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE_NE(client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::thread server_thread(StartCheckpointServer, distributed::kRequestNotify);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
auto save_path = string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/base",
"embedding", "embedding.block0");
int mode = 0;
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
client->Wait();
save_path = string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/delta",
"embedding", "embedding.block0");
mode = 1;
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
client->Wait();
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册