未验证 提交 d41836ef 编写于 作者: Z zmxdream 提交者: GitHub

[HeterPS]fix ut for heteps comm op (#39684)

* fix. test=develop

* fix. test=develop

* fix code style. test=develop

* fix. test=develop

* fix. test=develop
上级 65ced1fa
......@@ -17,6 +17,9 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <random>
#include <sstream>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h"
......@@ -36,6 +39,19 @@ DECLARE_double(eager_delete_tensor_gb);
USE_OP_ITSELF(scale);
USE_NO_KERNEL_OP(heter_listen_and_serv);
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist(4444, 25000);
int port = dist(rng);
std::string ip_port;
std::stringstream temp_str;
temp_str << "127.0.0.1:";
temp_str << port;
temp_str >> ip_port;
return ip_port;
}
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
framework::BlockDesc* block =
program->AppendBlock(*(program->MutableBlock(0)));
......@@ -53,16 +69,13 @@ framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
return block;
}
void GetHeterListenAndServProgram(framework::ProgramDesc* program) {
void GetHeterListenAndServProgram(framework::ProgramDesc* program,
std::string endpoint) {
auto root_block = program->MutableBlock(0);
auto* sub_block = AppendSendAndRecvBlock(program);
std::vector<framework::BlockDesc*> optimize_blocks;
optimize_blocks.push_back(sub_block);
std::vector<std::string> message_to_block_id = {"x:1"};
std::string endpoint = "127.0.0.1:19944";
framework::OpDesc* op = root_block->AppendOp();
op->SetType("heter_listen_and_serv");
op->SetInput("X", {});
......@@ -129,7 +142,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
CreateVarsOnScope(scope, place);
}
void StartHeterServer() {
void StartHeterServer(std::string endpoint) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
......@@ -137,7 +150,7 @@ void StartHeterServer() {
platform::CPUDeviceContext ctx(place);
LOG(INFO) << "before GetHeterListenAndServProgram";
GetHeterListenAndServProgram(&program);
GetHeterListenAndServProgram(&program, endpoint);
auto prepared = exe.Prepare(program, 0);
LOG(INFO) << "before InitTensorsOnServer";
......@@ -150,13 +163,12 @@ void StartHeterServer() {
TEST(HETER_LISTEN_AND_SERV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:19944";
std::string previous_endpoint = "127.0.0.1:19944";
std::string endpoint = get_ip_port();
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
FLAGS_eager_delete_tensor_gb = -1;
std::thread server_thread(StartHeterServer);
std::thread server_thread(StartHeterServer, endpoint);
sleep(1);
auto b_rpc_service = distributed::HeterServer::GetInstance();
b_rpc_service->WaitServerReady();
using MicroScope =
......
......@@ -17,6 +17,9 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <random>
#include <sstream>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h"
......@@ -33,6 +36,19 @@ USE_OP_ITSELF(scale);
std::shared_ptr<distributed::HeterServer> b_rpc_service;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist(4444, 25000);
int port = dist(rng);
std::string ip_port;
std::stringstream temp_str;
temp_str << "127.0.0.1:";
temp_str << port;
temp_str >> ip_port;
return ip_port;
}
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
......@@ -178,16 +194,17 @@ void StartSendAndRecvServer(std::string endpoint) {
b_rpc_service->SetRequestHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
std::thread server_thread(std::bind(RunServer, b_rpc_service));
RunServer(b_rpc_service);
// std::thread server_thread(std::bind(RunServer, b_rpc_service));
server_thread.join();
// server_thread.join();
}
TEST(SENDANDRECV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:4444";
std::string previous_endpoint = "127.0.0.1:4444";
std::string endpoint = get_ip_port();
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint);
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <random>
#include <sstream>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h"
......@@ -36,6 +38,19 @@ USE_OP(send_and_recv);
std::shared_ptr<distributed::HeterServer> b_rpc_service;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist(4444, 25000);
int port = dist(rng);
std::string ip_port;
std::stringstream temp_str;
temp_str << "127.0.0.1:";
temp_str << port;
temp_str >> ip_port;
return ip_port;
}
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
......@@ -151,16 +166,18 @@ void StartSendAndRecvServer(std::string endpoint) {
b_rpc_service->SetRequestHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
std::thread server_thread(std::bind(RunServer, b_rpc_service));
server_thread.join();
RunServer(b_rpc_service);
// std::thread server_thread(std::bind(RunServer, b_rpc_service));
// server_thread.join();
}
TEST(SENDANDRECV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:4444";
std::string previous_endpoint = "127.0.0.1:4444";
std::string endpoint = get_ip_port();
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint);
......@@ -260,8 +277,10 @@ TEST(SENDANDRECV, CPU) {
exe.RunPreparedContext(prepared.get(), scope, false);
LOG(INFO) << "client wait for Pop";
auto task = (*task_queue_)[0]->Pop();
LOG(INFO) << "client get from task queue";
PADDLE_ENFORCE_EQ(
task.first, "x",
platform::errors::InvalidArgument(
......
......@@ -19,6 +19,9 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <random>
#include <sstream>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h"
......@@ -40,20 +43,30 @@ USE_OP(send_and_recv);
std::shared_ptr<distributed::HeterServer> b_rpc_service2;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist(4444, 25000);
int port = dist(rng);
std::string ip_port;
std::stringstream temp_str;
temp_str << "127.0.0.1:";
temp_str << port;
temp_str >> ip_port;
return ip_port;
}
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
framework::OpDesc* op = block->AppendOp();
op->SetType("scale");
op->SetInput("X", {"x"});
op->SetOutput("Out", {"res"});
op->SetAttr("scale", 0.5f);
auto& out = *root_block->Var("res");
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({1, 10});
return block;
}
......@@ -172,15 +185,17 @@ void StartSendAndRecvServer(std::string endpoint) {
b_rpc_service2->SetRequestHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
std::thread server_thread(std::bind(RunServer, b_rpc_service2));
server_thread.join();
RunServer(b_rpc_service2);
// std::thread server_thread(std::bind(RunServer, b_rpc_service2));
// server_thread.join();
}
TEST(SENDANDRECV, GPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
std::string endpoint = "127.0.0.1:4445";
std::string previous_endpoint = "127.0.0.1:4445";
std::string endpoint = get_ip_port();
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service2 = distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册