未验证 提交 d41836ef 编写于 作者: Z zmxdream 提交者: GitHub

[HeterPS]fix ut for heteps comm op (#39684)

* fix. test=develop

* fix. test=develop

* fix code style. test=develop

* fix. test=develop

* fix. test=develop
上级 65ced1fa
...@@ -17,6 +17,9 @@ limitations under the License. */ ...@@ -17,6 +17,9 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <random>
#include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h" #include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h" #include "paddle/fluid/distributed/ps/service/heter_server.h"
...@@ -36,6 +39,19 @@ DECLARE_double(eager_delete_tensor_gb); ...@@ -36,6 +39,19 @@ DECLARE_double(eager_delete_tensor_gb);
USE_OP_ITSELF(scale); USE_OP_ITSELF(scale);
USE_NO_KERNEL_OP(heter_listen_and_serv); USE_NO_KERNEL_OP(heter_listen_and_serv);
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist(4444, 25000);
int port = dist(rng);
std::string ip_port;
std::stringstream temp_str;
temp_str << "127.0.0.1:";
temp_str << port;
temp_str >> ip_port;
return ip_port;
}
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
framework::BlockDesc* block = framework::BlockDesc* block =
program->AppendBlock(*(program->MutableBlock(0))); program->AppendBlock(*(program->MutableBlock(0)));
...@@ -53,16 +69,13 @@ framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { ...@@ -53,16 +69,13 @@ framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
return block; return block;
} }
void GetHeterListenAndServProgram(framework::ProgramDesc* program) { void GetHeterListenAndServProgram(framework::ProgramDesc* program,
std::string endpoint) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
auto* sub_block = AppendSendAndRecvBlock(program); auto* sub_block = AppendSendAndRecvBlock(program);
std::vector<framework::BlockDesc*> optimize_blocks; std::vector<framework::BlockDesc*> optimize_blocks;
optimize_blocks.push_back(sub_block); optimize_blocks.push_back(sub_block);
std::vector<std::string> message_to_block_id = {"x:1"}; std::vector<std::string> message_to_block_id = {"x:1"};
std::string endpoint = "127.0.0.1:19944";
framework::OpDesc* op = root_block->AppendOp(); framework::OpDesc* op = root_block->AppendOp();
op->SetType("heter_listen_and_serv"); op->SetType("heter_listen_and_serv");
op->SetInput("X", {}); op->SetInput("X", {});
...@@ -129,7 +142,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, ...@@ -129,7 +142,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
CreateVarsOnScope(scope, place); CreateVarsOnScope(scope, place);
} }
void StartHeterServer() { void StartHeterServer(std::string endpoint) {
framework::ProgramDesc program; framework::ProgramDesc program;
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
...@@ -137,7 +150,7 @@ void StartHeterServer() { ...@@ -137,7 +150,7 @@ void StartHeterServer() {
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
LOG(INFO) << "before GetHeterListenAndServProgram"; LOG(INFO) << "before GetHeterListenAndServProgram";
GetHeterListenAndServProgram(&program); GetHeterListenAndServProgram(&program, endpoint);
auto prepared = exe.Prepare(program, 0); auto prepared = exe.Prepare(program, 0);
LOG(INFO) << "before InitTensorsOnServer"; LOG(INFO) << "before InitTensorsOnServer";
...@@ -150,13 +163,12 @@ void StartHeterServer() { ...@@ -150,13 +163,12 @@ void StartHeterServer() {
TEST(HETER_LISTEN_AND_SERV, CPU) { TEST(HETER_LISTEN_AND_SERV, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:19944"; std::string endpoint = get_ip_port();
std::string previous_endpoint = "127.0.0.1:19944"; std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer"; LOG(INFO) << "before StartSendAndRecvServer";
FLAGS_eager_delete_tensor_gb = -1; FLAGS_eager_delete_tensor_gb = -1;
std::thread server_thread(StartHeterServer); std::thread server_thread(StartHeterServer, endpoint);
sleep(1); sleep(1);
auto b_rpc_service = distributed::HeterServer::GetInstance(); auto b_rpc_service = distributed::HeterServer::GetInstance();
b_rpc_service->WaitServerReady(); b_rpc_service->WaitServerReady();
using MicroScope = using MicroScope =
......
...@@ -17,6 +17,9 @@ limitations under the License. */ ...@@ -17,6 +17,9 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <random>
#include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h" #include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h" #include "paddle/fluid/distributed/ps/service/heter_server.h"
...@@ -33,6 +36,19 @@ USE_OP_ITSELF(scale); ...@@ -33,6 +36,19 @@ USE_OP_ITSELF(scale);
std::shared_ptr<distributed::HeterServer> b_rpc_service; std::shared_ptr<distributed::HeterServer> b_rpc_service;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist(4444, 25000);
int port = dist(rng);
std::string ip_port;
std::stringstream temp_str;
temp_str << "127.0.0.1:";
temp_str << port;
temp_str >> ip_port;
return ip_port;
}
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block); auto* block = program->AppendBlock(*root_block);
...@@ -178,16 +194,17 @@ void StartSendAndRecvServer(std::string endpoint) { ...@@ -178,16 +194,17 @@ void StartSendAndRecvServer(std::string endpoint) {
b_rpc_service->SetRequestHandler(b_req_handler); b_rpc_service->SetRequestHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer"; LOG(INFO) << "before HeterServer::RunServer";
std::thread server_thread(std::bind(RunServer, b_rpc_service)); RunServer(b_rpc_service);
// std::thread server_thread(std::bind(RunServer, b_rpc_service));
server_thread.join(); // server_thread.join();
} }
TEST(SENDANDRECV, CPU) { TEST(SENDANDRECV, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:4444"; std::string endpoint = get_ip_port();
std::string previous_endpoint = "127.0.0.1:4444"; std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer"; LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance(); b_rpc_service = distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint); std::thread server_thread(StartSendAndRecvServer, endpoint);
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <random>
#include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h" #include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h" #include "paddle/fluid/distributed/ps/service/heter_server.h"
...@@ -36,6 +38,19 @@ USE_OP(send_and_recv); ...@@ -36,6 +38,19 @@ USE_OP(send_and_recv);
std::shared_ptr<distributed::HeterServer> b_rpc_service; std::shared_ptr<distributed::HeterServer> b_rpc_service;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist(4444, 25000);
int port = dist(rng);
std::string ip_port;
std::stringstream temp_str;
temp_str << "127.0.0.1:";
temp_str << port;
temp_str >> ip_port;
return ip_port;
}
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block); auto* block = program->AppendBlock(*root_block);
...@@ -151,16 +166,18 @@ void StartSendAndRecvServer(std::string endpoint) { ...@@ -151,16 +166,18 @@ void StartSendAndRecvServer(std::string endpoint) {
b_rpc_service->SetRequestHandler(b_req_handler); b_rpc_service->SetRequestHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer"; LOG(INFO) << "before HeterServer::RunServer";
std::thread server_thread(std::bind(RunServer, b_rpc_service));
server_thread.join(); RunServer(b_rpc_service);
// std::thread server_thread(std::bind(RunServer, b_rpc_service));
// server_thread.join();
} }
TEST(SENDANDRECV, CPU) { TEST(SENDANDRECV, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:4444"; std::string endpoint = get_ip_port();
std::string previous_endpoint = "127.0.0.1:4444"; std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer"; LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance(); b_rpc_service = distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint); std::thread server_thread(StartSendAndRecvServer, endpoint);
...@@ -260,8 +277,10 @@ TEST(SENDANDRECV, CPU) { ...@@ -260,8 +277,10 @@ TEST(SENDANDRECV, CPU) {
exe.RunPreparedContext(prepared.get(), scope, false); exe.RunPreparedContext(prepared.get(), scope, false);
LOG(INFO) << "client wait for Pop"; LOG(INFO) << "client wait for Pop";
auto task = (*task_queue_)[0]->Pop(); auto task = (*task_queue_)[0]->Pop();
LOG(INFO) << "client get from task queue"; LOG(INFO) << "client get from task queue";
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
task.first, "x", task.first, "x",
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -19,6 +19,9 @@ limitations under the License. */ ...@@ -19,6 +19,9 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <random>
#include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h" #include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h" #include "paddle/fluid/distributed/ps/service/heter_server.h"
...@@ -40,20 +43,30 @@ USE_OP(send_and_recv); ...@@ -40,20 +43,30 @@ USE_OP(send_and_recv);
std::shared_ptr<distributed::HeterServer> b_rpc_service2; std::shared_ptr<distributed::HeterServer> b_rpc_service2;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist(4444, 25000);
int port = dist(rng);
std::string ip_port;
std::stringstream temp_str;
temp_str << "127.0.0.1:";
temp_str << port;
temp_str >> ip_port;
return ip_port;
}
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block); auto* block = program->AppendBlock(*root_block);
framework::OpDesc* op = block->AppendOp(); framework::OpDesc* op = block->AppendOp();
op->SetType("scale"); op->SetType("scale");
op->SetInput("X", {"x"}); op->SetInput("X", {"x"});
op->SetOutput("Out", {"res"}); op->SetOutput("Out", {"res"});
op->SetAttr("scale", 0.5f); op->SetAttr("scale", 0.5f);
auto& out = *root_block->Var("res"); auto& out = *root_block->Var("res");
out.SetType(framework::proto::VarType::LOD_TENSOR); out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({1, 10}); out.SetShape({1, 10});
return block; return block;
} }
...@@ -172,15 +185,17 @@ void StartSendAndRecvServer(std::string endpoint) { ...@@ -172,15 +185,17 @@ void StartSendAndRecvServer(std::string endpoint) {
b_rpc_service2->SetRequestHandler(b_req_handler); b_rpc_service2->SetRequestHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer"; LOG(INFO) << "before HeterServer::RunServer";
std::thread server_thread(std::bind(RunServer, b_rpc_service2));
server_thread.join(); RunServer(b_rpc_service2);
// std::thread server_thread(std::bind(RunServer, b_rpc_service2));
// server_thread.join();
} }
TEST(SENDANDRECV, GPU) { TEST(SENDANDRECV, GPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:4445"; std::string endpoint = get_ip_port();
std::string previous_endpoint = "127.0.0.1:4445"; std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer"; LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service2 = distributed::HeterServer::GetInstance(); b_rpc_service2 = distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint); std::thread server_thread(StartSendAndRecvServer, endpoint);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册