未验证 提交 2f41f389 编写于 作者: Z ziyoujiyi 提交者: GitHub

heter & multi-cloud brpc communication (#40965)

* back fl

* delete ssl cert

* .

* make warning

* .

* unittest paral degree

* solve unittest

* heter & multi cloud commm ready

* .

* .
上级 3a7761a0
......@@ -39,8 +39,8 @@ cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
cc_library(communicator SRCS communicator/communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS})
cc_library(ps_service SRCS ps_service/service.cc DEPS communicator client server boost ${RPC_DEPS})
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_server SRCS heter_server.cc DEPS heter_client brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
set_source_files_properties(ps_service/graph_py_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_py_service SRCS ps_service/graph_py_service.cc DEPS ps_service)
......
......@@ -55,6 +55,8 @@ DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
DEFINE_int32(pserver_sparse_table_shard_num, 1000,
"sparse table shard for save & load");
DEFINE_int32(heter_world_size, 100, "group size"); // 可配置
namespace paddle {
namespace framework {
class Scope;
......
......@@ -13,18 +13,14 @@
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/split.h"
DECLARE_int32(rpc_deadline);
DECLARE_int32(pserver_timeout_ms);
namespace paddle {
namespace distributed {
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
bool HeterClient::is_initialized_ = false;
std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr;
int GetMicroId(const platform::DeviceContext& ctx,
const framework::Scope* scope) {
......@@ -54,58 +50,21 @@ int GetMicroId(const platform::DeviceContext& ctx,
return micro_id;
}
void HeterClient::MainThread() {
while (running_) {
RpcProfilerControl();
}
}
void HeterClient::Stop() {
running_ = false;
if (!is_initialized_) {
VLOG(3) << "HeterClient is not inited, do nothing";
} else {
if (main_thread_) {
auto status = StopHeterWorker();
status.wait();
main_thread_->join();
main_thread_.reset(nullptr);
}
VLOG(3) << "HeterClient Stop Done";
}
}
void HeterClient::FinalizeWorker() {
running_ = false;
if (!is_initialized_) {
VLOG(3) << "HeterClient is not inited, do nothing";
} else {
if (main_thread_) {
main_thread_->join();
main_thread_.reset(nullptr);
}
VLOG(3) << "HeterClient Stop Done";
}
auto status = StopHeterWorker();
status.wait();
}
std::future<int32_t> HeterClient::StopHeterWorker() {
return SendCmd(-1, PS_STOP_SERVER, {});
}
void HeterClient::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = StartProfiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = StopProfiler();
stop_status.wait();
do_server_profiler_ = false;
}
}
std::future<int32_t> HeterClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> HeterClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
void HeterClient::CreateClient2XpuConnection() {
......@@ -156,27 +115,24 @@ void HeterClient::SendAndRecvAsync(
1);
const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope;
const std::string message_name_val = message_name;
const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name;
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: "
<< message_name_val;
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: " << message_name;
brpc::Channel* channel = nullptr;
distributed::MultiVarMsg request;
OnHeterRpcDone* closure = new OnHeterRpcDone([p_ctx, p_scope](void* done) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PADDLE_ENFORCE_NE(
closure->cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
closure->cntl.ErrorText()));
VLOG(4) << "call heter_worker success";
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
message_name, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
int micro_id = GetMicroId(ctx, p_scope);
......@@ -188,6 +144,19 @@ void HeterClient::SendAndRecvAsync(
} else if (mode == "backward") {
int num = minibatch_id % previous_xpu_channels_.size();
channel = previous_xpu_channels_[num].get();
} else if (mode == "send_to_switch") {
VLOG(4) << "calling switch service";
// auto promise = std::make_shared<std::promise<int32_t>>();
// closure->add_promise(promise);
// std::future<int> fut = promise->get_future();
// int idx = 1; // for test
// LOG(INFO) << "xpu_channels_ size: " << xpu_channels_.size();
// channel = xpu_channels_[idx].get(); // 为了适配 send_and_recv op
// ::paddle::distributed::PsService_Stub stub(channel);
// stub.SendToSwitch(&closure->cntl, &request, &closure->response,
// closure); fut.wait();
VLOG(4) << "calling switch service done";
return;
}
::paddle::distributed::PsService_Stub stub(channel);
stub.SendAndRecvVariable(&closure->cntl, &request, &closure->response,
......@@ -229,13 +198,209 @@ std::future<int32_t> HeterClient::SendCmd(
return fut;
}
std::future<int32_t> HeterClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
int HeterClient::Send(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_names) {
const framework::Scope* p_scope = &scope; // 注意是 const
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
PADDLE_ENFORCE_NE(
closure->cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendToSwitch meets brpc error, error message is %s",
closure->cntl.ErrorText()));
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
// 2. set req send_var_names(<string>)
for (auto& send_var_name : send_var_names) {
request.add_send_var_names(send_var_name);
}
// 3. set req var_messages(<VarMessage>)
for (auto& send_var_name : send_var_names) {
auto* send_var_msg = request.add_var_messages();
send_var_msg->set_varname(send_var_name);
framework::Variable* var = p_scope->FindVar(send_var_name);
butil::IOBuf temp_iobuf;
if (var->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf);
} else if (var->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf);
}
request_io_buffer.append(temp_iobuf);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (send_switch_channels_.empty()) {
LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]";
if (xpu_channels_.empty()) {
LOG(ERROR) << "xpu_channels_ is null";
}
send_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = send_switch_channels_[0].get();
// brpc::Channel* channel = xpu_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
VLOG(4) << "waiting SendToSwitch response result......";
fut.wait();
VLOG(4) << "Send done";
return 0;
}
std::future<int32_t> HeterClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
int HeterClient::Send(int group_id, const std::vector<std::string>& var_names,
const std::vector<int>& vars_len, void* data_ptr,
int64_t data_size) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
LOG(ERROR) << "Send meets brpc error, err msg is %s"
<< closure->cntl.ErrorText();
}
});
distributed::MultiVarMsg request;
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
std::string message_name = "send and save";
request.set_message_name(message_name);
request.set_group_id(group_id);
for (auto& send_var_name : var_names) {
request.add_send_var_names(send_var_name);
}
for (auto var_len : vars_len) {
request.add_vars_len(var_len);
}
auto& request_buffer = closure->cntl.request_attachment();
request_buffer.append(reinterpret_cast<void*>(data_ptr),
data_size * sizeof(float));
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (send_switch_channels_.empty()) {
LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]";
if (xpu_channels_.empty()) {
LOG(ERROR) << "xpu_channels_ is null";
}
send_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = send_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
fut.wait();
return 0;
}
} // end namespace distributed
int HeterClient::Recv(const platform::DeviceContext& ctx,
framework::Scope& recv_scope, // NOLINT
const std::string& message_name,
const std::vector<std::string>& recv_var_names) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
VLOG(4) << "Recv service call done";
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
VLOG(4) << "HeterClient::RecvFromSwitch meets "
"brpc error, error message is %s"
<< closure->cntl.ErrorText();
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
// 2. set req recv_var_names(<string>)
for (auto& recv_var_name : recv_var_names) {
request.add_recv_var_names(recv_var_name);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (recv_switch_channels_.empty()) {
LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]";
if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null";
}
recv_switch_channels_.push_back(xpu_channels_[1]);
}
brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait();
VLOG(4) << "RecvFromSwitch done";
// save in worker
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto& res_io_buffer = closure->cntl.response_attachment();
VLOG(4) << "entering DeserializeFromMultiVarMsgAndIOBuf";
distributed::DeserializeFromMultiVarMsgAndIOBuf(
closure->response, &res_io_buffer, cpu_dev_ctx, &recv_scope);
VLOG(4) << "Recv done";
return 0;
}
int HeterClient::Recv(int group_id, const std::vector<std::string>& var_names,
void* data_ptr, int64_t data_size) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
LOG(ERROR) << "Recv meets brpc error, err msg is %s"
<< closure->cntl.ErrorText();
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request;
std::string message_name = "query and recv";
request.set_message_name(message_name);
request.set_group_id(group_id);
for (auto& recv_var_name : var_names) {
request.add_recv_var_names(recv_var_name);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (recv_switch_channels_.empty()) {
LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]";
if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null";
}
recv_switch_channels_.push_back(xpu_channels_[1]);
}
brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait();
VLOG(4) << "RecvFromSwitch done";
// save in worker
auto& res_io_buffer = closure->cntl.response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(data_ptr),
data_size * sizeof(float));
VLOG(4) << "Recv done";
return 0;
}
} // namespace distributed
} // end namespace paddle
......@@ -32,13 +32,14 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
DECLARE_int32(pserver_timeout_ms);
namespace paddle {
namespace distributed {
......@@ -51,24 +52,72 @@ class OnHeterRpcDone : public google::protobuf::Closure {
public:
explicit OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
virtual ~OnHeterRpcDone() {}
void Run() {
std::unique_ptr<OnHeterRpcDone> self_guard(this);
handler_(this);
void Run() { handler_(this); }
void add_promise(std::shared_ptr<std::promise<int32_t>>& promise) { // NOLINT
_promises.push_back(promise);
}
void set_promise_value(int value) {
for (auto& promise : _promises) {
promise->set_value(value);
}
}
int CheckResponse() { return 0; }
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
HeterRpcCallbackFunc handler_;
MultiVariableMessage request;
MultiVariableMessage response;
PsResponseMessage ps_response;
brpc::Controller cntl;
// PsRequestMessage *request(size_t i) { return &_requests[i]; }
// PsResponseMessage *response(size_t i) { return &_responses[i]; }
// std::vector<PsRequestMessage> _requests;
// std::vector<PsResponseMessage> _responses;
// std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
class HeterClient {
public:
virtual ~HeterClient() {}
HeterClient() {
running_ = true;
main_thread_.reset(
new std::thread(std::bind(&HeterClient::MainThread, this)));
void InitClientChannels(bool need_encrypt,
const std::vector<std::string>& node_list,
int32_t peer_role) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = FLAGS_pserver_timeout_ms;
std::vector<std::shared_ptr<brpc::Channel>>* client_channels = nullptr;
if (peer_role == PEER_ROLE_IS_SWITCH) {
options.ssl_options.enable = need_encrypt;
client_channels = &peer_switch_channels_;
} else if (peer_role == PEER_ROLE_IS_WORKER) {
client_channels = &peer_worker_channels_;
} else {
LOG(ERROR) << "init switch client failed, peer_role not valid";
}
(*client_channels).resize(node_list.size());
for (size_t i = 0; i < node_list.size(); ++i) {
(*client_channels)[i].reset(new brpc::Channel());
if ((*client_channels)[i]->Init(node_list[i].c_str(), "", &options) !=
0) {
VLOG(0) << "client channel init failed! try again";
auto ip_port = paddle::string::Split(node_list[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if ((*client_channels)[i]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "client channel init failed! peer ip_port = "
<< int_ip_port;
}
}
}
VLOG(4) << "InitClientChannels success";
}
void CreateClient2XpuConnection();
......@@ -80,14 +129,28 @@ class HeterClient {
const std::vector<std::string>& recv_var_name,
const std::string& mode = "forward");
int Send(int group_id, const std::vector<std::string>& var_names,
const std::vector<int>& vars_len, void* data_ptr, int64_t data_size);
int Send(const platform::DeviceContext& ctx, const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_names);
int Recv(int group_id, const std::vector<std::string>& var_names,
void* data_ptr, int64_t data_size);
int Recv(const platform::DeviceContext& ctx,
framework::Scope& recv_scope, // NOLINT
const std::string& message_name,
const std::vector<std::string>& recv_var_names);
// HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoint,
const std::vector<std::string>& previous_endpoint,
const int& trainer_id) {
if (NULL == s_instance_) {
is_initialized_ = true;
s_instance_.reset(new paddle::distributed::HeterClient());
s_instance_.reset(new HeterClient());
s_instance_->SetXpuList(endpoint);
s_instance_->SetPreviousXpuList(previous_endpoint);
s_instance_->SetTrainerID(trainer_id);
......@@ -96,13 +159,29 @@ class HeterClient {
return s_instance_;
}
void Stop();
// switch client singleton
static HeterClient& GetSwitchInstance(
const std::vector<std::string>& peer_endpoints, int32_t peer_role) {
static HeterClient switch_s_instance_;
if (peer_endpoints.empty()) {
VLOG(4) << "init switch client failed, null peer_endpoints";
}
VLOG(4) << "peer role is: " << peer_role
<< ", addr is: " << peer_endpoints[0];
switch_s_instance_.SetPeerSwitchList(peer_endpoints);
switch_s_instance_.InitClientChannels(false, peer_endpoints, peer_role);
return switch_s_instance_;
}
void FinalizeWorker();
void SetPeerSwitchList(const std::vector<std::string>& peer_endpoints) {
peer_switch_list_ = peer_endpoints;
}
void MainThread();
void SetPeerWorkerList(const std::vector<std::string>& worker_endpoints) {
peer_worker_list_ = worker_endpoints;
}
void RpcProfilerControl();
void Stop();
std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
const std::vector<std::string>& params);
......@@ -124,20 +203,32 @@ class HeterClient {
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
public:
std::vector<std::string> send_switch_list_;
std::vector<std::string> recv_switch_list_;
std::vector<std::string> peer_switch_list_;
std::vector<std::string> peer_worker_list_;
std::vector<std::shared_ptr<brpc::Channel>> send_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> recv_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> peer_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> peer_worker_channels_;
private:
HeterClient() {}
HeterClient& operator=(const HeterClient&);
HeterClient(const HeterClient&);
static std::shared_ptr<HeterClient> s_instance_;
static bool is_initialized_;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_;
DISABLE_COPY_AND_ASSIGN(HeterClient);
// DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_;
std::vector<std::string> previous_xpu_list_;
bool running_ = false;
int trainer_id_;
bool do_server_profiler_ = false;
};
} // end namespace distributed
......
......@@ -13,21 +13,28 @@
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_server.h"
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace distributed {
// DEFINE_string(cert_path, "./cert.pem", "cert.pem path");
// DEFINE_string(key_path, "./key.pem", "key.pem path");
std::shared_ptr<HeterServer> HeterServer::s_instance_ = NULL;
std::shared_ptr<HeterServer> HeterServer::s_instance_ = nullptr;
void HeterServer::RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
service_.RegisterServiceHandler(message_name, func);
}
void HeterServer::StartHeterService() {
void HeterServer::StartHeterService(bool neeed_encrypt) {
server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (neeed_encrypt) {
options.ssl_options.default_cert.certificate = "/cert.pem";
options.ssl_options.default_cert.private_key = "/key.pem";
}
if (server_.Start(endpoint_.c_str(), &options) != 0) {
VLOG(0) << "HeterServer start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_, ':');
......@@ -47,16 +54,50 @@ void HeterServer::StartHeterService() {
ready_ = 1;
}
condition_ready_.notify_all();
VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_;
std::unique_lock<std::mutex> running_lock(mutex_);
cv_.wait(running_lock, [&] {
VLOG(1) << "Heter Server is Stop? " << stoped_;
VLOG(4) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
VLOG(4) << "start service done";
}
void HeterServer::SetEndPoint(const std::string& endpoint) {
endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
void HeterServer::StartHeterInterService(bool neeed_encrypt) {
server_inter_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (neeed_encrypt) {
options.ssl_options.default_cert.certificate = "/cert.pem";
options.ssl_options.default_cert.private_key = "/key.pem";
}
if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) {
VLOG(4) << "switch inter server start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_inter_, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) {
LOG(ERROR) << "switch inter server start failed, ip_port= "
<< int_ip_port;
}
} else {
VLOG(4) << "switch inter server server start success! listen on "
<< endpoint_inter_;
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
stoped_ = false;
ready_ = 1;
}
condition_ready_.notify_all();
VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_;
std::unique_lock<std::mutex> running_lock(mutex_);
cv_.wait(running_lock, [&] {
VLOG(4) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
VLOG(4) << "start service done";
}
void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }
......@@ -64,35 +105,180 @@ void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
while (!this->ready_) {
sleep(1);
}
}
int32_t HeterService::stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("heter_worker_%s_profile", endpoint_));
int SendAndRecvVariableHandler::SaveInSwitchWithShard(
const MultiVarMsg* request, PsResponseMessage* response,
brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithShard";
int32_t group_id = request->group_id();
auto& local_shard = _local_shards[group_id];
auto& request_io_buffer = cntl->request_attachment();
butil::IOBufBytesIterator io_buffer_itr(request_io_buffer);
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
const auto& var_name = request->send_var_names(idx);
const auto& var_len = request->vars_len(idx);
auto itr = local_shard.find(var_name);
if (itr != local_shard.end()) {
LOG(INFO) << "var: " << var_name << "has not been consumed!"
<< "check again";
WaitForVarsConsumed(group_id, var_name);
}
auto& value = local_shard[var_name];
value.resize(var_len);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(value.data()),
var_len * sizeof(float));
VLOG(4) << "saved data in shards: ";
for (uint32_t i = 0; i < local_shard[var_name].size(); i++) {
VLOG(4) << *(local_shard[var_name].data() + i);
}
}
VLOG(4) << "SaveInSwitchWithShard success";
return 0;
}
int32_t HeterService::start_profiler(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
platform::EnableProfiler(platform::ProfilerState::kAll);
int SendAndRecvVariableHandler::QueryInSwitchWithShard(
const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) {
VLOG(4) << "entering QueryInSwitchWithShard";
int32_t group_id = request->group_id();
VLOG(4) << "group id: " << group_id;
auto& local_shard = _local_shards[group_id];
auto& response_io_buffer = cntl->response_attachment();
auto req_var_nums = request->recv_var_names_size();
std::vector<std::string> req_var_names(req_var_nums);
for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) {
req_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto msg_name = request->message_name();
response->set_message_name(msg_name);
for (auto& req_var_name : req_var_names) {
VLOG(4) << "req var name: " << req_var_name;
response->add_send_var_names(req_var_name);
auto itr = local_shard.find(req_var_name);
if (itr == local_shard.end()) {
LOG(INFO) << "var: " << req_var_name << " not found in shards";
WaitForVarsProduced(group_id, req_var_name);
}
LOG(INFO) << "var: " << req_var_name << " found in shards";
itr = local_shard.find(req_var_name);
auto& value = itr.value();
response_io_buffer.append(value.data(), value.size() * sizeof(float));
value.resize(0); // 标记位
}
VLOG(4) << "heter server QueryInSwitchWithShard done";
return 0;
}
int32_t HeterService::stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response,
brpc::Controller* cntl) {
auto client_id = request.client_id();
stop_cpu_worker_set_.insert(client_id);
if (stop_cpu_worker_set_.size() == fan_in_) {
is_exit_ = true;
VLOG(3) << "Stop heter Service done.";
int SendAndRecvVariableHandler::SaveInSwitchWithScope(
const MultiVarMsg* request, PsResponseMessage* response,
brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithScope";
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto message_name = request->message_name();
VLOG(4) << "message_name in heter server: " << message_name;
std::unique_lock<std::mutex> lk(scope_mutex_);
auto local_scope = local_scope_ptr.get();
if (!local_scope) {
LOG(ERROR) << "local_scope_ptr is null in SaveInSwitchWithScope";
}
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
const auto& msg = request->var_messages(idx);
std::string var_name = msg.varname();
auto* var_exist_ptr = local_scope->FindVar(var_name);
if (!var_exist_ptr) {
VLOG(4) << "not find var: " << var_name << " in local_scope";
}
vars_table[var_name] += 1;
VLOG(4) << "saved var_name: " << var_name
<< ", cnt = " << vars_table[var_name];
}
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(*request, &request_io_buffer,
cpu_dev_ctx, local_scope);
lk.unlock();
while (true) {
int ret = 0;
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
ret |= vars_table[request->var_messages(idx).varname()];
}
if (!ret) {
VLOG(4) << "all saved vars consumed";
break;
}
VLOG(4) << "waiting consume result......";
sleep(1);
}
VLOG(4) << "SaveInSwitchWithScope success";
return 0;
}
int SendAndRecvVariableHandler::QueryInSwitchWithScope(
const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) {
VLOG(4) << "entering QueryInSwitchWithScope";
auto local_scope = local_scope_ptr.get();
if (!local_scope) {
LOG(INFO) << "local_scope is null";
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
// get req message_name & req_var_names
auto msg_name = request->message_name();
auto req_var_nums = request->recv_var_names_size();
std::vector<std::string> req_var_names(req_var_nums);
for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) {
req_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
// 1. fill message_name(string)
response->set_message_name(msg_name);
// 2. fill var_names(string)
for (auto& req_var_name : req_var_names) {
response->add_send_var_names(req_var_name);
}
// 3. fill var_messages(VarMessage)
for (auto& req_var_name : req_var_names) {
LOG(INFO) << "query var_name: " << req_var_name;
auto* send_var_msg = response->add_var_messages();
send_var_msg->set_varname(req_var_name);
framework::Variable* var_ptr;
while (true) {
var_ptr = local_scope->FindVar(req_var_name);
if (!var_ptr) {
LOG(INFO) << "local_scope not find var: " << req_var_name;
} else {
break;
}
sleep(1);
}
butil::IOBuf temp_iobuf;
if (var_ptr->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf);
} else if (var_ptr->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf);
}
response_io_buffer.append(temp_iobuf);
}
for (auto& req_var_name : req_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_table[req_var_name] -= 1;
VLOG(4) << "remained var: " << req_var_name
<< ", cnt = " << vars_table[req_var_name];
lk.unlock();
}
VLOG(4) << "heter server QueryInSwitchWithScope done";
return 0;
}
} // end namespace distributed
} // end namespace paddle
} // namespace paddle
......@@ -22,11 +22,14 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -51,108 +54,37 @@ class Scope;
} // namespace paddle
DECLARE_double(eager_delete_tensor_gb);
DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(heter_world_size);
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
class HeterService;
using MultiVarMsg = MultiVariableMessage;
using VarMsg = VariableMessage;
typedef int32_t (HeterService::*serviceHandlerFunc)(
using serviceHandler = std::function<int32_t(
const PsRequestMessage& request, PsResponseMessage& response, // NOLINT
brpc::Controller* cntl);
brpc::Controller* cntl)>;
using HeterServiceHandler =
std::function<int32_t(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>;
typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
HeterServiceHandler;
using HeterRpcCallbackFunc = std::function<void(void*)>;
class HeterService : public ::paddle::distributed::PsService {
class ServiceHandlerBase {
public:
HeterService() {
_service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
_service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler;
}
ServiceHandlerBase() : dev_ctx_(nullptr), scope_(nullptr) {}
virtual ~HeterService() {}
virtual ~ServiceHandlerBase() {}
virtual void service(::google::protobuf::RpcController* controller,
const PsRequestMessage* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
response->set_err_code(0);
response->set_err_msg("");
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(*request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
}
void SendAndRecvVariable(::google::protobuf::RpcController* controller,
const MultiVarMsg* request, MultiVarMsg* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
std::string message_name = request->message_name();
auto itr = handler_map_.find(message_name);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
PADDLE_ENFORCE_NE(
itr, handler_map_.end(),
platform::errors::InvalidArgument(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_",
message_name));
itr->second(request, response, cntl);
}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
handler_map_[message_name] = func;
}
int32_t ForceExit() {
VLOG(3) << "heter service force exit";
is_exit_ = true;
return 0;
}
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
bool IsExit() { return is_exit_; }
private:
int32_t stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl);
int32_t start_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl);
void SetScope(const framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
int32_t stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl);
virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) = 0;
private:
std::string endpoint_;
std::unordered_map<std::string, HeterServiceHandler> handler_map_;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
std::unordered_set<int> stop_cpu_worker_set_;
int fan_in_;
bool is_exit_ = false;
protected:
const platform::DeviceContext* dev_ctx_;
const framework::Scope* scope_;
};
using SharedMiniScope =
......@@ -163,31 +95,15 @@ using SharedTaskQueue = std::shared_ptr<
std::unordered_map<int, std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
class HeterRequestHandler {
public:
HeterRequestHandler() : dev_ctx_(nullptr), scope_(nullptr) {}
virtual ~HeterRequestHandler() {}
void SetScope(const framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) = 0;
protected:
const platform::DeviceContext* dev_ctx_;
const framework::Scope* scope_;
};
class RequestSendAndRecvHandler final : public HeterRequestHandler {
class SendAndRecvVariableHandler final : public ServiceHandlerBase {
public:
RequestSendAndRecvHandler() {
SendAndRecvVariableHandler() {
this->num_microbatch_ = 0;
this->num_minibatch_ = 0;
_local_shards.reset(new shard_type[FLAGS_heter_world_size]);
}
virtual ~RequestSendAndRecvHandler() {}
virtual ~SendAndRecvVariableHandler() {}
void SetMiniScopes(SharedMiniScope mini_scopes) {
mini_scopes_ = mini_scopes;
......@@ -209,11 +125,47 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
return (*task_queue_).size();
}
int SaveInSwitchWithScope(const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl);
void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) {
auto& local_shard = _local_shards[group_id];
while (local_shard.find(var_name) != local_shard.end()) {
if (local_shard[var_name].size() == 0) {
break;
}
VLOG(4) << "waiting consume result......";
sleep(1);
}
return;
}
void WaitForVarsProduced(int32_t group_id, const std::string& var_name) {
auto& local_shard = _local_shards[group_id];
while (local_shard.find(var_name) == local_shard.end()) {
VLOG(4) << "waiting produce result......";
sleep(1);
}
return;
}
int SaveInSwitchWithShard(const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl);
int QueryInSwitchWithShard(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl);
int QueryInSwitchWithScope(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl);
void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; }
int Handle(const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) override {
platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle",
LOG(INFO) << "entered Handle";
platform::RecordEvent record_event("SendAndRecvVariableHandler->Handle",
platform::TracerEventType::Communication,
1);
FLAGS_eager_delete_tensor_gb = -1;
......@@ -241,7 +193,6 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
auto* tensor = var->GetMutable<framework::LoDTensor>();
auto data = reinterpret_cast<const float*>(tensor->data());
auto micro_id = static_cast<int>(data[0]);
int minibatch_index = micro_id / 10;
int microbatch_index = micro_id % 10;
......@@ -249,10 +200,7 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
std::unique_lock<std::mutex> lk(scope_mutex_);
if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) {
lk.unlock();
// PADDLE_ENFORCE_EQ(
// (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1,
// platform::errors::InvalidArgument(
// "minibatch index should in current trainer"));
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1,
platform::errors::InvalidArgument(
......@@ -282,6 +230,7 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
// blocking queue handles multi thread
(*task_queue_)[minibatch_index]->Push(
std::make_pair(message_name, microbatch_index));
auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
......@@ -295,6 +244,12 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
return 0;
}
public:
using shard_type = SparseTableShard<std::string, FixedFeatureValue>;
std::shared_ptr<paddle::framework::Scope> local_scope_ptr; // for switch
std::unordered_map<std::string, uint32_t> vars_table;
std::unique_ptr<shard_type[]> _local_shards;
private:
// share with HeterPipelineTrainer
SharedMiniScope mini_scopes_{nullptr};
......@@ -310,15 +265,254 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler {
SharedTaskQueue task_queue_;
};
class HeterService : public PsService {
public:
HeterService() {
_service_handler_map[PS_STOP_SERVER] =
std::bind(&HeterService::stop_heter_worker, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3);
_service_handler_map[PS_START_PROFILER] =
std::bind(&HeterService::start_profiler, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3);
_service_handler_map[PS_STOP_PROFILER] =
std::bind(&HeterService::stop_profiler, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3);
service_handler_.local_scope_ptr =
std::make_shared<paddle::framework::Scope>();
}
virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller,
const PsRequestMessage* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
response->set_err_code(0);
response->set_err_msg("");
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
return;
}
serviceHandler handler = itr->second;
int service_ret = handler(*request, *response, cntl);
VLOG(4) << "handler in service ret: " << service_ret;
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
}
virtual void SendAndRecvVariable(
::google::protobuf::RpcController* controller, const MultiVarMsg* request,
MultiVarMsg* response, ::google::protobuf::Closure* done) {
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc::ClosureGuard done_guard(done);
std::string message_name = request->message_name();
VLOG(0) << "SendAndRecvVariable message_name: " << message_name;
auto itr = handler_map_.find(message_name);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
LOG(INFO) << "SendAndRecvVariable(client addr) =" << cntl->remote_side();
PADDLE_ENFORCE_NE(
itr, handler_map_.end(),
platform::errors::InvalidArgument(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_",
message_name));
itr->second(request, response, cntl);
// We don't want to call done->Run() here, release the guard.
// done_guard.release();
}
virtual void RecvFromSwitch(::google::protobuf::RpcController* controller,
const MultiVarMsg* request, MultiVarMsg* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// int ret = service_handler_.QueryInSwitchWithScope(request, response,
// cntl);
int ret = service_handler_.QueryInSwitchWithShard(request, response, cntl);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// int ret = itr->second(request, response, cntl);
if (ret != 0) {
LOG(ERROR) << "QueryInSwitchWithScope failed!";
}
// response->set_message_name(message_name);
}
virtual void SendToSwitch(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
VLOG(4) << "entering SendToSwitch";
brpc::ClosureGuard done_guard(done);
auto& switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_SWITCH);
if (switch_client_ptr_.peer_switch_channels_.empty()) {
LOG(ERROR) << "switch_client_ptr_.peer_switch_channels_ null";
}
brpc::Channel* channel = switch_client_ptr_.peer_switch_channels_[0].get();
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset)
OnHeterRpcDone* closure2 = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = closure->CheckResponse();
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
PADDLE_ENFORCE_NE(
closure->cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendS2S meets brpc error, error message is %s",
closure->cntl.ErrorText()));
}
});
auto& std_cntl = closure2->cntl;
std_cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
std_cntl.request_attachment().append(cntl->request_attachment().movable());
auto promise = std::make_shared<std::promise<int32_t>>();
closure2->add_promise(promise);
std::future<int> fut = promise->get_future();
// brpc::Controller std_cntl;
// std_cntl.request_attachment().append(cntl->request_attachment().movable());
PsService_Stub stub(channel);
stub.SendS2S(&std_cntl, request, response, closure2);
cntl->response_attachment().append(
std_cntl.response_attachment().movable());
fut.wait();
VLOG(4) << "SendToSwitch done";
}
void SendS2S(::google::protobuf::RpcController* controller,
const MultiVarMsg* request, PsResponseMessage* response,
::google::protobuf::Closure* done) {
VLOG(4) << "entering SendS2S";
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// int ret = service_handler_.SaveInSwitchWithScope(request, response,
// cntl);
int ret = service_handler_.SaveInSwitchWithShard(request, response, cntl);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// if (itr == handler_map_.end()) {
// LOG(ERROR) << "can not find func handler";
//}
// int ret = itr->second(request, response, cntl);
if (ret != 0) {
LOG(ERROR) << "SaveInSwitchWithScope failed";
}
std::string err_msg = "ok";
response->set_err_msg(err_msg.c_str());
response->set_err_code(ret);
VLOG(4) << "heter server SendS2S done";
}
void SendToWorker(::google::protobuf::RpcController* controller,
const MultiVarMsg* request, PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
VLOG(4) << "SendToWorker(client addr) =" << cntl->remote_side();
auto& switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_WORKER);
VLOG(4) << "in switch client, peer worker 0: "
<< switch_client_ptr_.peer_worker_list_[0];
brpc::Channel* channel = switch_client_ptr_.peer_worker_channels_[0].get();
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PsService_Stub stub(channel);
stub.SendAndRecvVariable(controller, request, &closure->response, done);
// fill response content
std::string err_msg("pass to worker");
response->set_err_msg(err_msg.c_str());
response->set_err_code(0);
}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
handler_map_[message_name] = func;
}
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
void SetInterEndpoint(const std::string& end_point) {
endpoint_inter_ = end_point;
}
void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
peer_endpoints_ = peer_endpoints;
}
void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
void ForceExit() {
VLOG(3) << "heter service force exit";
is_exit_ = true;
return;
}
bool IsExit() { return is_exit_; }
private:
int32_t stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("heter_worker_%s_profile", endpoint_));
return 0;
}
int32_t start_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
platform::EnableProfiler(platform::ProfilerState::kAll);
return 0;
}
int32_t stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
auto client_id = request.client_id();
stop_cpu_worker_set_.insert(client_id);
if (stop_cpu_worker_set_.size() == fan_in_) {
is_exit_ = true;
}
return 0;
}
private:
SendAndRecvVariableHandler service_handler_;
std::string endpoint_;
std::string endpoint_inter_;
// for switch
std::vector<std::string> peer_endpoints_;
std::unordered_map<int32_t, serviceHandler> _service_handler_map;
std::unordered_map<std::string, HeterServiceHandler> handler_map_;
std::unordered_set<int> stop_cpu_worker_set_;
uint32_t fan_in_;
bool is_exit_ = false;
};
class HeterServer {
public:
HeterServer() : ready_(0) {}
virtual ~HeterServer() {}
void Stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_ == true) return;
if (!IsExit()) service_.ForceExit();
VLOG(3) << "HeterServer Stop()";
if (!IsExit()) {
service_.ForceExit();
}
stoped_ = true;
cv_.notify_all();
server_.Stop(1000);
......@@ -327,26 +521,42 @@ class HeterServer {
bool IsStop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_ == true)
return true;
else
return false;
return stoped_;
}
bool IsExit() { return service_.IsExit(); }
HeterServer() : service_(), ready_(0) {}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func);
void StartHeterService();
void StartHeterService(bool need_encrypt = false);
void StartHeterInterService(bool need_encrypt = false);
void SetEndPoint(const std::string& endpoint) {
this->endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
}
void SetLocalScope() {
request_handler_->local_scope_ptr =
std::make_shared<paddle::framework::Scope>();
}
void SetInterEndpoint(const std::string& endpoint) {
this->endpoint_inter_ = endpoint;
service_.SetInterEndpoint(endpoint);
}
void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
this->peer_endpoints_ = peer_endpoints;
service_.SetPeerEndPoints(peer_endpoints);
}
void SetEndPoint(const std::string& endpoint);
void SetFanin(const int& fan_in);
void SetRequestHandler(
std::shared_ptr<RequestSendAndRecvHandler> request_handler) {
void SetServiceHandler(
std::shared_ptr<SendAndRecvVariableHandler> request_handler) {
request_handler_ = request_handler;
}
......@@ -381,11 +591,15 @@ class HeterServer {
std::condition_variable condition_ready_;
bool stoped_ = true;
std::string endpoint_;
std::string endpoint_inter_;
// for switch
std::vector<std::string> peer_endpoints_;
protected:
brpc::Server server_;
brpc::Server server_inter_;
HeterService service_;
std::shared_ptr<RequestSendAndRecvHandler> request_handler_;
std::shared_ptr<SendAndRecvVariableHandler> request_handler_;
DISABLE_COPY_AND_ASSIGN(HeterServer);
std::mutex mutex_ready_;
......
......@@ -59,6 +59,12 @@ enum PsCmdID {
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39;
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG = 40;
PEER_ROLE_IS_WORKER = 41;
PEER_ROLE_IS_SWITCH = 42;
PS_SAVE_WITH_SCOPE = 43;
PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46;
}
message PsRequestMessage {
......@@ -117,9 +123,16 @@ message MultiVariableMessage {
repeated string send_var_names = 2;
repeated string recv_var_names = 3;
repeated VariableMessage var_messages = 4;
optional bytes data = 5;
repeated int32 vars_len = 6;
optional int32 group_id = 7;
};
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
rpc SendToWorker(MultiVariableMessage) returns (PsResponseMessage);
rpc SendToSwitch(MultiVariableMessage) returns (PsResponseMessage);
rpc SendS2S(MultiVariableMessage) returns (PsResponseMessage);
rpc RecvFromSwitch(MultiVariableMessage) returns (MultiVariableMessage);
};
......@@ -300,7 +300,7 @@ if(WITH_DISTRIBUTE)
lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto fleet_executor ${BRPC_DEP})
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
......@@ -320,7 +320,7 @@ if(WITH_DISTRIBUTE)
index_sampler index_wrapper sampler index_dataset_proto
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc fleet_executor)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
......
......@@ -6,9 +6,9 @@ include(operators)
set(DISTRIBUTE_DEPS "")
list(APPEND DISTRIBUTE_DEPS fleet ps_service brpc_utils heter_server heter_client ps_framework_proto framework_proto sendrecv_rpc brpc leveldb ssl crypto protobuf gflags glog zlib snappy device_context)
list(APPEND DISTRIBUTE_DEPS executor fleet ps_service brpc_utils heter_server heter_client ps_framework_proto framework_proto sendrecv_rpc brpc leveldb ssl crypto protobuf gflags glog zlib snappy device_context)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
......@@ -37,3 +37,6 @@ cc_test(send_and_recv_gpu_test SRCS send_and_recv_op_gpu_test.cc DEPS executor s
set_source_files_properties(heter_listen_and_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(heter_listen_and_server_test SRCS heter_listen_and_server_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
#set_source_files_properties(heter_cloud_comm_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(heter_cloud_comm_cpu_test SRCS heter_cloud_comm_cpu_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined PADDLE_WITH_PSCORE
#include <stdlib.h>
#include <memory>
#include <random>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/heter_server.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::distributed;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
void CreateVarsOnScope(framework::Scope* scope) {
auto var1 = scope->Var("w");
var1->GetMutable<phi::SelectedRows>();
auto var2 = scope->Var("x");
var2->GetMutable<framework::LoDTensor>();
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope);
auto w = scope->Var("w")->GetMutable<phi::SelectedRows>();
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});
for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
auto ptr = w_value->mutable_data<float>(*place);
for (int64_t i = 0; i < w_value->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) {
x_ptr[i] = 1.0;
}
}
void StartSwitchServer(
std::shared_ptr<distributed::HeterServer>& switch_server_ptr, // NOLINT
std::vector<std::string> endpoints,
std::vector<std::string> peer_endpoints) {
switch_server_ptr->SetPeerEndPoints(peer_endpoints);
switch_server_ptr->SetEndPoint(endpoints[0]);
/*
std::shared_ptr<distributed::SendAndRecvVariableHandler> b_req_handler;
b_req_handler.reset(new distributed::SendAndRecvVariableHandler());
switch_server_ptr->SetServiceHandler(b_req_handler);
switch_server_ptr->SetLocalScope();
switch_server_ptr->RegisterServiceHandler(
std::to_string(distributed::PS_SAVE_WITH_SCOPE),
[&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->SaveInSwitchWithScope(request, response, cntl);
});
switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_SAVE_WITH_SHARD),
[&](const MultiVarMsg* request, MultiVarMsg*
response,
brpc::Controller* cntl) -> int {
return b_req_handler->SaveInSwitchWithShard(
request, response, cntl);
});
switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SCOPE),
[&](const MultiVarMsg* request, MultiVarMsg*
response,
brpc::Controller* cntl) -> int {
return b_req_handler->QueryInSwitchWithScope(
request, response, cntl);
});
switch_server_ptr->RegisterServiceHandler(std::to_string(distributed::PS_QUERY_WITH_SHARD),
[&](const MultiVarMsg* request, MultiVarMsg*
response,
brpc::Controller* cntl) -> int {
return b_req_handler->QueryInSwitchWithShard(
request, response, cntl);
});
*/
switch_server_ptr->StartHeterService(false);
}
void StartSwitchInterServer(
std::shared_ptr<distributed::HeterServer>& switch_server_ptr, // NOLINT
std::vector<std::string> endpoints,
std::vector<std::string> peer_endpoints) {
switch_server_ptr->SetPeerEndPoints(peer_endpoints);
switch_server_ptr->SetInterEndpoint(endpoints[1]);
switch_server_ptr->StartHeterInterService(false);
}
TEST(HETERSENDANDRECV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
// 启动 switch server A & B
std::string switch_a_endpoint("127.0.0.1:6000");
std::string switch_a_endpoint_inter("127.0.0.1:6100");
std::string switch_b_endpoint_inter("127.0.0.1:7100");
std::string switch_b_endpoint("127.0.0.1:7000");
std::shared_ptr<distributed::HeterServer> switch_server_ptr_a =
std::make_shared<distributed::HeterServer>();
std::vector<std::string> end_points{switch_a_endpoint};
std::vector<std::string> peer_endpoints{switch_b_endpoint_inter};
std::thread switch_server_a_thread(StartSwitchServer,
std::ref(switch_server_ptr_a), end_points,
peer_endpoints);
switch_server_ptr_a->WaitServerReady();
std::shared_ptr<distributed::HeterServer> switch_server_ptr_b =
std::make_shared<distributed::HeterServer>();
end_points = {switch_b_endpoint, switch_b_endpoint_inter};
peer_endpoints = {};
std::thread switch_server_b_thread(StartSwitchServer,
std::ref(switch_server_ptr_b), end_points,
peer_endpoints);
switch_server_ptr_b->WaitServerReady();
end_points = {switch_b_endpoint, switch_b_endpoint_inter};
peer_endpoints = {};
std::thread switch_server_b_thread_inter(StartSwitchInterServer,
std::ref(switch_server_ptr_b),
end_points, peer_endpoints);
switch_server_ptr_b->WaitServerReady();
// 获取 client 实例
std::shared_ptr<distributed::HeterClient> heter_client_ptr_ =
distributed::HeterClient::GetInstance(
{switch_a_endpoint, switch_b_endpoint}, {}, 0);
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
framework::Executor exe(place);
framework::ProgramDesc program;
exe.Prepare(program, 0); // solve undefined symbol: tensor_table.cc
std::shared_ptr<framework::Scope> send_scope_ptr =
std::make_shared<framework::Scope>();
int64_t rows_numel = 10;
InitTensorsOnClient(send_scope_ptr.get(), &place, rows_numel);
LOG(INFO) << "InitTensorsOnClient done";
auto send_async = [&]() -> void {
/*
//std::string message_name =
std::to_string(distributed::PS_SAVE_WITH_SCOPE);
std::string message_name = "send and save";
std::vector<std::string> send_var_names{"w", "x"};
int ret = heter_client_ptr_->Send(ctx, *send_scope_ptr, message_name,
send_var_names);
if (!ret) {
LOG(ERROR) << ">>>> worker send success";
}
*/
///*
std::vector<int> vars_len{2, 4};
std::vector<float> values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
int64_t data_size = 6;
std::vector<std::string> send_var_names{"w", "x"};
int group_id = 0;
int ret = heter_client_ptr_->Send(group_id, send_var_names, vars_len,
values.data(), data_size);
if (!ret) {
LOG(INFO) << ">>>> worker send success";
}
//*/
};
std::thread send_thread(send_async);
/*
std::string message_name = std::to_string(distributed::PS_QUERY_WITH_SCOPE);
std::vector<std::string> recv_var_names{"w", "x"};
std::shared_ptr<framework::Scope> recv_scope_ptr =
std::make_shared<framework::Scope>();
int ret = heter_client_ptr_->Recv(ctx, *recv_scope_ptr, message_name,
recv_var_names);
if (!ret && recv_scope_ptr->FindVar("w") && recv_scope_ptr->FindVar("x")) {
LOG(INFO) << ">>>> worker recv success";
} else {
LOG(INFO) << "worker recv failed";
}
*/
///*
int group_id = 0;
std::vector<std::string> recv_var_names{"w", "x"};
std::vector<float> values;
int data_size = 6;
values.resize(data_size);
int ret = heter_client_ptr_->Recv(group_id, recv_var_names, values.data(),
data_size);
if (!ret) {
VLOG(4) << "queried data is: ";
for (auto f : values) {
VLOG(4) << f << " ";
}
LOG(INFO) << ">>>> worker recv success";
}
//*/
send_thread.join();
switch_server_ptr_a->Stop();
LOG(INFO) << "switch server A stopped";
switch_server_ptr_b->Stop();
LOG(INFO) << "switch server B stopped";
switch_server_a_thread.join();
LOG(INFO) << "switch_server_a_thread joined";
switch_server_b_thread.join();
LOG(INFO) << "switch_server_b_thread joined";
switch_server_b_thread_inter.join();
LOG(INFO) << "switch_server_b_thread_inter joined";
}
#endif
......@@ -88,21 +88,20 @@ void HeterListenAndServOp::RunAsyncLoop(framework::ProgramDesc *program) const {
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
}
for (size_t i = 0; i < block_list.size(); ++i) {
auto blkid = block_list[i];
auto it = message_to_block_id.find_value(blkid);
rpc_service_->RegisterServiceHandler(
heter_server_->RegisterServiceHandler(
it->first, [&](const MultiVarMsg *request, MultiVarMsg *response,
brpc::Controller *cntl) -> int {
return request_send_and_recv_handler_->Handle(request, response,
cntl);
return send_and_recv_variable_handler_->Handle(request, response,
cntl);
});
}
while (true) {
if (rpc_service_->IsExit() || rpc_service_->IsStop()) {
rpc_service_->Stop();
if (heter_server_->IsExit() || heter_server_->IsStop()) {
heter_server_->Stop();
VLOG(0) << "get exit. rpc_processor stop!";
break;
}
......@@ -110,8 +109,9 @@ void HeterListenAndServOp::RunAsyncLoop(framework::ProgramDesc *program) const {
} // while(true)
}
void RunServer(std::shared_ptr<paddle::distributed::HeterServer> service) {
service->StartHeterService();
void RunServer(
std::shared_ptr<paddle::distributed::HeterServer> heter_server_ptr) {
heter_server_ptr->StartHeterService();
}
void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
......@@ -126,16 +126,16 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
auto fan_in = Attr<int>("fanin");
auto inputs = Inputs("X");
PADDLE_ENFORCE_EQ(rpc_service_, nullptr,
PADDLE_ENFORCE_EQ(heter_server_, nullptr,
platform::errors::PreconditionNotMet(
"RPC service has been created unexpectedly."));
std::string endpoint = Attr<std::string>("endpoint");
VLOG(4) << "pserver_id: " << pserver_id << ", end_point:" << endpoint;
rpc_service_ = distributed::HeterServer::GetInstance();
rpc_service_->SetEndPoint(endpoint);
rpc_service_->SetFanin(fan_in);
heter_server_ = distributed::HeterServer::GetInstance();
heter_server_->SetEndPoint(endpoint);
heter_server_->SetFanin(fan_in);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>("optimize_blocks");
......@@ -146,20 +146,18 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope,
auto *program = optimize_blocks[0]->Program();
request_send_and_recv_handler_.reset(
new distributed::RequestSendAndRecvHandler());
request_send_and_recv_handler_->SetScope(&scope);
request_send_and_recv_handler_->SetDevCtx(&dev_ctx);
rpc_service_->SetRequestHandler(request_send_and_recv_handler_);
send_and_recv_variable_handler_.reset(
new distributed::SendAndRecvVariableHandler());
send_and_recv_variable_handler_->SetScope(&scope);
send_and_recv_variable_handler_->SetDevCtx(&dev_ctx);
heter_server_->SetServiceHandler(send_and_recv_variable_handler_);
VLOG(2) << "RunAsyncLoop";
auto message_to_block_id_str =
Attr<std::vector<std::string>>("message_to_block_id");
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
server_thread_.reset(new std::thread(RunServer, heter_server_));
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
heter_server_->WaitServerReady();
RunAsyncLoop(program);
VLOG(3) << "Wait for Server_thread_ stop";
(server_thread_.get())->join();
......
......@@ -34,7 +34,7 @@ limitations under the License. */
namespace paddle {
namespace distributed {
class HeterRequestHandler;
class ServiceHandlerBase;
class HeterServer;
} // namespace distributed
} // namespace paddle
......@@ -82,10 +82,10 @@ class HeterListenAndServOp : public framework::OperatorBase {
const platform::Place& dev_place) const override;
protected:
mutable std::shared_ptr<paddle::distributed::HeterServer> rpc_service_;
mutable std::shared_ptr<paddle::distributed::HeterServer> heter_server_;
mutable std::shared_ptr<std::thread> server_thread_;
mutable std::shared_ptr<paddle::distributed::RequestSendAndRecvHandler>
request_send_and_recv_handler_;
mutable std::shared_ptr<paddle::distributed::SendAndRecvVariableHandler>
send_and_recv_variable_handler_;
};
} // namespace operators
......
......@@ -142,7 +142,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
CreateVarsOnScope(scope, place);
}
void StartHeterServer(std::string endpoint) {
void RunHeterServerOp(std::string endpoint) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
......@@ -167,10 +167,10 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
FLAGS_eager_delete_tensor_gb = -1;
std::thread server_thread(StartHeterServer, endpoint);
std::thread server_thread(RunHeterServerOp, endpoint);
sleep(1);
auto b_rpc_service = distributed::HeterServer::GetInstance();
b_rpc_service->WaitServerReady();
auto heter_server_ptr_ = distributed::HeterServer::GetInstance();
heter_server_ptr_->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
......@@ -185,8 +185,8 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
(*micro_scope).push_back(micro_scope_0);
(*micro_scope).push_back(micro_scope_1);
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
heter_server_ptr_->SetMicroBatchScopes(micro_scopes);
heter_server_ptr_->SetMiniBatchScopes(mini_scopes);
using TaskQueue =
std::unordered_map<int,
......@@ -198,17 +198,13 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
SharedTaskQueue task_queue_(new TaskQueue{});
(*task_queue_)[0] = std::make_shared<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>();
b_rpc_service->SetTaskQueue(task_queue_);
heter_server_ptr_->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client =
distributed::HeterClient* heter_client_ptr_ =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
framework::Scope* scope = (*micro_scope)[0];
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
......@@ -224,8 +220,8 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
std::vector<std::string> recv_var = {};
LOG(INFO) << "before SendAndRecvAsync";
rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var,
"forward");
heter_client_ptr_->SendAndRecvAsync(ctx, *scope, in_var_name, send_var,
recv_var, "forward");
auto task = (*task_queue_)[0]->Pop();
PADDLE_ENFORCE_EQ(
task.first, "x",
......@@ -234,15 +230,15 @@ TEST(HETER_LISTEN_AND_SERV, CPU) {
InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel);
LOG(INFO) << "before SendAndRecvAsync 2";
rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name, send_var,
recv_var, "backward");
heter_client_ptr_->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name,
send_var, recv_var, "backward");
auto task2 = (*task_queue_)[0]->Pop();
PADDLE_ENFORCE_EQ(
task2.first, "x",
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->Stop();
heter_client_ptr_->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
LOG(INFO) << "end server thread join";
......
......@@ -34,8 +34,6 @@ using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP_ITSELF(scale);
std::shared_ptr<distributed::HeterServer> b_rpc_service;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
......@@ -171,31 +169,32 @@ void StartSendAndRecvServer(std::string endpoint) {
InitTensorsOnServer(&scope, &place, 10);
LOG(INFO) << "end InitTensorsOnServer";
std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler;
b_req_handler.reset(new distributed::RequestSendAndRecvHandler());
std::shared_ptr<distributed::SendAndRecvVariableHandler> b_req_handler;
b_req_handler.reset(new distributed::SendAndRecvVariableHandler());
LOG(INFO) << "before SetDevCtx";
b_req_handler->SetDevCtx(&ctx);
LOG(INFO) << "before SetScope";
b_req_handler->SetScope(&scope);
LOG(INFO) << "before HeterServer::GetInstance";
b_rpc_service = distributed::HeterServer::GetInstance();
b_rpc_service->SetEndPoint(endpoint);
std::shared_ptr<distributed::HeterServer> heter_server_ptr_ =
distributed::HeterServer::GetInstance();
heter_server_ptr_->SetEndPoint(endpoint);
LOG(INFO) << "before HeterServer::RegisterServiceHandler";
b_rpc_service->RegisterServiceHandler(
heter_server_ptr_->RegisterServiceHandler(
in_var_name, [&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service->RegisterServiceHandler(
heter_server_ptr_->RegisterServiceHandler(
in_var_name2, [&](const MultiVarMsg* request, MultiVarMsg* response,
brpc::Controller* cntl) -> int {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service->SetRequestHandler(b_req_handler);
heter_server_ptr_->SetServiceHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
RunServer(b_rpc_service);
// std::thread server_thread(std::bind(RunServer, b_rpc_service));
RunServer(heter_server_ptr_);
// std::thread server_thread(std::bind(RunServer, heter_server_ptr_));
// server_thread.join();
}
......@@ -206,9 +205,10 @@ TEST(SENDANDRECV, CPU) {
std::string endpoint = get_ip_port();
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance();
std::shared_ptr<distributed::HeterServer> heter_server_ptr_ =
distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint);
b_rpc_service->WaitServerReady();
heter_server_ptr_->WaitServerReady();
using MicroScope =
std::unordered_map<int, std::shared_ptr<std::vector<framework::Scope*>>>;
using MiniScope = std::unordered_map<int, framework::Scope*>;
......@@ -223,8 +223,8 @@ TEST(SENDANDRECV, CPU) {
(*micro_scope).push_back(micro_scope_0);
(*micro_scope).push_back(micro_scope_1);
(*micro_scopes)[0] = micro_scope;
b_rpc_service->SetMicroBatchScopes(micro_scopes);
b_rpc_service->SetMiniBatchScopes(mini_scopes);
heter_server_ptr_->SetMicroBatchScopes(micro_scopes);
heter_server_ptr_->SetMiniBatchScopes(mini_scopes);
using TaskQueue =
std::unordered_map<int,
......@@ -236,17 +236,13 @@ TEST(SENDANDRECV, CPU) {
SharedTaskQueue task_queue_(new TaskQueue{});
(*task_queue_)[0] = std::make_shared<
::paddle::framework::BlockingQueue<std::pair<std::string, int>>>();
b_rpc_service->SetTaskQueue(task_queue_);
heter_server_ptr_->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client =
distributed::HeterClient* heter_client_ptr_ =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
framework::Scope* scope = (*micro_scope)[0];
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
......@@ -262,8 +258,8 @@ TEST(SENDANDRECV, CPU) {
std::vector<std::string> recv_var = {};
LOG(INFO) << "before SendAndRecvAsync";
rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var,
"forward");
heter_client_ptr_->SendAndRecvAsync(ctx, *scope, in_var_name, send_var,
recv_var, "forward");
LOG(INFO) << "client wait for Pop";
auto task = (*task_queue_)[0]->Pop();
......@@ -276,8 +272,8 @@ TEST(SENDANDRECV, CPU) {
InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel);
LOG(INFO) << "before SendAndRecvAsync 2";
std::string in_var_name2("y");
rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name2,
send_var, recv_var, "backward");
heter_client_ptr_->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name2,
send_var, recv_var, "backward");
LOG(INFO) << "after SendAndRecvAsync 2";
auto task2 = (*task_queue_)[0]->Pop();
......@@ -286,8 +282,7 @@ TEST(SENDANDRECV, CPU) {
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->FinalizeWorker();
b_rpc_service->Stop();
heter_server_ptr_->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
LOG(INFO) << "end server thread join";
......
......@@ -36,8 +36,6 @@ using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP_ITSELF(scale);
USE_OP(send_and_recv);
std::shared_ptr<distributed::HeterServer> b_rpc_service;
std::string get_ip_port() {
std::mt19937 rng;
rng.seed(std::random_device()());
......@@ -148,14 +146,15 @@ void StartSendAndRecvServer(std::string endpoint) {
InitTensorsOnServer(&scope, &place, 10);
LOG(INFO) << "end InitTensorsOnServer";
std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler;
b_req_handler.reset(new distributed::RequestSendAndRecvHandler());
std::shared_ptr<distributed::SendAndRecvVariableHandler> b_req_handler;
b_req_handler.reset(new distributed::SendAndRecvVariableHandler());
LOG(INFO) << "before SetDevCtx";
b_req_handler->SetDevCtx(&ctx);
LOG(INFO) << "before SetScope";
b_req_handler->SetScope(&scope);
LOG(INFO) << "before HeterServer::GetInstance";
b_rpc_service = distributed::HeterServer::GetInstance();
std::shared_ptr<distributed::HeterServer> b_rpc_service =
distributed::HeterServer::GetInstance();
b_rpc_service->SetEndPoint(endpoint);
LOG(INFO) << "before HeterServer::RegisterServiceHandler";
b_rpc_service->RegisterServiceHandler(
......@@ -164,7 +163,7 @@ void StartSendAndRecvServer(std::string endpoint) {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service->SetRequestHandler(b_req_handler);
b_rpc_service->SetServiceHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
RunServer(b_rpc_service);
......@@ -179,7 +178,8 @@ TEST(SENDANDRECV, CPU) {
std::string endpoint = get_ip_port();
std::string previous_endpoint = endpoint;
LOG(INFO) << "before StartSendAndRecvServer";
b_rpc_service = distributed::HeterServer::GetInstance();
std::shared_ptr<distributed::HeterServer> b_rpc_service =
distributed::HeterServer::GetInstance();
std::thread server_thread(StartSendAndRecvServer, endpoint);
b_rpc_service->WaitServerReady();
using MicroScope =
......@@ -292,7 +292,6 @@ TEST(SENDANDRECV, CPU) {
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->FinalizeWorker();
b_rpc_service->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
......
......@@ -167,8 +167,8 @@ void StartSendAndRecvServer(std::string endpoint) {
InitTensorsOnServer(&scope, &place, 10);
LOG(INFO) << "end InitTensorsOnServer";
std::shared_ptr<distributed::RequestSendAndRecvHandler> b_req_handler;
b_req_handler.reset(new distributed::RequestSendAndRecvHandler());
std::shared_ptr<distributed::SendAndRecvVariableHandler> b_req_handler;
b_req_handler.reset(new distributed::SendAndRecvVariableHandler());
LOG(INFO) << "before SetDevCtx";
b_req_handler->SetDevCtx(&ctx);
LOG(INFO) << "before SetScope";
......@@ -183,7 +183,7 @@ void StartSendAndRecvServer(std::string endpoint) {
return b_req_handler->Handle(request, response, cntl);
});
b_rpc_service2->SetRequestHandler(b_req_handler);
b_rpc_service2->SetServiceHandler(b_req_handler);
LOG(INFO) << "before HeterServer::RunServer";
RunServer(b_rpc_service2);
......@@ -228,13 +228,11 @@ TEST(SENDANDRECV, GPU) {
b_rpc_service2->SetTaskQueue(task_queue_);
LOG(INFO) << "before HeterClient::GetInstance";
distributed::HeterClient* rpc_client =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0)
.get();
PADDLE_ENFORCE_NE(rpc_client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::shared_ptr<distributed::HeterClient> heter_client_ptr_ =
distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0);
if (heter_client_ptr_ == nullptr) {
LOG(ERROR) << "heter_client_ptr_ is null";
}
framework::Scope* scope = (*micro_scope)[0];
platform::CUDAPlace place;
......@@ -316,7 +314,6 @@ TEST(SENDANDRECV, GPU) {
platform::errors::InvalidArgument(
"Recv message and Send message name not match, Check your Code"));
rpc_client->FinalizeWorker();
b_rpc_service2->Stop();
LOG(INFO) << "end server Stop";
server_thread.join();
......
......@@ -1174,6 +1174,7 @@ SIXTH_PARALLEL_JOB_NEW = [
]
LOWEST_PARALLEL_JOB_NEW = [
'heter_cloud_comm_cpu_test',
'heter_server_test',
'test_scatter_op',
'test_trt_convert_hard_sigmoid',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册