未验证 提交 2f41f389 编写于 作者: Z ziyoujiyi 提交者: GitHub

heter & multi-cloud brpc communication (#40965)

* back fl

* delete ssl cert

* .

* make warning

* .

* unittest paral degree

* solve unittest

* heter & multi cloud commm ready

* .

* .
上级 3a7761a0
......@@ -39,8 +39,8 @@ cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
cc_library(communicator SRCS communicator/communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS})
cc_library(ps_service SRCS ps_service/service.cc DEPS communicator client server boost ${RPC_DEPS})
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_server SRCS heter_server.cc DEPS heter_client brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
set_source_files_properties(ps_service/graph_py_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_py_service SRCS ps_service/graph_py_service.cc DEPS ps_service)
......
......@@ -55,6 +55,8 @@ DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
DEFINE_int32(pserver_sparse_table_shard_num, 1000,
"sparse table shard for save & load");
DEFINE_int32(heter_world_size, 100, "group size"); // 可配置
namespace paddle {
namespace framework {
class Scope;
......
......@@ -13,18 +13,14 @@
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/split.h"
DECLARE_int32(rpc_deadline);
DECLARE_int32(pserver_timeout_ms);
namespace paddle {
namespace distributed {
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
bool HeterClient::is_initialized_ = false;
std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr;
int GetMicroId(const platform::DeviceContext& ctx,
const framework::Scope* scope) {
......@@ -54,58 +50,21 @@ int GetMicroId(const platform::DeviceContext& ctx,
return micro_id;
}
void HeterClient::MainThread() {
while (running_) {
RpcProfilerControl();
}
}
void HeterClient::Stop() {
running_ = false;
if (!is_initialized_) {
VLOG(3) << "HeterClient is not inited, do nothing";
} else {
if (main_thread_) {
auto status = StopHeterWorker();
status.wait();
main_thread_->join();
main_thread_.reset(nullptr);
}
VLOG(3) << "HeterClient Stop Done";
}
}
void HeterClient::FinalizeWorker() {
running_ = false;
if (!is_initialized_) {
VLOG(3) << "HeterClient is not inited, do nothing";
} else {
if (main_thread_) {
main_thread_->join();
main_thread_.reset(nullptr);
}
VLOG(3) << "HeterClient Stop Done";
}
}
std::future<int32_t> HeterClient::StopHeterWorker() {
return SendCmd(-1, PS_STOP_SERVER, {});
}
void HeterClient::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = StartProfiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = StopProfiler();
stop_status.wait();
do_server_profiler_ = false;
}
}
std::future<int32_t> HeterClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> HeterClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
void HeterClient::CreateClient2XpuConnection() {
......@@ -156,27 +115,24 @@ void HeterClient::SendAndRecvAsync(
1);
const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope;
const std::string message_name_val = message_name;
const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name;
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: "
<< message_name_val;
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: " << message_name;
brpc::Channel* channel = nullptr;
distributed::MultiVarMsg request;
OnHeterRpcDone* closure = new OnHeterRpcDone([p_ctx, p_scope](void* done) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PADDLE_ENFORCE_NE(
closure->cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
closure->cntl.ErrorText()));
VLOG(4) << "call heter_worker success";
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
message_name, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
int micro_id = GetMicroId(ctx, p_scope);
......@@ -188,6 +144,19 @@ void HeterClient::SendAndRecvAsync(
} else if (mode == "backward") {
int num = minibatch_id % previous_xpu_channels_.size();
channel = previous_xpu_channels_[num].get();
} else if (mode == "send_to_switch") {
VLOG(4) << "calling switch service";
// auto promise = std::make_shared<std::promise<int32_t>>();
// closure->add_promise(promise);
// std::future<int> fut = promise->get_future();
// int idx = 1; // for test
// LOG(INFO) << "xpu_channels_ size: " << xpu_channels_.size();
// channel = xpu_channels_[idx].get(); // 为了适配 send_and_recv op
// ::paddle::distributed::PsService_Stub stub(channel);
// stub.SendToSwitch(&closure->cntl, &request, &closure->response,
// closure); fut.wait();
VLOG(4) << "calling switch service done";
return;
}
::paddle::distributed::PsService_Stub stub(channel);
stub.SendAndRecvVariable(&closure->cntl, &request, &closure->response,
......@@ -229,13 +198,209 @@ std::future<int32_t> HeterClient::SendCmd(
return fut;
}
std::future<int32_t> HeterClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
int HeterClient::Send(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_names) {
const framework::Scope* p_scope = &scope; // 注意是 const
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
PADDLE_ENFORCE_NE(
closure->cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendToSwitch meets brpc error, error message is %s",
closure->cntl.ErrorText()));
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
// 2. set req send_var_names(<string>)
for (auto& send_var_name : send_var_names) {
request.add_send_var_names(send_var_name);
}
// 3. set req var_messages(<VarMessage>)
for (auto& send_var_name : send_var_names) {
auto* send_var_msg = request.add_var_messages();
send_var_msg->set_varname(send_var_name);
framework::Variable* var = p_scope->FindVar(send_var_name);
butil::IOBuf temp_iobuf;
if (var->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf);
} else if (var->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf);
}
request_io_buffer.append(temp_iobuf);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (send_switch_channels_.empty()) {
LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]";
if (xpu_channels_.empty()) {
LOG(ERROR) << "xpu_channels_ is null";
}
send_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = send_switch_channels_[0].get();
// brpc::Channel* channel = xpu_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
VLOG(4) << "waiting SendToSwitch response result......";
fut.wait();
VLOG(4) << "Send done";
return 0;
}
std::future<int32_t> HeterClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
int HeterClient::Send(int group_id, const std::vector<std::string>& var_names,
const std::vector<int>& vars_len, void* data_ptr,
int64_t data_size) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
LOG(ERROR) << "Send meets brpc error, err msg is %s"
<< closure->cntl.ErrorText();
}
});
distributed::MultiVarMsg request;
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
std::string message_name = "send and save";
request.set_message_name(message_name);
request.set_group_id(group_id);
for (auto& send_var_name : var_names) {
request.add_send_var_names(send_var_name);
}
for (auto var_len : vars_len) {
request.add_vars_len(var_len);
}
auto& request_buffer = closure->cntl.request_attachment();
request_buffer.append(reinterpret_cast<void*>(data_ptr),
data_size * sizeof(float));
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (send_switch_channels_.empty()) {
LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]";
if (xpu_channels_.empty()) {
LOG(ERROR) << "xpu_channels_ is null";
}
send_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = send_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
fut.wait();
return 0;
}
int HeterClient::Recv(const platform::DeviceContext& ctx,
framework::Scope& recv_scope, // NOLINT
const std::string& message_name,
const std::vector<std::string>& recv_var_names) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
VLOG(4) << "Recv service call done";
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
VLOG(4) << "HeterClient::RecvFromSwitch meets "
"brpc error, error message is %s"
<< closure->cntl.ErrorText();
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
// 2. set req recv_var_names(<string>)
for (auto& recv_var_name : recv_var_names) {
request.add_recv_var_names(recv_var_name);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (recv_switch_channels_.empty()) {
LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]";
if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null";
}
recv_switch_channels_.push_back(xpu_channels_[1]);
}
brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait();
VLOG(4) << "RecvFromSwitch done";
// save in worker
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto& res_io_buffer = closure->cntl.response_attachment();
VLOG(4) << "entering DeserializeFromMultiVarMsgAndIOBuf";
distributed::DeserializeFromMultiVarMsgAndIOBuf(
closure->response, &res_io_buffer, cpu_dev_ctx, &recv_scope);
VLOG(4) << "Recv done";
return 0;
}
} // end namespace distributed
int HeterClient::Recv(int group_id, const std::vector<std::string>& var_names,
void* data_ptr, int64_t data_size) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
LOG(ERROR) << "Recv meets brpc error, err msg is %s"
<< closure->cntl.ErrorText();
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request;
std::string message_name = "query and recv";
request.set_message_name(message_name);
request.set_group_id(group_id);
for (auto& recv_var_name : var_names) {
request.add_recv_var_names(recv_var_name);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (recv_switch_channels_.empty()) {
LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]";
if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null";
}
recv_switch_channels_.push_back(xpu_channels_[1]);
}
brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait();
VLOG(4) << "RecvFromSwitch done";
// save in worker
auto& res_io_buffer = closure->cntl.response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(data_ptr),
data_size * sizeof(float));
VLOG(4) << "Recv done";
return 0;
}
} // namespace distributed
} // end namespace paddle
......@@ -32,13 +32,14 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
DECLARE_int32(pserver_timeout_ms);
namespace paddle {
namespace distributed {
......@@ -51,24 +52,72 @@ class OnHeterRpcDone : public google::protobuf::Closure {
public:
explicit OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
virtual ~OnHeterRpcDone() {}
void Run() {
std::unique_ptr<OnHeterRpcDone> self_guard(this);
handler_(this);
void Run() { handler_(this); }
void add_promise(std::shared_ptr<std::promise<int32_t>>& promise) { // NOLINT
_promises.push_back(promise);
}
void set_promise_value(int value) {
for (auto& promise : _promises) {
promise->set_value(value);
}
}
int CheckResponse() { return 0; }
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
HeterRpcCallbackFunc handler_;
MultiVariableMessage request;
MultiVariableMessage response;
PsResponseMessage ps_response;
brpc::Controller cntl;
// PsRequestMessage *request(size_t i) { return &_requests[i]; }
// PsResponseMessage *response(size_t i) { return &_responses[i]; }
// std::vector<PsRequestMessage> _requests;
// std::vector<PsResponseMessage> _responses;
// std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
class HeterClient {
public:
virtual ~HeterClient() {}
HeterClient() {
running_ = true;
main_thread_.reset(
new std::thread(std::bind(&HeterClient::MainThread, this)));
void InitClientChannels(bool need_encrypt,
const std::vector<std::string>& node_list,
int32_t peer_role) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = FLAGS_pserver_timeout_ms;
std::vector<std::shared_ptr<brpc::Channel>>* client_channels = nullptr;
if (peer_role == PEER_ROLE_IS_SWITCH) {
options.ssl_options.enable = need_encrypt;
client_channels = &peer_switch_channels_;
} else if (peer_role == PEER_ROLE_IS_WORKER) {
client_channels = &peer_worker_channels_;
} else {
LOG(ERROR) << "init switch client failed, peer_role not valid";
}
(*client_channels).resize(node_list.size());
for (size_t i = 0; i < node_list.size(); ++i) {
(*client_channels)[i].reset(new brpc::Channel());
if ((*client_channels)[i]->Init(node_list[i].c_str(), "", &options) !=
0) {
VLOG(0) << "client channel init failed! try again";
auto ip_port = paddle::string::Split(node_list[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if ((*client_channels)[i]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "client channel init failed! peer ip_port = "
<< int_ip_port;
}
}
}
VLOG(4) << "InitClientChannels success";
}
void CreateClient2XpuConnection();
......@@ -80,14 +129,28 @@ class HeterClient {
const std::vector<std::string>& recv_var_name,
const std::string& mode = "forward");
int Send(int group_id, const std::vector<std::string>& var_names,
const std::vector<int>& vars_len, void* data_ptr, int64_t data_size);
int Send(const platform::DeviceContext& ctx, const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_names);
int Recv(int group_id, const std::vector<std::string>& var_names,
void* data_ptr, int64_t data_size);
int Recv(const platform::DeviceContext& ctx,
framework::Scope& recv_scope, // NOLINT
const std::string& message_name,
const std::vector<std::string>& recv_var_names);
// HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoint,
const std::vector<std::string>& previous_endpoint,
const int& trainer_id) {
if (NULL == s_instance_) {
is_initialized_ = true;
s_instance_.reset(new paddle::distributed::HeterClient());
s_instance_.reset(new HeterClient());
s_instance_->SetXpuList(endpoint);
s_instance_->SetPreviousXpuList(previous_endpoint);
s_instance_->SetTrainerID(trainer_id);
......@@ -96,13 +159,29 @@ class HeterClient {
return s_instance_;
}
void Stop();
// switch client singleton
static HeterClient& GetSwitchInstance(
const std::vector<std::string>& peer_endpoints, int32_t peer_role) {
static HeterClient switch_s_instance_;
if (peer_endpoints.empty()) {
VLOG(4) << "init switch client failed, null peer_endpoints";
}
VLOG(4) << "peer role is: " << peer_role
<< ", addr is: " << peer_endpoints[0];
switch_s_instance_.SetPeerSwitchList(peer_endpoints);
switch_s_instance_.InitClientChannels(false, peer_endpoints, peer_role);
return switch_s_instance_;
}
void FinalizeWorker();
void SetPeerSwitchList(const std::vector<std::string>& peer_endpoints) {
peer_switch_list_ = peer_endpoints;
}
void MainThread();
void SetPeerWorkerList(const std::vector<std::string>& worker_endpoints) {
peer_worker_list_ = worker_endpoints;
}
void RpcProfilerControl();
void Stop();
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string>& params);
......@@ -124,20 +203,32 @@ class HeterClient {
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
public:
std::vector<std::string> send_switch_list_;
std::vector<std::string> recv_switch_list_;
std::vector<std::string> peer_switch_list_;
std::vector<std::string> peer_worker_list_;
std::vector<std::shared_ptr<brpc::Channel>> send_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> recv_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> peer_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> peer_worker_channels_;
private:
HeterClient() {}
HeterClient& operator=(const HeterClient&);
HeterClient(const HeterClient&);
static std::shared_ptr<HeterClient> s_instance_;
static bool is_initialized_;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_;
DISABLE_COPY_AND_ASSIGN(HeterClient);
// DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_;
std::vector<std::string> previous_xpu_list_;
bool running_ = false;
int trainer_id_;
bool do_server_profiler_ = false;
};
} // end namespace distributed
......
......@@ -13,21 +13,28 @@
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_server.h"
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace distributed {
// DEFINE_string(cert_path, "./cert.pem", "cert.pem path");
// DEFINE_string(key_path, "./key.pem", "key.pem path");
std::shared_ptr<HeterServer> HeterServer::s_instance_ = NULL;
std::shared_ptr<HeterServer> HeterServer::s_instance_ = nullptr;
void HeterServer::RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
service_.RegisterServiceHandler(message_name, func);
}
void HeterServer::StartHeterService() {
void HeterServer::StartHeterService(bool neeed_encrypt) {
server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (neeed_encrypt) {
options.ssl_options.default_cert.certificate = "/cert.pem";
options.ssl_options.default_cert.private_key = "/key.pem";
}
if (server_.Start(endpoint_.c_str(), &options) != 0) {
VLOG(0) << "HeterServer start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_, ':');
......@@ -47,16 +54,50 @@ void HeterServer::StartHeterService() {
ready_ = 1;
}
condition_ready_.notify_all();
VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_;
std::unique_lock<std::mutex> running_lock(mutex_);
cv_.wait(running_lock, [&] {
VLOG(1) << "Heter Server is Stop? " << stoped_;
VLOG(4) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
VLOG(4) << "start service done";
}
void HeterServer::SetEndPoint(const std::string& endpoint) {
endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
void HeterServer::StartHeterInterService(bool neeed_encrypt) {
server_inter_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (neeed_encrypt) {
options.ssl_options.default_cert.certificate = "/cert.pem";
options.ssl_options.default_cert.private_key = "/key.pem";
}
if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) {
VLOG(4) << "switch inter server start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_inter_, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) {
LOG(ERROR) << "switch inter server start failed, ip_port= "
<< int_ip_port;
}
} else {
VLOG(4) << "switch inter server server start success! listen on "
<< endpoint_inter_;
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
stoped_ = false;
ready_ = 1;
}
condition_ready_.notify_all();
VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_;
std::unique_lock<std::mutex> running_lock(mutex_);
cv_.wait(running_lock, [&] {
VLOG(4) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
VLOG(4) << "start service done";
}
void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }
......@@ -64,35 +105,180 @@ void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
while (!this->ready_) {
sleep(1);
}
}
int32_t HeterService::stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response,
int SendAndRecvVariableHandler::SaveInSwitchWithShard(
const MultiVarMsg* request, PsResponseMessage* response,
brpc::Controller* cntl) {
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("heter_worker_%s_profile", endpoint_));
VLOG(4) << "entering SaveInSwitchWithShard";
int32_t group_id = request->group_id();
auto& local_shard = _local_shards[group_id];
auto& request_io_buffer = cntl->request_attachment();
butil::IOBufBytesIterator io_buffer_itr(request_io_buffer);
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
const auto& var_name = request->send_var_names(idx);
const auto& var_len = request->vars_len(idx);
auto itr = local_shard.find(var_name);
if (itr != local_shard.end()) {
LOG(INFO) << "var: " << var_name << "has not been consumed!"
<< "check again";
WaitForVarsConsumed(group_id, var_name);
}
auto& value = local_shard[var_name];
value.resize(var_len);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(value.data()),
var_len * sizeof(float));
VLOG(4) << "saved data in shards: ";
for (uint32_t i = 0; i < local_shard[var_name].size(); i++) {
VLOG(4) << *(local_shard[var_name].data() + i);
}
}
VLOG(4) << "SaveInSwitchWithShard success";
return 0;
}
int32_t HeterService::start_profiler(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
platform::EnableProfiler(platform::ProfilerState::kAll);
int SendAndRecvVariableHandler::QueryInSwitchWithShard(
const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) {
VLOG(4) << "entering QueryInSwitchWithShard";
int32_t group_id = request->group_id();
VLOG(4) << "group id: " << group_id;
auto& local_shard = _local_shards[group_id];
auto& response_io_buffer = cntl->response_attachment();
auto req_var_nums = request->recv_var_names_size();
std::vector<std::string> req_var_names(req_var_nums);
for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) {
req_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto msg_name = request->message_name();
response->set_message_name(msg_name);
for (auto& req_var_name : req_var_names) {
VLOG(4) << "req var name: " << req_var_name;
response->add_send_var_names(req_var_name);
auto itr = local_shard.find(req_var_name);
if (itr == local_shard.end()) {
LOG(INFO) << "var: " << req_var_name << " not found in shards";
WaitForVarsProduced(group_id, req_var_name);
}
LOG(INFO) << "var: " << req_var_name << " found in shards";
itr = local_shard.find(req_var_name);
auto& value = itr.value();
response_io_buffer.append(value.data(), value.size() * sizeof(float));
value.resize(0); // 标记位
}
VLOG(4) << "heter server QueryInSwitchWithShard done";
return 0;
}
int32_t HeterService::stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response,
int SendAndRecvVariableHandler::SaveInSwitchWithScope(
const MultiVarMsg* request, PsResponseMessage* response,
brpc::Controller* cntl) {
auto client_id = request.client_id();
stop_cpu_worker_set_.insert(client_id);
if (stop_cpu_worker_set_.size() == fan_in_) {
is_exit_ = true;
VLOG(3) << "Stop heter Service done.";
VLOG(4) << "entering SaveInSwitchWithScope";
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto message_name = request->message_name();
VLOG(4) << "message_name in heter server: " << message_name;
std::unique_lock<std::mutex> lk(scope_mutex_);
auto local_scope = local_scope_ptr.get();
if (!local_scope) {
LOG(ERROR) << "local_scope_ptr is null in SaveInSwitchWithScope";
}
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
const auto& msg = request->var_messages(idx);
std::string var_name = msg.varname();
auto* var_exist_ptr = local_scope->FindVar(var_name);
if (!var_exist_ptr) {
VLOG(4) << "not find var: " << var_name << " in local_scope";
}
vars_table[var_name] += 1;
VLOG(4) << "saved var_name: " << var_name
<< ", cnt = " << vars_table[var_name];
}
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(*request, &request_io_buffer,
cpu_dev_ctx, local_scope);
lk.unlock();
while (true) {
int ret = 0;
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
ret |= vars_table[request->var_messages(idx).varname()];
}
if (!ret) {
VLOG(4) << "all saved vars consumed";
break;
}
VLOG(4) << "waiting consume result......";
sleep(1);
}
VLOG(4) << "SaveInSwitchWithScope success";
return 0;
}
int SendAndRecvVariableHandler::QueryInSwitchWithScope(
const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) {
VLOG(4) << "entering QueryInSwitchWithScope";
auto local_scope = local_scope_ptr.get();
if (!local_scope) {
LOG(INFO) << "local_scope is null";
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
// get req message_name & req_var_names
auto msg_name = request->message_name();
auto req_var_nums = request->recv_var_names_size();
std::vector<std::string> req_var_names(req_var_nums);
for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) {
req_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
// 1. fill message_name(string)
response->set_message_name(msg_name);
// 2. fill var_names(string)
for (auto& req_var_name : req_var_names) {
response->add_send_var_names(req_var_name);
}
// 3. fill var_messages(VarMessage)
for (auto& req_var_name : req_var_names) {
LOG(INFO) << "query var_name: " << req_var_name;
auto* send_var_msg = response->add_var_messages();
send_var_msg->set_varname(req_var_name);
framework::Variable* var_ptr;
while (true) {
var_ptr = local_scope->FindVar(req_var_name);
if (!var_ptr) {
LOG(INFO) << "local_scope not find var: " << req_var_name;
} else {
break;
}
sleep(1);
}
butil::IOBuf temp_iobuf;
if (var_ptr->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf);
} else if (var_ptr->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf);
}
response_io_buffer.append(temp_iobuf);
}
for (auto& req_var_name : req_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_table[req_var_name] -= 1;
VLOG(4) << "remained var: " << req_var_name
<< ", cnt = " << vars_table[req_var_name];
lk.unlock();
}
VLOG(4) << "heter server QueryInSwitchWithScope done";
return 0;
}
} // end namespace distributed
} // end namespace paddle
} // namespace paddle
......@@ -59,6 +59,12 @@ enum PsCmdID {
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39;
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG = 40;
PEER_ROLE_IS_WORKER = 41;
PEER_ROLE_IS_SWITCH = 42;
PS_SAVE_WITH_SCOPE = 43;
PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46;
}
message PsRequestMessage {
......@@ -117,9 +123,16 @@ message MultiVariableMessage {
repeated string send_var_names = 2;
repeated string recv_var_names = 3;
repeated VariableMessage var_messages = 4;
optional bytes data = 5;
repeated int32 vars_len = 6;
optional int32 group_id = 7;
};
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
rpc SendToWorker(MultiVariableMessage) returns (PsResponseMessage);
rpc SendToSwitch(MultiVariableMessage) returns (PsResponseMessage);
rpc SendS2S(MultiVariableMessage) returns (PsResponseMessage);
rpc RecvFromSwitch(MultiVariableMessage) returns (MultiVariableMessage);
};
......@@ -300,7 +300,7 @@ if(WITH_DISTRIBUTE)
lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto fleet_executor ${BRPC_DEP})
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
......@@ -320,7 +320,7 @@ if(WITH_DISTRIBUTE)
index_sampler index_wrapper sampler index_dataset_proto
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
......
......@@ -6,9 +6,9 @@ include(operators)
set(DISTRIBUTE_DEPS "")
list(APPEND DISTRIBUTE_DEPS fleet ps_service brpc_utils heter_server heter_client ps_framework_proto framework_proto sendrecv_rpc brpc leveldb ssl crypto protobuf gflags glog zlib snappy device_context)
list(APPEND DISTRIBUTE_DEPS executor fleet ps_service brpc_utils heter_server heter_client ps_framework_proto framework_proto sendrecv_rpc brpc leveldb ssl crypto protobuf gflags glog zlib snappy device_context)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
......@@ -37,3 +37,6 @@ cc_test(send_and_recv_gpu_test SRCS send_and_recv_op_gpu_test.cc DEPS executor s
set_source_files_properties(heter_listen_and_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(heter_listen_and_server_test SRCS heter_listen_and_server_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
#set_source_files_properties(heter_cloud_comm_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined PADDLE_WITH_PSCORE
#include <stdlib.h>
#include <memory>
#include <random>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::distributed;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
void CreateVarsOnScope(framework::Scope* scope) {
auto var1 = scope->Var("w");
var1->GetMutable<phi::SelectedRows>();
auto var2 = scope->Var("x");
var2->GetMutable<framework::LoDTensor>();
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope);
auto w = scope->Var("w")->GetMutable<phi::SelectedRows>();
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});
for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
auto ptr = w_value->mutable_data<float>(*place);
for (int64_t i = 0; i < w_value->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) {
x_ptr[i] = 1.0;
}
}
void StartSwitchServer(
std::shared_ptr<distributed::HeterServer>& switch_server_ptr, // NOLINT
std::vector<std::string> endpoints,
std::vector<std::string> peer_endpoints) {
switch_server_ptr->SetPeerEndPoints(peer_endpoints);
switch_server_ptr->SetEndPoint(endpoints[0]);
/*
std::shared_ptr<distributed::SendAndRecvVariableHandler> b_req_handler;
b_req_handler.reset(new distributed::SendAndRecvVariableHandler());
switch_server_ptr->SetServiceHandler(b_req_handler);
switch_server_ptr->SetLocalScope();
switch_server_ptr->RegisterServiceHandler(
std::to_string(distributed::PS_SAVE_WITH_SCOPE),
[&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->SaveInSwitchWithScope(request, response, cntl);
});
switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_SAVE_WITH_SHARD),
[&](const MultiVarMsg* request, MultiVarMsg*
response,
brpc::Controller* cntl) -> int {
return b_req_handler->SaveInSwitchWithShard(
request, response, cntl);
});
switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SCOPE),
[&](const MultiVarMsg* request, MultiVarMsg*
response,
brpc::Controller* cntl) -> int {
return b_req_handler->QueryInSwitchWithScope(
request, response, cntl);
});
switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SHARD),
[&](const MultiVarMsg* request, MultiVarMsg*
response,
brpc::Controller* cntl) -> int {
return b_req_handler->QueryInSwitchWithShard(
request, response, cntl);
});
*/
switch_server_ptr->StartHeterService(false);
}
void StartSwitchInterServer(
std::shared_ptr<distributed::HeterServer>& switch_server_ptr, // NOLINT
std::vector<std::string> endpoints,
std::vector<std::string> peer_endpoints) {
switch_server_ptr->SetPeerEndPoints(peer_endpoints);
switch_server_ptr->SetInterEndpoint(endpoints[1]);
switch_server_ptr->StartHeterInterService(false);
}
TEST(HETERSENDANDRECV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
// 启动 switch server A & B
std::string switch_a_endpoint("127.0.0.1:6000");
std::string switch_a_endpoint_inter("127.0.0.1:6100");
std::string switch_b_endpoint_inter("127.0.0.1:7100");
std::string switch_b_endpoint("127.0.0.1:7000");
std::shared_ptr<distributed::HeterServer> switch_server_ptr_a =
std::make_shared<distributed::HeterServer>();
std::vector<std::string> end_points{switch_a_endpoint};
std::vector<std::string> peer_endpoints{switch_b_endpoint_inter};
std::thread switch_server_a_thread(StartSwitchServer,
std::ref(switch_server_ptr_a), end_points,
peer_endpoints);
switch_server_ptr_a->WaitServerReady();
std::shared_ptr<distributed::HeterServer> switch_server_ptr_b =
std::make_shared<distributed::HeterServer>();
end_points = {switch_b_endpoint, switch_b_endpoint_inter};
peer_endpoints = {};
std::thread switch_server_b_thread(StartSwitchServer,
std::ref(switch_server_ptr_b), end_points,
peer_endpoints);
switch_server_ptr_b->WaitServerReady();
end_points = {switch_b_endpoint, switch_b_endpoint_inter};
peer_endpoints = {};
std::thread switch_server_b_thread_inter(StartSwitchInterServer,
std::ref(switch_server_ptr_b),
end_points, peer_endpoints);
switch_server_ptr_b->WaitServerReady();
// 获取 client 实例
std::shared_ptr<distributed::HeterClient> heter_client_ptr_ =
distributed::HeterClient::GetInstance(
{switch_a_endpoint, switch_b_endpoint}, {}, 0);
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
framework::Executor exe(place);
framework::ProgramDesc program;
exe.Prepare(program, 0); // solve undefined symbol: tensor_table.cc
std::shared_ptr<framework::Scope> send_scope_ptr =
std::make_shared<framework::Scope>();
int64_t rows_numel = 10;
InitTensorsOnClient(send_scope_ptr.get(), &place, rows_numel);
LOG(INFO) << "InitTensorsOnClient done";
auto send_async = [&]() -> void {
/*
//std::string message_name =
std::to_string(distributed::PS_SAVE_WITH_SCOPE);
std::string message_name = "send and save";
std::vector<std::string> send_var_names{"w", "x"};
int ret = heter_client_ptr_->Send(ctx, *send_scope_ptr, message_name,
send_var_names);
if (!ret) {
LOG(ERROR) << ">>>> worker send success";
}
*/
///*
std::vector<int> vars_len{2, 4};
std::vector<float> values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
int64_t data_size = 6;
std::vector<std::string> send_var_names{"w", "x"};
int group_id = 0;
int ret = heter_client_ptr_->Send(group_id, send_var_names, vars_len,
values.data(), data_size);
if (!ret) {
LOG(INFO) << ">>>> worker send success";
}
//*/
};
std::thread send_thread(send_async);
/*
std::string message_name = std::to_string(distributed::PS_QUERY_WITH_SCOPE);
std::vector<std::string> recv_var_names{"w", "x"};
std::shared_ptr<framework::Scope> recv_scope_ptr =
std::make_shared<framework::Scope>();
int ret = heter_client_ptr_->Recv(ctx, *recv_scope_ptr, message_name,
recv_var_names);
if (!ret && recv_scope_ptr->FindVar("w") && recv_scope_ptr->FindVar("x")) {
LOG(INFO) << ">>>> worker recv success";
} else {
LOG(INFO) << "worker recv failed";
}
*/
///*
int group_id = 0;
std::vector<std::string> recv_var_names{"w", "x"};
std::vector<float> values;
int data_size = 6;
values.resize(data_size);
int ret = heter_client_ptr_->Recv(group_id, recv_var_names, values.data(),
data_size);
if (!ret) {
VLOG(4) << "queried data is: ";
for (auto f : values) {
VLOG(4) << f << " ";
}
LOG(INFO) << ">>>> worker recv success";
}
//*/
send_thread.join();
switch_server_ptr_a->Stop();
LOG(INFO) << "switch server A stopped";
switch_server_ptr_b->Stop();
LOG(INFO) << "switch server B stopped";
switch_server_a_thread.join();
LOG(INFO) << "switch_server_a_thread joined";
switch_server_b_thread.join();
LOG(INFO) << "switch_server_b_thread joined";
switch_server_b_thread_inter.join();
LOG(INFO) << "switch_server_b_thread_inter joined";
}
#endif
......@@ -88,21 +88,20 @@ void HeterListenAndServOp::RunAsyncLoop(framework::ProgramDesc *program) const {
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
}
for (size_t i = 0; i < block_list.size(); ++i) {
auto blkid = block_list[i];
auto it = message_to_block_id.find_value(blkid);
rpc_service_->RegisterServiceHandler(
heter_server_->RegisterServiceHandler(
it->first, [&](const MultiVarMsg *request, MultiVarMsg *response,
brpc::Controller *cntl) -> int {
return request_send_and_recv_handler_->Handle(request, response,
return send_and_recv_variable_handler_->Handle(request, response,
cntl);
});
}
while (true) {
if (rpc_service_->IsExit() || rpc_service_->IsStop()) {
rpc_service_->Stop();
if (heter_server_->IsExit() || heter_server_->IsStop()) {
heter_server_->Stop();
VLOG(0) << "get exit. rpc_processor stop!";
break;
}
......@@ -110,8 +109,9 @@ void HeterListenAndServOp::RunAsyncLoop(framework::ProgramDesc *program) const {
} // while(true)
}
void RunServer(std::shared_ptr<paddle::distributed::HeterServer> service) {
service->StartHeterService();
void RunServer(
std::shared_ptr<paddle::distributed::HeterServer> heter_server_ptr) {
heter_server_ptr->StartHeterService();
}
void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
......@@ -126,16 +126,16 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
auto fan_in = Attr<int>("fanin");
auto inputs = Inputs("X");
PADDLE_ENFORCE_EQ(rpc_service_, nullptr,
PADDLE_ENFORCE_EQ(heter_server_, nullptr,
platform::errors::PreconditionNotMet(
"RPC service has been created unexpectedly."));
std::string endpoint = Attr<std::string>("endpoint");
VLOG(4) << "pserver_id: " << pserver_id << ", end_point:" << endpoint;
rpc_service_ = distributed::HeterServer::GetInstance();
rpc_service_->SetEndPoint(endpoint);
rpc_service_->SetFanin(fan_in);
heter_server_ = distributed::HeterServer::GetInstance();
heter_server_->SetEndPoint(endpoint);
heter_server_->SetFanin(fan_in);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>("optimize_blocks");
......@@ -146,20 +146,18 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
auto *program = optimize_blocks[0]->Program();
request_send_and_recv_handler_.reset(
new distributed::RequestSendAndRecvHandler());
request_send_and_recv_handler_->SetScope(&scope);
request_send_and_recv_handler_->SetDevCtx(&dev_ctx);
rpc_service_->SetRequestHandler(request_send_and_recv_handler_);
send_and_recv_variable_handler_.reset(
new distributed::SendAndRecvVariableHandler());
send_and_recv_variable_handler_->SetScope(&scope);
send_and_recv_variable_handler_->SetDevCtx(&dev_ctx);
heter_server_->SetServiceHandler(send_and_recv_variable_handler_);
VLOG(2) << "RunAsyncLoop";
auto message_to_block_id_str =
Attr<std::vector<std::string>>("message_to_block_id");
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
server_thread_.reset(new std::thread(RunServer, heter_server_));
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
heter_server_->WaitServerReady();
RunAsyncLoop(program);
VLOG(3) << "Wait for Server_thread_ stop";
(server_thread_.get())->join();
......
......@@ -34,7 +34,7 @@ limitations under the License. */
namespace paddle {
namespace distributed {
class HeterRequestHandler;
class ServiceHandlerBase;
class HeterServer;
} // namespace distributed
} // namespace paddle
......@@ -82,10 +82,10 @@ class HeterListenAndServOp : public framework::OperatorBase {
const platform::Place& dev_place) const override;
protected:
mutable std::shared_ptr<paddle::distributed::HeterServer> rpc_service_;
mutable std::shared_ptr<paddle::distributed::HeterServer> heter_server_;
mutable std::shared_ptr<std::thread> server_thread_;
mutable std::shared_ptr<paddle::distributed::RequestSendAndRecvHandler>
request_send_and_recv_handler_;
mutable std::shared_ptr<paddle::distributed::SendAndRecvVariableHandler>
send_and_recv_variable_handler_;
};
} // namespace operators
......
......@@ -142,7 +142,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
CreateVarsOnScope(scope, place);
}
void StartHeterServer(std::string endpoint) {
void RunHeterServerOp(std::string endpoint) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
......@@ -167,10 +167,10 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
FLAGS_eager_delete_tensor_gb = -1;
std::thread server_thread(StartHeterServer, endpoint);
std::thread server_thread(RunHeterServerOp, endpoint);
sleep(1);
auto b_rpc_service = distributed::HeterServer::GetInstance();
b_rpc_service->WaitServerReady();
auto heter_server_ptr_ = distributed::HeterServer::GetInstance();
heter_server_ptr_->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
......@@ -185,8 +185,8 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
(*micro_scope).push_back(micro_scope_0);
(*micro_scope).push_back(micro_scope_1);
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
heter_server_ptr_->SetMicroBatchScopes(micro_scopes);
heter_server_ptr_->SetMiniBatchScopes(mini_scopes);
using TaskQueue =
std::unordered_map<int,
......@@ -198,17 +198,13 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
SharedTaskQueue task_queue_(new TaskQueue{});
(*task_queue_)[0] = std::make_shared<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>();
b_rpc_service->SetTaskQueue(task_queue_);
heter_server_ptr_->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client =
distributed::HeterClient* heter_client_ptr_ =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
framework::Scope* scope = (*micro_scope)[0];
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
......@@ -224,8 +220,8 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
std::vector<std::string> recv_var = {};
LOG(INFO) << "before SendAndRecvAsync";
rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var,
"forward");
heter_client_ptr_->SendAndRecvAsync(ctx, *scope, in_var_name, send_var,
recv_var, "forward");
auto task = (*task_queue_)[0]->Pop();
PADDLE_ENFORCE_EQ(
task.first, "x",
......@@ -234,15 +230,15 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel);
LOG(INFO) << "before SendAndRecvAsync 2";
rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name, send_var,
recv_var, "backward");
heter_client_ptr_->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name,
send_var, recv_var, "backward");
auto task2 = (*task_queue_)[0]->Pop();
PADDLE_ENFORCE_EQ(
task2.first, "x",
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->Stop();
heter_client_ptr_->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
LOG(INFO) << "end server thread join";
......
......@@ -34,8 +34,6 @@ using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP_ITSELF(scale);
std::shared_ptr<distributed::HeterServer> b_rpc_service;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
......@@ -171,31 +169,32 @@ void StartSendAndRecvServer(std::string endpoint) {
InitTensorsOnServer(&scope, &place, 10);
LOG(INFO) << "end InitTensorsOnServer";
std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler;
b_req_handler.reset(new distributed::RequestSendAndRecvHandler());
std::shared_ptr<distributed::SendAndRecvVariableHandler> b_req_handler;
b_req_handler.reset(new distributed::SendAndRecvVariableHandler());
LOG(INFO) << "before SetDevCtx";
b_req_handler->SetDevCtx(&ctx);
LOG(INFO) << "before SetScope";
b_req_handler->SetScope(&scope);
LOG(INFO) << "before HeterServer::GetInstance";
b_rpc_service = distributed::HeterServer::GetInstance();
b_rpc_service->SetEndPoint(endpoint);
std::shared_ptr<distributed::HeterServer> heter_server_ptr_ =
distributed::HeterServer::GetInstance();
heter_server_ptr_->SetEndPoint(endpoint);
LOG(INFO) << "before HeterServer::RegisterServiceHandler";
b_rpc_service->RegisterServiceHandler(
heter_server_ptr_->RegisterServiceHandler(
in_var_name, [&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service->RegisterServiceHandler(
heter_server_ptr_->RegisterServiceHandler(
in_var_name2, [&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service->SetRequestHandler(b_req_handler);
heter_server_ptr_->SetServiceHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
RunServer(b_rpc_service);
// std::thread server_thread(std::bind(RunServer, b_rpc_service));
RunServer(heter_server_ptr_);
// std::thread server_thread(std::bind(RunServer, heter_server_ptr_));
// server_thread.join();
}
......@@ -206,9 +205,10 @@ TEST(SENDANDRECV, CPU) {
std::string endpoint = get_ip_port();
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance();
std::shared_ptr<distributed::HeterServer> heter_server_ptr_ =
distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint);
b_rpc_service->WaitServerReady();
heter_server_ptr_->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
......@@ -223,8 +223,8 @@ TEST(SENDANDRECV, CPU) {
(*micro_scope).push_back(micro_scope_0);
(*micro_scope).push_back(micro_scope_1);
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
heter_server_ptr_->SetMicroBatchScopes(micro_scopes);
heter_server_ptr_->SetMiniBatchScopes(mini_scopes);
using TaskQueue =
std::unordered_map<int,
......@@ -236,17 +236,13 @@ TEST(SENDANDRECV, CPU) {
SharedTaskQueue task_queue_(new TaskQueue{});
(*task_queue_)[0] = std::make_shared<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>();
b_rpc_service->SetTaskQueue(task_queue_);
heter_server_ptr_->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client =
distributed::HeterClient* heter_client_ptr_ =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
framework::Scope* scope = (*micro_scope)[0];
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
......@@ -262,8 +258,8 @@ TEST(SENDANDRECV, CPU) {
std::vector<std::string> recv_var = {};
LOG(INFO) << "before SendAndRecvAsync";
rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var,
"forward");
heter_client_ptr_->SendAndRecvAsync(ctx, *scope, in_var_name, send_var,
recv_var, "forward");
LOG(INFO) << "client wait for Pop";
auto task = (*task_queue_)[0]->Pop();
......@@ -276,7 +272,7 @@ TEST(SENDANDRECV, CPU) {
InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel);
LOG(INFO) << "before SendAndRecvAsync 2";
std::string in_var_name2("y");
rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name2,
heter_client_ptr_->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name2,
send_var, recv_var, "backward");
LOG(INFO) << "after SendAndRecvAsync 2";
......@@ -286,8 +282,7 @@ TEST(SENDANDRECV, CPU) {
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->FinalizeWorker();
b_rpc_service->Stop();
heter_server_ptr_->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
LOG(INFO) << "end server thread join";
......
......@@ -36,8 +36,6 @@ using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP_ITSELF(scale);
USE_OP(send_and_recv);
std::shared_ptr<distributed::HeterServer> b_rpc_service;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
......@@ -148,14 +146,15 @@ void StartSendAndRecvServer(std::string endpoint) {
InitTensorsOnServer(&scope, &place, 10);
LOG(INFO) << "end InitTensorsOnServer";
std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler;
b_req_handler.reset(new distributed::RequestSendAndRecvHandler());
std::shared_ptr<distributed::SendAndRecvVariableHandler> b_req_handler;
b_req_handler.reset(new distributed::SendAndRecvVariableHandler());
LOG(INFO) << "before SetDevCtx";
b_req_handler->SetDevCtx(&ctx);
LOG(INFO) << "before SetScope";
b_req_handler->SetScope(&scope);
LOG(INFO) << "before HeterServer::GetInstance";
b_rpc_service = distributed::HeterServer::GetInstance();
std::shared_ptr<distributed::HeterServer> b_rpc_service =
distributed::HeterServer::GetInstance();
b_rpc_service->SetEndPoint(endpoint);
LOG(INFO) << "before HeterServer::RegisterServiceHandler";
b_rpc_service->RegisterServiceHandler(
......@@ -164,7 +163,7 @@ void StartSendAndRecvServer(std::string endpoint) {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service->SetRequestHandler(b_req_handler);
b_rpc_service->SetServiceHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
RunServer(b_rpc_service);
......@@ -179,7 +178,8 @@ TEST(SENDANDRECV, CPU) {
std::string endpoint = get_ip_port();
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance();
std::shared_ptr<distributed::HeterServer> b_rpc_service =
distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint);
b_rpc_service->WaitServerReady();
using MicroScope =
......@@ -292,7 +292,6 @@ TEST(SENDANDRECV, CPU) {
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->FinalizeWorker();
b_rpc_service->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
......
......@@ -167,8 +167,8 @@ void StartSendAndRecvServer(std::string endpoint) {
InitTensorsOnServer(&scope, &place, 10);
LOG(INFO) << "end InitTensorsOnServer";
std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler;
b_req_handler.reset(new distributed::RequestSendAndRecvHandler());
std::shared_ptr<distributed::SendAndRecvVariableHandler> b_req_handler;
b_req_handler.reset(new distributed::SendAndRecvVariableHandler());
LOG(INFO) << "before SetDevCtx";
b_req_handler->SetDevCtx(&ctx);
LOG(INFO) << "before SetScope";
......@@ -183,7 +183,7 @@ void StartSendAndRecvServer(std::string endpoint) {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service2->SetRequestHandler(b_req_handler);
b_rpc_service2->SetServiceHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
RunServer(b_rpc_service2);
......@@ -228,13 +228,11 @@ TEST(SENDANDRECV, GPU) {
b_rpc_service2->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::shared_ptr<distributed::HeterClient> heter_client_ptr_ =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0);
if (heter_client_ptr_ == nullptr) {
LOG(ERROR) << "heter_client_ptr_ is null";
}
framework::Scope* scope = (*micro_scope)[0];
platform::CUDAPlace place;
......@@ -316,7 +314,6 @@ TEST(SENDANDRECV, GPU) {
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->FinalizeWorker();
b_rpc_service2->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
......
......@@ -1174,6 +1174,7 @@ SIXTH_PARALLEL_JOB_NEW = [
]
LOWEST_PARALLEL_JOB_NEW = [
'heter_cloud_comm_cpu_test',
'heter_server_test',
'test_scatter_op',
'test_trt_convert_hard_sigmoid',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册