未验证 提交 4bc22b69 编写于 作者: Z ziyoujiyi 提交者: GitHub

add horizontal federation learning ps feature (#44327)

* back fl

* delete ssl cert

* .

* make warning

* .

* unittest paral degree

* solve unittest

* heter & multi cloud commm ready

* .

* .

* fl-ps v1.0

* .

* support N + N mode

* .

* .

* .

* .

* delete print

* .

* .

* .

* .

* fix bug

* .

* .

* fl-ps with coordinator ready

* merge dev

* update message parse only

* update fl client scheduler

* fix bug

* update multithreads sync

* fix ci errors

* update role_maker.py

* update role_maker.py

* fix ci error: windows py import error

* fix ci error: windows py import error

* fix windows ci pylib import error

* add dump fields & params

* try to fix windows import fleet error

* fix ps FLAGS error
上级 e3ee5103
......@@ -47,7 +47,6 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed
GIT_REPOSITORY "https://github.com/wangjiawei04/brpc"
#GIT_REPOSITORY "https://github.com/ziyoujiyi/brpc" # ssl error in the previous repo(can be mannual fixed)
GIT_TAG "e203afb794caf027da0f1e0776443e7d20c0c28e"
PREFIX ${BRPC_PREFIX_DIR}
UPDATE_COMMAND ""
......
......@@ -78,6 +78,10 @@ set_source_files_properties(
graph_brpc_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
graph_brpc_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
coordinator_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
brpc_utils
SRCS brpc_utils.cc
......@@ -90,6 +94,7 @@ cc_library(
cc_library(
downpour_client
SRCS graph_brpc_client.cc brpc_ps_client.cc ps_local_client.cc
coordinator_client.cc
DEPS eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})
cc_library(
......
......@@ -18,10 +18,22 @@
#include <sstream>
#include <string>
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/split.h"
static const int max_port = 65535;
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
DEFINE_int32(pserver_push_dense_merge_limit,
12,
"limit max push_dense local merge requests");
......@@ -66,16 +78,6 @@ DEFINE_int32(pserver_sparse_table_shard_num,
1000,
"sparse table shard for save & load");
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
inline size_t get_sparse_shard(uint32_t shard_num,
uint32_t server_num,
uint64_t key) {
......@@ -101,7 +103,7 @@ void DownpourPsClientService::service(
}
}
// 启动client端RpcService 用于数据互发等操作
// 启动 client 端 RpcService 用于数据互发等操作
int32_t BrpcPsClient::StartClientService() {
if (_service.Configure(this, _client_id) != 0) {
LOG(ERROR)
......@@ -122,6 +124,35 @@ int32_t BrpcPsClient::StartClientService() {
_server_started = true;
_env->RegistePsClient(
butil::my_ip_cstr(), _server.listen_address().port, _client_id);
VLOG(0) << "BrpcPsClient Service addr: " << butil::my_ip_cstr() << ", "
<< _server.listen_address().port << ", " << _client_id;
return 0;
}
// 启动 FlClientService,用户接收 coordinator 数据
int32_t BrpcPsClient::StartFlClientService(const std::string &self_endpoint) {
_fl_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (self_endpoint.empty()) {
LOG(ERROR) << "fl-ps > fl client endpoint not set";
return -1;
}
if (_fl_server.Start(self_endpoint.c_str(), &options) != 0) {
VLOG(0) << "fl-ps > StartFlClientService failed. Try again.";
auto ip_port = paddle::string::Split(self_endpoint, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (_fl_server.Start(int_ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "fl-ps > StartFlClientService failed, ip_port= "
<< int_ip_port;
return -1;
}
} else {
VLOG(0) << "fl-ps > StartFlClientService succeed! listen on "
<< self_endpoint;
}
return 0;
}
......@@ -166,6 +197,96 @@ int32_t BrpcPsClient::CreateClient2ClientConnection(
return 0;
}
int32_t BrpcPsClient::InitializeFlWorker(const std::string &self_endpoint) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms =
paddle::distributed::FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3;
// 获取 coordinator 列表,并连接
std::string coordinator_ip_port;
std::vector<PSHost> coordinator_list = _env->GetCoordinators();
_coordinator_channels.resize(coordinator_list.size());
for (size_t i = 0; i < coordinator_list.size(); ++i) {
coordinator_ip_port.assign(coordinator_list[i].ip.c_str());
coordinator_ip_port.append(":");
coordinator_ip_port.append(std::to_string(coordinator_list[i].port));
VLOG(0) << "fl-ps > BrpcFlclient connetcting to coordinator: "
<< coordinator_ip_port;
for (size_t j = 0; j < _coordinator_channels[i].size(); ++j) {
_coordinator_channels[i][j].reset(new brpc::Channel());
if (_coordinator_channels[i][j]->Init(
coordinator_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:"
<< coordinator_ip_port << " Failed! Try again.";
std::string int_ip_port = GetIntTypeEndpoint(coordinator_list[i].ip,
coordinator_list[i].port);
if (_coordinator_channels[i][j]->Init(
int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:"
<< int_ip_port << " Failed!";
return -1;
}
}
}
}
StartFlClientService(self_endpoint);
VLOG(0) << "fl-ps > InitializeFlWorker finished!";
return 0;
}
void BrpcPsClient::PushFLClientInfoSync(const std::string &fl_client_info) {
size_t request_call_num = _coordinator_channels.size();
FlClientBrpcClosure *closure =
new FlClientBrpcClosure(request_call_num, [request_call_num](void *done) {
auto *closure = reinterpret_cast<FlClientBrpcClosure *>(done);
int ret = 0;
for (size_t i = 0; i < request_call_num; i++) {
if (closure->check_response(i, PUSH_FL_CLIENT_INFO_SYNC) != 0) {
LOG(ERROR) << "fl-ps > PushFLClientInfoSync response from "
"coordinator is failed";
ret = -1;
return;
} else {
VLOG(0) << "fl-ps > rpc service call cost time: "
<< (closure->cntl(i)->latency_us() / 1000) << " ms";
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int32_t> fut = promise->get_future();
closure->add_promise(promise);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PUSH_FL_CLIENT_INFO_SYNC);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->set_str_params(fl_client_info);
brpc::Channel *rpc_channel = _coordinator_channels[0][0].get();
if (rpc_channel == nullptr) {
LOG(ERROR) << "_coordinator_channels is null";
return;
}
PsService_Stub rpc_stub(rpc_channel); // CoordinatorService
rpc_stub.FLService(
closure->cntl(i), closure->request(i), closure->response(i), closure);
fut.wait();
}
VLOG(0) << "fl-ps > PushFLClientInfoSync finished, client id: " << _client_id;
return;
}
std::string BrpcPsClient::PullFlStrategy() {
while (!_service._is_fl_strategy_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
VLOG(0) << "fl-ps > waiting for fl strategy returned from coordinator";
}
_service._is_fl_strategy_ready =
false; // only support single thread, no need for multi-threads
return _service._fl_strategy;
}
int32_t BrpcPsClient::Initialize() {
_async_call_num = 0;
......@@ -300,6 +421,24 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) {
return data;
}
int FlClientBrpcClosure::check_response(size_t request_idx, int cmd_id) {
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id
<< " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
if (_responses[request_idx].err_code() != 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return 0;
}
std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
......
......@@ -25,6 +25,7 @@
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
......@@ -56,16 +57,71 @@ class DownpourPsClientService : public PsService {
_rank = rank_id;
return 0;
}
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done);
virtual void FLService(::google::protobuf::RpcController *controller,
const CoordinatorReqMessage *request,
CoordinatorResMessage *response,
::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
size_t client_id = request->client_id();
CHECK(_client->_client_id == client_id)
<< "request client id not matched self";
_fl_strategy = request->str_params();
_is_fl_strategy_ready = true;
response->set_err_code(0);
response->set_err_msg("");
VLOG(0) << "fl-ps > DownpourPsClientService::FLService finished!";
return;
}
public:
std::string _fl_strategy;
bool _is_fl_strategy_ready = false;
protected:
size_t _rank;
PSClient *_client;
};
class FlClientBrpcClosure : public PSClientClosure {
public:
FlClientBrpcClosure(size_t num, PSClientCallBack callback)
: PSClientClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~FlClientBrpcClosure() {}
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
CoordinatorReqMessage *request(size_t i) { return &_requests[i]; }
CoordinatorResMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id);
int check_save_response(size_t request_idx, int cmd_id);
std::string get_response(size_t request_idx, int cmd_id);
private:
std::atomic<int32_t> _waiting_num;
std::vector<CoordinatorReqMessage> _requests;
std::vector<CoordinatorResMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
class DownpourBrpcClosure : public PSClientClosure {
public:
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
......@@ -267,6 +323,14 @@ class BrpcPsClient : public PSClient {
}
int32_t Initialize() override;
// for fl
public:
virtual int32_t InitializeFlWorker(const std::string &self_endpoint);
int32_t StartFlClientService(const std::string &self_endpoint);
virtual void PushFLClientInfoSync(const std::string &fl_client_info);
std::string PullFlStrategy();
// for fl
private:
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
......@@ -320,6 +384,8 @@ class BrpcPsClient : public PSClient {
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
std::vector<std::array<std::shared_ptr<brpc::Channel>, 1>>
_coordinator_channels; // client2coordinator
std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
......@@ -360,6 +426,7 @@ class BrpcPsClient : public PSClient {
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
brpc::Server _fl_server;
DownpourPsClientService _service;
bool _server_started = false;
std::atomic_uint grad_num_{0};
......
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
set_source_files_properties(
communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
......@@ -89,7 +89,7 @@ int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
int table_id,
Scope *scope) {
Scope *scope) { // pserver_scope_
platform::RecordEvent record_event("Communicator->RpcRecvDense",
platform::TracerEventType::Communication,
1);
......@@ -106,7 +106,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
paddle::distributed::Region reg(temp_data, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
VLOG(1) << "Communicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
......@@ -123,11 +123,11 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
for (auto &t : varnames) {
Variable *var = scope->FindVar(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(3) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? "
VLOG(3) << "Communicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place());
float *temp_recv_data = tensor->mutable_data<float>(platform::CPUPlace());
VLOG(3) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
VLOG(3) << "Communicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_recv_data[0]
<< " Temp_data[-1] " << temp_recv_data[tensor->numel() - 1];
if (platform::is_gpu_place(tensor->place())) {
......@@ -136,7 +136,7 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
xpu_temp_scope_->FindVar(t)->GetMutable<LoDTensor>();
framework::TensorCopy(*temp_tensor, tensor->place(), tensor);
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
VLOG(1) << "Communicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
......@@ -187,7 +187,8 @@ void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
return;
}
void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
void Communicator::RpcSendDense(const CommContext &ctx,
const Scope &scope) { // delta_scope_
platform::RecordEvent record_event("Communicator->RpcSendDense",
platform::TracerEventType::Communication,
1);
......@@ -343,21 +344,21 @@ void Communicator::RpcRecvSparse(const std::string &varname,
auto dim = tensor->dims()[1];
uint64_t sparse_num = static_cast<uint64_t>(tensor->dims()[0]);
std::vector<uint64_t> sparse_push_keys(sparse_num);
std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0);
std::vector<uint64_t> sparse_pull_keys(sparse_num);
std::iota(sparse_pull_keys.begin(), sparse_pull_keys.end(), 0);
std::vector<float *> push_g_vec;
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * dim);
std::vector<float *> pull_g_vec;
for (auto i = 0; i < static_cast<int>(sparse_pull_keys.size()); ++i) {
pull_g_vec.push_back(tensor->data<float>() + i * dim);
}
bool training = true;
auto status =
_worker_ptr->PullSparseParam(static_cast<float **>(push_g_vec.data()),
_worker_ptr->PullSparseParam(static_cast<float **>(pull_g_vec.data()),
table_id,
sparse_push_keys.data(),
sparse_push_keys.size(),
sparse_pull_keys.data(),
sparse_pull_keys.size(),
training);
status.wait();
return;
......@@ -1013,8 +1014,9 @@ void SyncCommunicator::BarrierRecv() {
VLOG(4) << "BarrierRecv with SyncCommunicator";
}
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) {
void GeoCommunicator::Send(
const std::vector<std::string> &var_names,
const framework::Scope &scope) { // last op in program
platform::RecordEvent record_event(
"GeoCommunicator->Send", platform::TracerEventType::Communication, 1);
waiting_ = false;
......@@ -1041,10 +1043,13 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
auto &rows = var->Get<phi::SelectedRows>().rows();
// insert ids which has not been record
for (size_t j = 0; j < rows.size(); j++) {
// VLOG(0) << "fl-ps > table_name: " << table_name << " splited_var_nums: " <<
// splited_var_nums << " rows size: " << rows.size();
for (size_t j = 0; j < rows.size(); j++) { // batch_size == rows.size()
auto ep_idx = rows[j] % splited_var_nums;
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
.insert(rows[j]);
// VLOG(0) << " id: " << rows[j] << " ";
}
for (auto &iter : ids_table) {
......@@ -1143,7 +1148,7 @@ void GeoCommunicator::InitDense(std::vector<std::string> &varnames,
} else {
BarrierWithTable(1);
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param to table " << table_id
VLOG(1) << "pull dense param from table " << table_id
<< " from 0' trainer done";
}
......@@ -1153,7 +1158,7 @@ void GeoCommunicator::InitDense(std::vector<std::string> &varnames,
global_var->GetMutable<framework::LoDTensor>();
auto *old_var = old_scope_->Var(t);
old_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, old_var);
framework::CopyVariable(*global_var, old_var); // src, dst
// init pserver_scope_
auto *pserver_var = pserver_scope_->Var(t);
pserver_var->GetMutable<framework::LoDTensor>();
......@@ -1218,7 +1223,7 @@ void GeoCommunicator::RecvDense(const CommContext &send_ctx) {
// 1. recv from pserver
RpcRecvDense(varnames, table_id, pserver_scope_.get());
// 2.1 pserver - old => delta; 2.2 latest + old => latest 2.3 old => pserver
// 2.1 pserver - old => delta; 2.2 latest + delta => latest 2.3 old => pserver
phi::CPUContext cpu_ctx;
for (auto &varname : varnames) {
auto *var_latest = recv_scope_->FindVar(varname);
......@@ -1267,7 +1272,7 @@ void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) {
VLOG(1) << "Init Sparse " << var_name << " : table " << table_id << " done.";
auto *global_var = recv_scope_->FindVar(var_name);
auto *var = old_scope_->Var(var_name);
framework::CopyVariable(*global_var, var);
framework::CopyVariable(*global_var, var); // src, dst
return;
}
......@@ -1278,7 +1283,8 @@ std::vector<int64_t> GeoCommunicator::MergeSparseIds(
1);
size_t merge_num = 0, wait_times = 0;
std::unordered_set<int64_t> sparse_ids;
while (merge_num < static_cast<size_t>(max_merge_var_num_)) {
while (merge_num <
static_cast<size_t>(max_merge_var_num_)) { // -> geo_step: 100
VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
if (sparse_id_queues_.at(send_varname)->Size() > 0) {
wait_times = 0;
......@@ -1467,7 +1473,9 @@ void GeoCommunicator::MainThread() {
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
// varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0
auto send_recv_task = [this, table_id, ep_idx, &ctx] {
auto splited_varname = ctx.splited_varnames[ep_idx];
auto splited_varname =
ctx.splited_varnames[ep_idx]; // embedding_0.w_0.block0
// embedding_1.w_0.block0
auto sparse_ids = MergeSparseIds(splited_varname);
SendSparse(splited_varname, sparse_ids, table_id, ep_idx);
RecvSparse(splited_varname, table_id, ep_idx);
......@@ -1490,5 +1498,83 @@ void GeoCommunicator::MainThread() {
}
}
void FLCommunicator::InitBrpcClient(
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list) {
auto fleet = paddle::distributed::FleetWrapper::GetInstance();
if (_worker_ptr.get() == nullptr) {
VLOG(0) << "fl-ps > FLCommunicator::InitBrpcClient get _worker_ptr";
_worker_ptr =
fleet->worker_ptr_; // FleetWrapper::InitWorker must be excuted before,
// but no need for Coordinator
}
if (coordinator_client_ptr_ == nullptr) {
coordinator_client_ptr_.reset(new CoordinatorClient);
}
int16_t servers = host_sign_list.size();
coordinator_client_ptr_->_env = &ps_env_;
coordinator_client_ptr_->_env->SetPsServers(&host_sign_list, servers);
}
void FLCommunicator::StartCoordinatorClient(
const std::vector<std::string> &trainer_endpoints) {
if (coordinator_client_ptr_ == nullptr) {
LOG(ERROR) << "coordinator_client_ptr_ is null";
return;
}
coordinator_client_ptr_->Initialize(trainer_endpoints);
VLOG(0) << "fl-ps > StartCoordinatorClient finish!";
}
void FLCommunicator::StartCoordinatorServer() {
if (coordinator_client_ptr_ == nullptr) {
LOG(ERROR) << "coordinator_client_ptr_ is null";
}
int ret = coordinator_client_ptr_->StartClientService();
if (ret != 0) {
LOG(ERROR) << "coordinator_client_ptr_ StartClientService failed";
}
VLOG(0) << "fl-ps > StartCoordinatorServer finished!";
return;
}
std::unordered_map<uint32_t, std::string> FLCommunicator::QueryFLClientsInfo() {
return coordinator_client_ptr_->QueryFLClientsInfo();
}
void FLCommunicator::SaveFLStrategy(
const std::unordered_map<uint32_t, std::string> &fl_strategy) {
coordinator_client_ptr_->SaveFLStrategy(fl_strategy);
return;
}
void FLCommunicator::SendThreadAsync() {
while (is_running_) {
RpcSendFLStrategy();
}
return;
}
void FLCommunicator::RpcSendFLStrategy() {
std::set<uint32_t> clients = coordinator_client_ptr_->GetFLClientIds();
coordinator_client_ptr_->WaitForFLStrategyReady();
for (auto client_id : clients) {
coordinator_client_ptr_->SendFLStrategy(client_id);
}
coordinator_client_ptr_->ResetFLStrategyFlag();
VLOG(0) << "fl-ps > RpcSendFLStrategy finished!";
return;
}
void FLCommunicator::StartCoordinator(
const std::string &self_endpoint,
const std::vector<std::string> &trainer_endpoints) {
coordinator_client_ptr_->SetEndpoint(self_endpoint);
StartCoordinatorClient(trainer_endpoints);
StartCoordinatorServer();
async_send_thread_.reset(
new std::thread(&FLCommunicator::SendThreadAsync, this));
}
} // namespace distributed
} // namespace paddle
......@@ -31,6 +31,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
......@@ -241,9 +242,11 @@ class Communicator {
envs[iter.first] = iter.second;
VLOG(3) << iter.first << ": " << iter.second;
}
barrier_table_id_ = std::stoi(envs.at("barrier_table_id"));
trainer_id_ = std::stoi(envs.at("trainer_id"));
trainers_ = std::stoi(envs.at("trainers"));
if (!envs.empty()) {
barrier_table_id_ = std::stoi(envs.at("barrier_table_id"));
trainer_id_ = std::stoi(envs.at("trainer_id"));
trainers_ = std::stoi(envs.at("trainers"));
}
}
virtual void InitBrpcClient(const std::string &dist_desc,
......@@ -280,6 +283,15 @@ class Communicator {
int batches,
Scope *send_scope);
virtual std::unordered_map<uint32_t, std::string> QueryFLClientsInfo() {
return {};
}
virtual void SaveFLStrategy(
const std::unordered_map<uint32_t, std::string> &fl_strategy) {}
virtual void StartCoordinator(
const std::string &self_endpoint,
const std::vector<std::string> &trainer_endpoints) {}
virtual ~Communicator() {}
virtual void RpcProfilerControl();
......@@ -376,10 +388,6 @@ class Communicator {
PSClient *GetPsClient() { return _worker_ptr.get(); }
std::shared_ptr<paddle::distributed::PSClient> GetPsClientPtr() {
return std::move(_worker_ptr);
}
RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; }
std::shared_ptr<PSClient> _worker_ptr; // pointer to worker
......@@ -657,5 +665,50 @@ class GeoCommunicator : public AsyncCommunicator {
sparse_id_queues_;
};
class FLCommunicator : public GeoCommunicator {
public:
FLCommunicator() : GeoCommunicator() {}
~FLCommunicator() {
is_running_ = false;
async_send_thread_->join();
}
explicit FLCommunicator(const std::map<std::string, std::string> &envs)
: GeoCommunicator(envs) {}
void InitEnvs() override {}
virtual void InitBrpcClient(const std::string &dist_desc,
const std::vector<std::string> &host_sign_list);
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override {}
void StartCoordinatorClient(
const std::vector<std::string> &trainer_endpoints);
void StartCoordinatorServer();
void StartCoordinator(
const std::string &self_endpoint,
const std::vector<std::string> &trainer_endpoints) override;
std::unordered_map<uint32_t, std::string> QueryFLClientsInfo();
void SaveFLStrategy(
const std::unordered_map<uint32_t, std::string> &fl_strategy);
void SendThreadAsync();
void RpcSendFLStrategy();
private:
int thread_pool_size_ = 1;
bool is_running_ = true;
PaddlePSEnvironment ps_env_;
std::shared_ptr<CoordinatorClient> coordinator_client_ptr_{nullptr};
std::unique_ptr<std::thread> async_send_thread_{nullptr};
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include <memory>
#include <sstream>
#include <string>
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/split.h"
static const int MIN_PORT = 8500;
static const int MAX_PORT = 65535;
namespace paddle {
namespace distributed {
DEFINE_uint64(total_fl_client_size, 100, "supported total fl client size");
DEFINE_uint32(coordinator_wait_all_clients_max_time, 60, "uint32: s");
void CoordinatorService::FLService(
::google::protobuf::RpcController* controller,
const CoordinatorReqMessage* request,
CoordinatorResMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
response->set_err_code(0);
response->set_err_msg("");
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
int32_t msg_type = request->cmd_id();
uint32_t from_client_id = request->client_id();
VLOG(0) << "fl-ps > recv from client id: " << from_client_id
<< ", msg_type: " << msg_type;
// TODO(ziyoujiyi): find is not thread safe, beacuse of RB_Tree traversal
auto itr = _service_handle_map.find(msg_type);
if (itr == _service_handle_map.end()) {
LOG(ERROR) << "fl-ps > unknown flClient2Coordinator msg type: " << msg_type;
return;
}
int ret = itr->second(*request, response, cntl); // SaveFLClientInfo
if (ret != 0) {
response->set_err_code(-1);
response->set_err_msg("fl-ps > handle flClient2Coordinator msg failed");
}
return;
}
int32_t CoordinatorClient::Initialize(
const std::vector<std::string>& trainer_endpoints) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = paddle::distributed::FLAGS_pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms =
paddle::distributed::FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3;
std::string server_ip_port;
// 获取 Pserver 列表,并连接
if (_env == nullptr) {
LOG(ERROR) << "_env is null in CoordinatorClient::Initialize()";
return -1;
}
std::vector<PSHost> pserver_list = _env->GetPsServers();
_pserver_channels.resize(pserver_list.size());
for (size_t i = 0; i < pserver_list.size(); ++i) {
server_ip_port.assign(pserver_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(pserver_list[i].port));
for (size_t j = 0; j < _pserver_channels[i].size(); ++j) {
_pserver_channels[i][j].reset(new brpc::Channel());
if (_pserver_channels[i][j]->Init(server_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "CoordinatorClient connect to PServer:" << server_ip_port
<< " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(pserver_list[i].ip, pserver_list[i].port);
if (_pserver_channels[i][j]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "CoordinatorClient connect to PServer:" << int_ip_port
<< " Failed!";
return -1;
}
}
}
}
// 获取 fl_client 列表,并连接
std::vector<PSHost> fl_client_list;
fl_client_list.resize(trainer_endpoints.size());
if (fl_client_list.empty()) {
LOG(ERROR) << ">>> fl clients addr info lost";
return -1;
}
for (size_t i = 0; i < trainer_endpoints.size(); i++) {
std::vector<std::string> addr =
paddle::string::Split(trainer_endpoints[i], ':');
fl_client_list[i].ip = addr[0];
fl_client_list[i].port = std::stol(addr[1]);
fl_client_list[i].rank = i; // TO CHECK
}
std::string fl_client_ip_port;
for (size_t i = 0; i < fl_client_list.size(); ++i) {
fl_client_ip_port.assign(fl_client_list[i].ip);
fl_client_ip_port.append(":");
fl_client_ip_port.append(std::to_string(fl_client_list[i].port));
uint32_t rank = fl_client_list[i].rank;
VLOG(0) << "fl-ps > coordinator connect to fl_client: " << rank;
_fl_client_channels[rank].reset(new brpc::Channel());
if (_fl_client_channels[rank]->Init(
fl_client_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "CoordinatorClient connect to FLClient:"
<< fl_client_ip_port << " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(fl_client_list[i].ip, fl_client_list[i].port);
if (_fl_client_channels[rank]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "CoordinatorClient connect to PSClient:" << int_ip_port
<< " Failed!";
return -1;
}
}
}
SetTotalFLClientsNum(fl_client_list.size());
SetDefaultFLStrategy();
return 0;
}
int32_t CoordinatorClient::StartClientService() {
_service.Initialize();
_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
options.num_threads = 1;
if (_endpoint.empty()) {
LOG(ERROR) << "fl-ps > coordinator server endpoint not set";
return -1;
}
auto addr = paddle::string::Split(_endpoint, ':');
std::string ip = addr[0];
std::string port = addr[1];
std::string rank = addr[2];
std::string ip_port = ip + ":" + port;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "fl-ps > StartClientService failed";
return -1;
}
uint32_t port_ = std::stol(port);
int32_t rank_ = std::stoi(rank);
_env->RegisteCoordinatorClient(ip, port_, rank_);
VLOG(0) << "fl-ps > coordinator service addr: " << ip << ", " << port << ", "
<< _coordinator_id;
return 0;
}
void CoordinatorClient::SendFLStrategy(const uint32_t& client_id) {
size_t request_call_num = 1;
FlClientBrpcClosure* closure =
new FlClientBrpcClosure(request_call_num, [](void* done) {
auto* closure = reinterpret_cast<FlClientBrpcClosure*>(done);
int ret = 0;
if (closure->check_response(0, PUSH_FL_STRATEGY) != 0) {
LOG(ERROR) << "fl-ps > SendFLStrategy failed";
ret = -1;
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int32_t> fut = promise->get_future();
closure->add_promise(promise);
closure->request(0)->set_cmd_id(PUSH_FL_STRATEGY);
closure->request(0)->set_client_id(client_id);
std::string fl_strategy = _fl_strategy_mp[client_id];
closure->request(0)->set_str_params(fl_strategy);
brpc::Channel* rpc_channel = _fl_client_channels[client_id].get();
if (rpc_channel == nullptr) {
LOG(ERROR) << "fl-ps > _fl_client_channels is null";
return;
}
PsService_Stub rpc_stub(rpc_channel); // DownpourPsClientService
rpc_stub.FLService(
closure->cntl(0), closure->request(0), closure->response(0), closure);
fut.wait();
VLOG(0) << "fl-ps > SendFLStrategy to client: " << client_id << " finished";
return;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace distributed {
DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(pserver_connect_timeout_ms);
DECLARE_uint64(total_fl_client_size);
DECLARE_uint32(coordinator_wait_all_clients_max_time);
using CoordinatorServiceFunc =
std::function<int32_t(const CoordinatorReqMessage& request,
CoordinatorResMessage* response,
brpc::Controller* cntl)>;
class ClientReportedInfo {
public:
ClientReportedInfo() {}
~ClientReportedInfo() {}
uint32_t client_id;
uint32_t iteration_idx;
double auc = 0.0;
};
class CoordinatorServiceHandle {
public:
CoordinatorServiceHandle() {}
virtual ~CoordinatorServiceHandle() {}
void SaveFLClientInfo(const CoordinatorReqMessage& request) {
auto client_id = request.client_id();
const std::string& str_params = request.str_params();
// each client is allowed to send empty message to maintain heartbeat(i.e.
// use staleness msg)
std::unique_lock<std::mutex> lck(_mtx);
if (str_params.size() != 0) {
_client_info_mp[client_id] = str_params;
} else {
LOG(INFO) << "fl-ps > content in request from " << client_id
<< " is null";
}
fl_client_ids.insert(client_id);
_fl_clients_count++;
// TODO(ziyoujiyi): how to process when a client loss connection?
if (_fl_clients_count.load() == last_round_total_fl_clients_num) {
_is_all_clients_info_collected = true;
_cv.notify_one();
}
lck.unlock();
VLOG(0) << "last_round_total_fl_clients_num: "
<< last_round_total_fl_clients_num
<< ", has recved fl client num: " << _fl_clients_count.load();
return;
}
std::unordered_map<uint32_t, std::string> QueryFLClientsInfo() {
platform::Timer timeline;
double query_wait_time = 0.0;
timeline.Start();
auto f = [&]() -> bool {
while (query_wait_time <
paddle::distributed::
FLAGS_coordinator_wait_all_clients_max_time) { // in case that
// some
// clients down
if (_is_all_clients_info_collected == true) {
// LOG(INFO) << "fl-ps > _is_all_clients_info_collected";
return true;
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
timeline.Pause();
query_wait_time += timeline.ElapsedSec();
}
// LOG(WARNNING) << "fl-ps > query_wait_time exceed!";
return true;
};
std::unique_lock<std::mutex> lck(_mtx);
_cv.wait(lck, f);
lck.unlock();
_is_all_clients_info_collected = false;
_fl_clients_count.store(0);
return _client_info_mp;
}
public:
std::unordered_map<uint32_t, std::string> _client_info_mp;
std::set<uint32_t> fl_client_ids;
uint32_t last_round_total_fl_clients_num = 0;
bool _is_all_clients_info_collected = false;
private:
std::mutex _mtx;
std::condition_variable _cv;
std::atomic<uint32_t> _fl_clients_count{0};
};
class CoordinatorService : public PsService {
public:
CoordinatorService() {
_coordinator_service_handle = std::make_shared<CoordinatorServiceHandle>();
}
virtual ~CoordinatorService() {}
virtual void Initialize() {
_service_handle_map[PUSH_FL_CLIENT_INFO_SYNC] =
std::bind(&CoordinatorService::SaveFLClientInfo,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
}
virtual void FLService(::google::protobuf::RpcController* controller,
const CoordinatorReqMessage* request,
CoordinatorResMessage* response,
::google::protobuf::Closure* done);
int32_t SaveFLClientInfo(const CoordinatorReqMessage& request,
CoordinatorResMessage* response,
brpc::Controller* cntl) {
_coordinator_service_handle->SaveFLClientInfo(request);
return 0;
}
void SetTotalFLClientsNum(uint32_t all_fl_clients_num) {
if (_coordinator_service_handle.get() != nullptr) {
_coordinator_service_handle->last_round_total_fl_clients_num =
all_fl_clients_num;
} else {
LOG(ERROR) << "fl-ps > _coordinator_service_handle is null in "
"CoordinatorService";
}
return;
}
std::set<uint32_t> GetFLClientIds() {
return _coordinator_service_handle->fl_client_ids;
}
std::unordered_map<uint32_t, std::string> QueryFLClientsInfo() {
return _coordinator_service_handle->QueryFLClientsInfo();
}
private:
std::shared_ptr<CoordinatorServiceHandle> _coordinator_service_handle;
std::unordered_map<int32_t, CoordinatorServiceFunc> _service_handle_map;
std::mutex _mtx;
};
class CoordinatorClient : public BrpcPsClient {
public:
CoordinatorClient() : _coordinator_id(0) {}
virtual ~CoordinatorClient() {}
int32_t Initialize(const std::vector<std::string>& trainer_endpoints);
void SetTotalFLClientsNum(uint32_t all_fl_clients_num) {
_service.SetTotalFLClientsNum(all_fl_clients_num);
this->_total_clients_num = all_fl_clients_num;
return;
}
int32_t StartClientService();
void SaveFLStrategy(
const std::unordered_map<uint32_t, std::string>& fl_strategy) {
for (auto it = fl_strategy.begin(); it != fl_strategy.end(); it++) {
uint32_t client_id = it->first;
_fl_strategy_mp[client_id] = it->second;
}
std::unique_lock<std::mutex> lck(_mtx);
_is_fl_strategy_ready = true;
_cv.notify_all();
return;
}
void WaitForFLStrategyReady() {
std::unique_lock<std::mutex> lck(_mtx);
_cv.wait(lck, [=]() { return _is_fl_strategy_ready; });
}
void SendFLStrategy(const uint32_t& client_id);
void ResetFLStrategyFlag() { _is_fl_strategy_ready = false; }
void SetDefaultFLStrategy() {
for (size_t i = 0; i < _total_clients_num; i++) {
_fl_strategy_mp[i] = "";
}
return;
}
std::set<uint32_t> GetFLClientIds() { return _service.GetFLClientIds(); }
std::unordered_map<uint32_t, std::string> QueryFLClientsInfo() {
return _service.QueryFLClientsInfo();
}
void SetEndpoint(const std::string& endpoint) {
_endpoint = std::move(endpoint);
}
public:
size_t _coordinator_id;
uint32_t _total_clients_num;
std::string _endpoint;
std::vector<std::array<std::shared_ptr<brpc::Channel>, 1>>
_pserver_channels; // coordinator2pserver
std::unordered_map<uint32_t, std::shared_ptr<brpc::Channel>>
_fl_client_channels; // coordinator2psclient
brpc::Server _server;
CoordinatorService _service;
std::unordered_map<uint32_t, std::string> _fl_strategy_mp;
bool _is_fl_strategy_ready = false;
std::mutex _mtx;
std::condition_variable _cv;
};
} // namespace distributed
} // namespace paddle
......@@ -65,7 +65,7 @@ struct PSHost {
s << "host: " << ip;
s << " port: " << port;
s << " rank: " << rank;
s << " uint: " << SerializeToUint64();
s << " uint64: " << SerializeToUint64();
return s.str();
}
......@@ -130,6 +130,7 @@ class PSEnvironment {
virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) {
return 0;
}
virtual uint64_t GetLocalHostSign() { return 0; }
virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
virtual int32_t RegistePsServer(const std::string &ip,
......@@ -145,6 +146,16 @@ class PSEnvironment {
return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
}
virtual std::vector<PSHost> GetCoordinators() const {
return _coordinator_list;
}
virtual int32_t RegisteCoordinatorClient(const std::string &ip,
uint32_t port,
int32_t rank) {
return RegistePsHost(
ip, port, rank, _coordinator_list, _coordinator_sign_set);
}
virtual std::vector<uint64_t> GetClientInfo() {
std::vector<uint64_t> client_info;
for (auto &i : _ps_client_list) {
......@@ -196,6 +207,9 @@ class PSEnvironment {
std::vector<PSHost> _ps_server_list;
std::unordered_set<uint64_t> _ps_server_sign_set; // for unique filter
std::vector<PSHost> _coordinator_list;
std::unordered_set<uint64_t> _coordinator_sign_set;
};
class PaddlePSEnvironment : public PSEnvironment {
......@@ -278,6 +292,22 @@ class PaddlePSEnvironment : public PSEnvironment {
return 0;
}
virtual void SetCoordinators(const std::vector<std::string> *host_sign_list,
size_t node_num) {
_coordinator_list.clear();
_coordinator_sign_set.clear();
for (size_t i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.ParseFromString(host_sign_list->at(i));
_coordinator_list.push_back(host);
_coordinator_sign_set.insert(host.rank);
VLOG(0) << "fl-ps > coordinator info in env: " << host.ToString();
}
}
return;
}
virtual uint64_t GetLocalHostSign() {
if (_ps_client_list.size() > 0) {
return _ps_client_list[0].SerializeToUint64();
......
......@@ -17,11 +17,11 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
DEFINE_int32(heter_world_size, 100, "group size"); // group max size
DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s");
namespace paddle {
namespace distributed {
std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr;
std::mutex HeterClient::mtx_;
std::shared_ptr<HeterClient> HeterClient::switch_s_instance_ = nullptr;
......
......@@ -39,10 +39,10 @@ namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
DECLARE_int32(pserver_timeout_ms);
namespace paddle {
namespace distributed {
DECLARE_int32(pserver_timeout_ms);
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
......
......@@ -52,14 +52,14 @@ class ProgramDesc;
class Scope;
} // namespace framework
} // namespace paddle
DECLARE_double(eager_delete_tensor_gb);
namespace paddle {
namespace distributed {
DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(heter_world_size);
DECLARE_int32(switch_send_recv_timeout_s);
namespace paddle {
namespace distributed {
using MultiVarMsg = MultiVariableMessage;
using VarMsg = VariableMessage;
......
......@@ -16,6 +16,7 @@
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
......@@ -25,8 +26,9 @@ namespace distributed {
REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
REGISTER_PSCORE_CLASS(PSClient, CoordinatorClient);
int32_t PSClient::Configure(
int32_t PSClient::Configure( // called in FleetWrapper::InitWorker
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env,
......
......@@ -321,14 +321,16 @@ class PSClient {
protected:
virtual int32_t Initialize() = 0;
size_t _client_id;
PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>>
_dense_pull_regions;
PSEnvironment *_env;
std::unordered_map<uint32_t, std::shared_ptr<ValueAccessor>> _table_accessors;
std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; // 处理client2client消息
public:
size_t _client_id;
PSEnvironment *_env;
};
template <class T>
......
......@@ -69,6 +69,8 @@ enum PsCmdID {
PS_CHECK_SAVE_PRE_PATCH_DONE = 48;
// pserver2pserver cmd start from 100
PS_S2S_MSG = 101;
PUSH_FL_CLIENT_INFO_SYNC = 200;
PUSH_FL_STRATEGY = 201;
}
message PsRequestMessage {
......@@ -85,6 +87,18 @@ message PsResponseMessage {
optional bytes data = 3;
};
message CoordinatorReqMessage {
required uint32 cmd_id = 1;
optional int32 client_id = 2;
optional string str_params = 3;
};
message CoordinatorResMessage {
required int32 err_code = 1 [ default = 0 ];
required string err_msg = 2 [ default = "" ];
optional string str_params = 3;
};
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
......@@ -134,6 +148,7 @@ message MultiVariableMessage {
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
rpc FLService(CoordinatorReqMessage) returns (CoordinatorResMessage);
rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
rpc SendToWorker(MultiVariableMessage) returns (PsResponseMessage);
rpc SendToSwitch(MultiVariableMessage) returns (PsResponseMessage);
......
文件模式从 100644 更改为 100755
......@@ -146,6 +146,34 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
}
}
void FleetWrapper::InitFlWorker(const std::vector<std::string>& host_list,
int index,
const std::string& self_endpoint) {
assert(worker_ptr_.get() != nullptr);
uint32_t coordinator_num = host_list.size();
ps_env_.SetCoordinators(&host_list, coordinator_num);
auto ptr = dynamic_cast<BrpcPsClient*>(worker_ptr_.get());
ptr->InitializeFlWorker(self_endpoint);
return;
}
void FleetWrapper::PushFLClientInfoSync(const std::string& fl_client_info) {
// FLClientInfo fci;
// google::protobuf::TextFormat::ParseFromString(fl_client_info, &fci);
// InitGFlag(fci.init_gflags());
auto ptr = dynamic_cast<BrpcPsClient*>(worker_ptr_.get());
VLOG(0) << "fl-ps > PushFLClientInfoSync: " << typeid(worker_ptr_).name()
<< ", " << typeid(ptr).name() << ", " << typeid(BrpcPsClient).name();
ptr->PushFLClientInfoSync(fl_client_info);
return;
}
std::string FleetWrapper::PullFlStrategy() {
auto ptr = dynamic_cast<BrpcPsClient*>(worker_ptr_.get());
std::string str = ptr->PullFlStrategy();
return str;
}
void FleetWrapper::StopServer() {
VLOG(3) << "Going to stop server";
auto status = worker_ptr_->StopServer();
......
......@@ -305,6 +305,14 @@ class FleetWrapper {
void Revert();
void CheckSavePrePatchDone();
//********* for fl-coordinator
void InitFlWorker(const std::vector<std::string>& host_list,
int index,
const std::string& self_endpoint);
void PushFLClientInfoSync(const std::string& fl_client_info);
std::string PullFlStrategy();
//**********
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_;
......
......@@ -239,3 +239,29 @@ message GraphFeature {
repeated string dtype = 2;
repeated int32 shape = 3;
}
message FLParameter {
optional FLStrategy fl_strategy = 1;
optional FLClientInfo client_info = 2;
}
message FLStrategy {
optional uint64 iteration_num = 1;
optional uint64 client_id = 2;
optional string next_state = 3 [default = "JOIN"];
optional string init_gflags = 4 [ default = "" ];
}
message FLClientInfo {
optional uint32 client_id = 1;
optional string device_type = 2;
optional int32 compute_capacity = 3;
optional int32 bandwidth = 4;
optional LocalTrainingResult local_training_result = 5;
optional string init_gflags = 6 [ default = "" ];
}
message LocalTrainingResult {
optional double acc = 1;
optional double loss = 2;
}
......@@ -163,13 +163,13 @@ void DeviceWorker::DumpField(const Scope& scope,
continue;
}
hit[i] = true;
}
} // dump_mode = 0
for (size_t i = 0; i < ins_id_vec.size(); i++) {
if (!hit[i]) {
continue;
}
ars[i] += ins_id_vec[i];
ars[i] = ars[i] + "\t" + ins_content_vec[i];
ars[i] += "\t" + ins_content_vec[i];
}
for (auto& field : *dump_fields_) {
Variable* var = scope.FindVar(field);
......@@ -202,8 +202,7 @@ void DeviceWorker::DumpField(const Scope& scope,
continue;
}
auto bound = GetTensorBound(tensor, i);
ars[i] = ars[i] + "\t" + field + ":" +
std::to_string(bound.second - bound.first);
ars[i] += "\t" + field + ":" + std::to_string(bound.second - bound.first);
ars[i] += PrintLodTensor(tensor, bound.first, bound.second);
}
}
......
......@@ -316,6 +316,7 @@ message DistributedStrategy {
optional bool auto_search = 37 [ default = false ];
optional bool heter_ccl_mode = 38 [ default = false ];
optional bool is_fl_ps_mode = 39 [ default = false ];
optional bool with_coordinator = 40 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......
......@@ -254,7 +254,6 @@ void MultiTrainer::Finalize() {
if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv();
}
for (size_t i = 0; i < need_merge_var_names_.size(); i++) {
Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]);
if (root_var == nullptr) {
......@@ -297,8 +296,12 @@ void MultiTrainer::Finalize() {
if (communicator == nullptr) {
VLOG(0) << "MultiTrainer::Finalize communicator is null!";
} else {
communicator->_worker_ptr->Flush();
VLOG(1) << "MultiTrainer::Finalize ps client flush done";
if (communicator->_worker_ptr != nullptr) {
communicator->_worker_ptr->Flush();
VLOG(1) << "MultiTrainer::Finalize ps client flush done";
} else {
VLOG(0) << "communicator->_worker_ptr is null";
}
}
#endif
root_scope_->DropKids();
......
......@@ -76,6 +76,9 @@ void BindDistFleetWrapper(py::module* m) {
.def("get_cache_threshold", &FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &FleetWrapper::CacheShuffle)
.def("save_cache", &FleetWrapper::SaveCache)
.def("init_fl_worker", &FleetWrapper::InitFlWorker)
.def("push_fl_client_info_sync", &FleetWrapper::PushFLClientInfoSync)
.def("pull_fl_strategy", &FleetWrapper::PullFlStrategy)
.def("revert", &FleetWrapper::Revert)
.def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone);
}
......@@ -131,6 +134,7 @@ void BindCommunicatorContext(py::module* m) {
}
using paddle::distributed::AsyncCommunicator;
using paddle::distributed::FLCommunicator;
using paddle::distributed::GeoCommunicator;
using paddle::distributed::RecvCtxMap;
using paddle::distributed::RpcCtxMap;
......@@ -157,6 +161,9 @@ void BindDistCommunicator(py::module* m) {
} else if (mode == "GEO") {
Communicator::InitInstance<GeoCommunicator>(
send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs);
} else if (mode == "WITH_COORDINATOR") {
Communicator::InitInstance<FLCommunicator>(
send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"unsuported communicator MODE"));
......@@ -172,7 +179,10 @@ void BindDistCommunicator(py::module* m) {
.def("create_client_to_client_connection",
&Communicator::CreateC2CConnection)
.def("get_client_info", &Communicator::GetClientInfo)
.def("set_clients", &Communicator::SetClients);
.def("set_clients", &Communicator::SetClients)
.def("start_coordinator", &Communicator::StartCoordinator)
.def("query_fl_clients_info", &Communicator::QueryFLClientsInfo)
.def("save_fl_strategy", &Communicator::SaveFLStrategy);
}
void BindHeterClient(py::module* m) {
......
......@@ -57,6 +57,10 @@ world_device_ids = fleet.world_device_ids
local_rank = fleet.local_rank
rank_in_node = local_rank
is_worker = fleet.is_worker
is_coordinator = fleet.is_coordinator
init_coordinator = fleet.init_coordinator
make_fl_strategy = fleet.make_fl_strategy
get_fl_client = fleet.get_fl_client
worker_endpoints = fleet.worker_endpoints
server_num = fleet.server_num
server_index = fleet.server_index
......
......@@ -1348,6 +1348,18 @@ class DistributedStrategy(object):
else:
print("WARNING: is_fl_ps_mode should have value of bool type")
@property
def is_with_coordinator(self):
return self.strategy.with_coordinator
@is_with_coordinator.setter
@is_strict_auto
def is_with_coordinator(self, flag):
if isinstance(flag, bool):
self.strategy.with_coordinator = flag
else:
print("WARNING: with_coordinator should have value of bool type")
@pipeline.setter
@is_strict_auto
def pipeline(self, flag):
......
......@@ -511,6 +511,9 @@ class Fleet(object):
"""
return self._role_maker._is_worker()
def is_coordinator(self):
return self._role_maker._is_coordinator()
def worker_endpoints(self, to_string=False):
"""
Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
......@@ -642,6 +645,25 @@ class Fleet(object):
"""
self._runtime_handle._init_worker(scopes)
@is_non_distributed_check
@inited_runtime_handler
def init_coordinator(self, scopes=None):
"""
initialize coordinator node
"""
self._runtime_handle._init_coordinator(scopes)
def make_fl_strategy(self):
self._runtime_handle._make_fl_strategy()
@is_non_distributed_check
@inited_runtime_handler
def get_fl_client(self):
"""
get worker(training node) ptr
"""
return self._runtime_handle._worker
@is_non_distributed_check
@inited_runtime_handler
def init_server(self, *args, **kwargs):
......
......@@ -30,6 +30,7 @@ class Role:
SERVER = 2
HETER_WORKER = 3
ALL = 4
COORDINATOR = 5
class Gloo(object):
......@@ -376,6 +377,7 @@ class RoleMakerBase(object):
def __init__(self):
self._worker_endpoints = []
self._server_endpoints = []
self._cur_endpoint = ""
self._role_is_generated = False
self._role = None
self._current_id = -1
......@@ -544,6 +546,8 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._server_endpoints = []
self._worker_endpoints = []
self._coordinator_endpoints = None
self._with_coordinator = False
self._gloo = Gloo() # gloo instance
......@@ -612,6 +616,11 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._generate_role()
return self._role == Role.SERVER
def _is_coordinator(self):
if not self._role_is_generated:
self._generate_role()
return self._role == Role.COORDINATOR
def _is_first_worker(self):
"""
whether current process is worker of rank 0
......@@ -734,6 +743,11 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._generate_role()
return self._server_endpoints
def _get_coordinator_endpoints(self):
if not self._role_is_generated:
self._generate_role()
return self._coordinator_endpoints
def _get_previous_trainers(self):
"""
invoked by heter worker
......@@ -781,7 +795,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._generate_role()
return self._role == Role.HETER_WORKER
def _ps_env(self):
def _ps_env(self): # each role will execute it
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None)
......@@ -806,6 +820,14 @@ class PaddleCloudRoleMaker(RoleMakerBase):
else:
self._worker_endpoints = []
self._coordinator_endpoints = os.getenv("PADDLE_COORDINATOR_ENDPOINTS",
"")
if self._coordinator_endpoints == "":
print("fl-ps > coordinator address is null!")
else:
self._with_coordinator = True
self._coordinator_endpoints = self._coordinator_endpoints.split(",")
trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None)
if trainers_num == None:
raise ValueError(
......@@ -818,9 +840,11 @@ class PaddleCloudRoleMaker(RoleMakerBase):
raise ValueError(
"Can not find TRAINING_ROLE, please check your environment.")
if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]:
if training_role not in [
"TRAINER", "PSERVER", "HETER_TRAINER", "COORDINATOR"
]:
raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment."
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER or COORDINATOR, but get {}, please check your environment."
.format(training_role))
# For Heter Parameter Server env setting
......@@ -862,29 +886,10 @@ class PaddleCloudRoleMaker(RoleMakerBase):
"Can not Find PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
)
#self._is_heter_parameter_server_mode = True
#heter_trainers_num = len(all_heter_trainer_eplist.split(","))
#self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",")
else:
self._is_heter_parameter_server_mode = False
self._heter_trainers_num = 0
#if previous_heter_trainer_eplist == "":
# self._is_heter_parameter_server_mode = False
# heter_trainers_num = 0
#else: ## for the last heter worker
# try:
# previous_heter_trainer_eplist = os.environ[
# "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"].split(",")
# self._previous_heter_trainer_endpoints = previous_heter_trainer_eplist
# except:
# raise ValueError(
# "Can not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
# )
# self._is_heter_parameter_server_mode = True
# heter_trainers_num = len(all_heter_trainer_eplist.split(","))
# self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",")
if training_role == "TRAINER":
role = Role.WORKER
current_id = os.getenv("PADDLE_TRAINER_ID", None)
......@@ -922,6 +927,10 @@ class PaddleCloudRoleMaker(RoleMakerBase):
"Can not find POD_IP, please check your environment.")
curr_endpoint = ":".join([cur_ip, cur_port])
self._cur_endpoint = curr_endpoint
elif training_role == "COORDINATOR":
print(">>> curr node is coordinator!")
role = Role.COORDINATOR
current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
elif training_role == "PSERVER":
role = Role.SERVER
cur_port = os.getenv("PADDLE_PORT", None)
......
......@@ -211,6 +211,10 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
type=str,
default="",
help="User defined workers ip:port")
ps_group.add_argument("--coordinators",
type=str,
default="",
help="User defined coordinators ip:port")
ps_group.add_argument(
"--heter_workers",
type=str,
......@@ -223,6 +227,9 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
help="User defined heter devices in each stage cpu;gpu;cpu")
ps_group.add_argument("--worker_num", type=int, help="number of workers")
ps_group.add_argument("--coordinator_num",
type=int,
help="number of coordinators")
ps_group.add_argument("--server_num", type=int, help="number of servers")
ps_group.add_argument("--heter_worker_num",
type=str,
......@@ -474,6 +481,8 @@ def which_distributed_mode(args):
ps_heter_args = ["--heter_worker_num", "--heter_workers", "--heter_devices"]
coordinator_args = ["--coordinator_num", "--coordinators"]
has_ps_args = [
ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1])
]
......@@ -503,6 +512,7 @@ def which_distributed_mode(args):
"Run parameter-sever mode. pserver arguments:{}, accelerators count:{}"
.format(has_ps_args, accelerators))
has_ps_heter_args = list(set(has_ps_args) & set(ps_heter_args))
has_coordinator_args = list(set(has_ps_args) & set(coordinator_args))
if len(has_ps_heter_args) > 0:
return DistributeMode.PS_HETER
else:
......
......@@ -189,17 +189,19 @@ class Pod(object):
self.trainers = []
self.servers = []
self.workers = []
self.coordinators = []
self.heter_workers = []
self.accelerators = []
self.device_mode = None
def __str__(self):
return "rank:{} id:{} addr:{} port:{} visible_accelerator:{} trainers:{} servers:{} \
workers:{} heter_workers:{}".format(
workers:{} heter_workers:{} coordinators:{}".format(
self.rank, self.id, self.addr, self.port, self.accelerators,
[str(t) for t in self.trainers], [str(s) for s in self.servers],
[str(w)
for w in self.workers], [str(h) for h in self.heter_workers])
for w in self.workers], [str(h) for h in self.heter_workers],
[str(c) for c in self.coordinators])
def __eq__(self, pod):
if self.rank != pod.rank or \
......@@ -1172,9 +1174,11 @@ class ParameterServerLauncher(object):
def __init__(self, args, distribute_mode):
self.args = args
self.distribute_mode = distribute_mode
self.with_coordinator = False
self.server_num = 0
self.worker_num = 0
self.heter_worker_num = 0
self.coordinator_num = 0
self.server_endpoints = ""
self.server_endpoints_ips = []
......@@ -1188,6 +1192,10 @@ class ParameterServerLauncher(object):
self.heter_worker_endpoints_ips = []
self.heter_worker_endpoints_port = []
self.coordinator_endpoints = ""
self.coordinator_endpoints_ips = []
self.coordinator_endpoints_port = []
self.is_local = True
self.current_node_ip = ""
......@@ -1257,6 +1265,23 @@ class ParameterServerLauncher(object):
else:
self.worker_endpoints = args.workers
# get coordinator envs
if args.coordinator_num:
self.with_coordinator = True
self.coordinator_num = args.coordinator_num
if args.coordinators:
assert len(
args.coordinators.split(",")
) == self.coordinator_num, "The coordinator_num and coordinators doesn't match. Expect coordinators endpoints num epual to coordinator_num, but received coordinator enpoint num: {} and coordinator_num {}".format(
len(args.coordinators.split(",")), self.coordinator_num)
self.coordinator_endpoints = args.coordinators
else:
ports = get_ports(self.coordinator_num, 1)
self.coordinator_endpoints = ",".join(
["127.0.0.1:" + str(x) for x in ports])
print(">>> use default coordinator addr(only one process)")
# get heter worker envs
if self.distribute_mode == DistributeMode.PS_HETER:
assert args.heter_devices != "", "The setting of Parameter-Server heter mode must has heter_devices."
......@@ -1398,6 +1423,17 @@ class ParameterServerLauncher(object):
self.worker_endpoints_ips = [
x.strip().split(":")[0] for x in self.worker_endpoints.split(",")
]
if self.with_coordinator == True:
self.coordinator_endpoints_ips = [
x.strip().split(":")[0]
for x in self.coordinator_endpoints.split(",")
]
self.coordinator_endpoints_port = [
x.strip().split(":")[1]
for x in self.coordinator_endpoints.split(",")
]
self.server_endpoints_port = [
x.strip().split(":")[1] for x in self.server_endpoints.split(",")
]
......@@ -1451,6 +1487,7 @@ class ParameterServerLauncher(object):
server_rank = 0
worker_rank = 0
heter_worker_rank = 0
coordinator_rank = 0
for node_rank, ip in enumerate(self.node_ips):
pod = Pod()
pod.rank = node_rank
......@@ -1472,6 +1509,16 @@ class ParameterServerLauncher(object):
worker.stage = 1
worker_rank += 1
pod.workers.append(worker)
for m in range(len(self.coordinator_endpoints_ips)):
if ip == self.coordinator_endpoints_ips[m]:
coordinator = Trainer()
coordinator.endpoint = "%s:%s" % (
ip, self.coordinator_endpoints_port[m])
coordinator.rank = coordinator_rank
coordinator.stage = 1
coordinator_rank += 1
pod.coordinators.append(coordinator)
for k in range(len(self.heter_worker_endpoints_ips)):
if ip == self.heter_worker_endpoints_ips[k]:
heter_worker = Trainer()
......@@ -1488,18 +1535,36 @@ class ParameterServerLauncher(object):
self.gloo_rendezvous_dir = tempfile.mkdtemp()
# 3. subproces start
self.procs = {"worker": [], "server": [], "heter_worker": []}
self.cmds = {"worker": [], "server": [], "heter_worker": []}
self.log_fns = {"worker": [], "server": [], "heter_worker": []}
self.procs = {
"worker": [],
"coordinator": [],
"server": [],
"heter_worker": []
}
self.cmds = {
"worker": [],
"coordinator": [],
"server": [],
"heter_worker": []
}
self.log_fns = {
"worker": [],
"coordinator": [],
"server": [],
"heter_worker": []
}
self.start_pod_server(self.args, pod)
self.start_pod_worker(self.args, pod)
if self.with_coordinator:
self.start_pod_coordinator(self.args, pod)
if self.distribute_mode == DistributeMode.PS_HETER:
self.start_pod_heter_worker(self.args, pod)
logger.info(
"Please check servers, workers and heter_worker logs in {}/workerlog.*, {}/serverlog.* and {}/heterlog.*"
.format(self.args.log_dir, self.args.log_dir, self.args.log_dir))
"Please check servers, workers, coordinator and heter_worker logs in {}/workerlog.*, {}/serverlog.* , {}/coordinatorlog.*, and {}/heterlog.*"
.format(self.args.log_dir, self.args.log_dir, self.args.log_dir,
self.args.log_dir))
# 4. wait for finish training
if len(self.procs["worker"]) > 0:
......@@ -1524,6 +1589,12 @@ class ParameterServerLauncher(object):
self.procs["server"][i].proc.terminate()
logger.info("all parameter server are killed")
if len(self.procs["coordinator"]) > 0:
for i, proc in enumerate(self.procs["coordinator"]):
self.log_fns["coordinator"][i].close()
self.procs["coordinator"][i].proc.terminate()
logger.info("all coordinators are killed")
else:
# if node has not worker procs
# blocking training process
......@@ -1548,6 +1619,7 @@ class ParameterServerLauncher(object):
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_COORDINATOR_ENDPOINTS": self.coordinator_endpoints,
"PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints,
"PADDLE_PORT": cur_server.endpoint.split(":")[1],
......@@ -1563,6 +1635,7 @@ class ParameterServerLauncher(object):
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_COORDINATOR_ENDPOINTS": self.coordinator_endpoints,
"PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(self.worker_num),
......@@ -1633,6 +1706,8 @@ class ParameterServerLauncher(object):
self.worker_endpoints,
"PADDLE_TRAINERS_NUM":
str(self.worker_num),
"PADDLE_COORDINATOR_ENDPOINTS":
self.coordinator_endpoints,
"PADDLE_STAGE_TRAINERS_NUM":
str(self.stage_trainer_num),
"STAGE_ID":
......@@ -1678,6 +1753,7 @@ class ParameterServerLauncher(object):
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"TRAINING_ROLE": "TRAINER",
"PADDLE_COORDINATOR_ENDPOINTS": self.coordinator_endpoints,
"POD_IP": cur_worker.endpoint.split(":")[0],
"PADDLE_PORT": cur_worker.endpoint.split(":")[1],
"PADDLE_TRAINER_ID": str(cur_worker.rank),
......@@ -1725,6 +1801,69 @@ class ParameterServerLauncher(object):
self.procs["worker"].append(tp)
def start_pod_coordinator(self, args, pod):
print(">>> entering start_pod_coordinator")
default_env = os.environ.copy()
current_env = copy.copy(default_env)
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
for idx, cur_coordinator in enumerate(pod.coordinators):
device_id = "0"
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"PADDLE_COORDINATOR_ENDPOINTS": self.coordinator_endpoints,
"PADDLE_COORDINATOR_NUM": str(self.coordinator_num),
"TRAINING_ROLE": "COORDINATOR",
"POD_IP": cur_coordinator.endpoint.split(":")[0],
"PADDLE_PORT": cur_coordinator.endpoint.split(":")[1],
"PADDLE_TRAINER_ID": str(cur_coordinator.rank),
"PADDLE_WITH_GLOO": str(os.getenv("PADDLE_WITH_GLOO", "0")),
"PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
}
current_env.update(proc_env)
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
self.cmds["coordinator"].append(cmd)
if idx == 0:
logger.info(
"Local coordinator start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.coordinators),
pretty_print_envs(proc_env,
("Distributed Envs", "Value"))))
if args.log_dir is not None:
os.system("mkdir -p {}".format(args.log_dir))
fn = open("%s/coordinator.%d" % (args.log_dir, idx), "w")
self.log_fns["coordinator"].append(fn)
proc = subprocess.Popen(cmd,
env=current_env,
stdout=fn,
stderr=fn)
else:
proc = subprocess.Popen(cmd, env=current_env)
tp = TrainerProc()
tp.proc = proc
tp.rank = cur_coordinator.rank
tp.local_rank = idx
tp.log_fn = fn
tp.log_offset = fn.tell() if fn else None
tp.cmd = cmd
self.procs["coordinator"].append(tp)
def start_pod_heter_worker(self, args, pod):
default_env = os.environ.copy()
current_env = copy.copy(default_env)
......
......@@ -78,6 +78,8 @@ class ParameterServerOptimizer(MetaOptimizerBase):
attrs['lr_decay_steps'] = self.user_defined_strategy.a_sync_configs[
"lr_decay_steps"]
attrs['is_fl_ps_mode'] = self.user_defined_strategy.is_fl_ps_mode
attrs[
'with_coordinator'] = self.user_defined_strategy.is_with_coordinator
attrs['k_steps'] = self.user_defined_strategy.a_sync_configs["k_steps"]
attrs['launch_barrier'] = self.user_defined_strategy.a_sync_configs[
"launch_barrier"]
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.communicator import FLCommunicator
from paddle.distributed.fleet.proto import the_one_ps_pb2
from google.protobuf import text_format
from paddle.distributed.ps.utils.public import is_distributed_env
from paddle.distributed import fleet
import time
import abc
import os
import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s',
level=logging.DEBUG)
logger = logging.getLogger(__name__)
class ClientInfoAttr:
CLIENT_ID = 0
DEVICE_TYPE = 1
COMPUTE_CAPACITY = 2
BANDWIDTH = 3
class FLStrategy:
JOIN = 0
WAIT = 1
FINISH = 2
class ClientSelectorBase(abc.ABC):
def __init__(self, fl_clients_info_mp):
self.fl_clients_info_mp = fl_clients_info_mp
self.clients_info = {}
self.fl_strategy = {}
def parse_from_string(self):
if not self.fl_clients_info_mp:
logger.warning("fl-ps > fl_clients_info_mp is null!")
for client_id, info in self.fl_clients_info_mp.items():
self.fl_client_info_desc = the_one_ps_pb2.FLClientInfo()
text_format.Parse(bytes(info, encoding="utf8"),
self.fl_client_info_desc)
self.clients_info[client_id] = {}
self.clients_info[client_id][
ClientInfoAttr.
DEVICE_TYPE] = self.fl_client_info_desc.device_type
self.clients_info[client_id][
ClientInfoAttr.
COMPUTE_CAPACITY] = self.fl_client_info_desc.compute_capacity
self.clients_info[client_id][
ClientInfoAttr.BANDWIDTH] = self.fl_client_info_desc.bandwidth
@abc.abstractmethod
def select(self):
pass
class ClientSelector(ClientSelectorBase):
def __init__(self, fl_clients_info_mp):
super().__init__(fl_clients_info_mp)
self.__fl_strategy = {}
def select(self):
self.parse_from_string()
for client_id in self.clients_info:
logger.info("fl-ps > client {} info : {}".format(
client_id, self.clients_info[client_id]))
# ......... to implement ...... #
fl_strategy_desc = the_one_ps_pb2.FLStrategy()
fl_strategy_desc.iteration_num = 99
fl_strategy_desc.client_id = 0
fl_strategy_desc.next_state = "JOIN"
str_msg = text_format.MessageToString(fl_strategy_desc)
self.__fl_strategy[client_id] = str_msg
return self.__fl_strategy
class FLClientBase(abc.ABC):
def __init__(self):
pass
def set_basic_config(self, role_maker, config, metrics):
self.role_maker = role_maker
self.config = config
self.total_train_epoch = int(self.config.get("runner.epochs"))
self.train_statical_info = dict()
self.train_statical_info['speed'] = []
self.epoch_idx = 0
self.worker_index = fleet.worker_index()
self.main_program = paddle.static.default_main_program()
self.startup_program = paddle.static.default_startup_program()
self._client_ptr = fleet.get_fl_client()
self._coordinators = self.role_maker._get_coordinator_endpoints()
logger.info("fl-ps > coordinator enpoints: {}".format(
self._coordinators))
self.strategy_handlers = dict()
self.exe = None
self.use_cuda = int(self.config.get("runner.use_gpu"))
self.place = paddle.CUDAPlace(0) if self.use_cuda else paddle.CPUPlace()
self.print_step = int(self.config.get("runner.print_interval"))
self.debug = self.config.get("runner.dataset_debug", False)
self.reader_type = self.config.get("runner.reader_type", "QueueDataset")
self.set_executor()
self.make_save_model_path()
self.set_metrics(metrics)
def set_train_dataset_info(self, train_dataset, train_file_list):
self.train_dataset = train_dataset
self.train_file_list = train_file_list
logger.info("fl-ps > {}, data_feed_desc:\n {}".format(
type(self.train_dataset), self.train_dataset._desc()))
def set_test_dataset_info(self, test_dataset, test_file_list):
self.test_dataset = test_dataset
self.test_file_list = test_file_list
def set_train_example_num(self, num):
self.train_example_nums = num
def load_dataset(self):
if self.reader_type == "InmemoryDataset":
self.train_dataset.load_into_memory()
def release_dataset(self):
if self.reader_type == "InmemoryDataset":
self.train_dataset.release_memory()
def set_executor(self):
self.exe = paddle.static.Executor(self.place)
def make_save_model_path(self):
self.save_model_path = self.config.get("runner.model_save_path")
if self.save_model_path and (not os.path.exists(self.save_model_path)):
os.makedirs(self.save_model_path)
def set_dump_fields(self):
# DumpField
# TrainerDesc -> SetDumpParamVector -> DumpParam -> DumpWork
if self.config.get("runner.need_dump"):
self.debug = True
dump_fields_path = "{}/epoch_{}".format(
self.config.get("runner.dump_fields_path"), self.epoch_idx)
dump_fields = self.config.get("runner.dump_fields", [])
dump_param = self.config.get("runner.dump_param", [])
persist_vars_list = self.main_program.all_parameters()
persist_vars_name = [
str(param).split(":")[0].strip().split()[-1]
for param in persist_vars_list
]
logger.info(
"fl-ps > persist_vars_list: {}".format(persist_vars_name))
if dump_fields_path is not None:
self.main_program._fleet_opt[
'dump_fields_path'] = dump_fields_path
if dump_fields is not None:
self.main_program._fleet_opt["dump_fields"] = dump_fields
if dump_param is not None:
self.main_program._fleet_opt["dump_param"] = dump_param
def set_metrics(self, metrics):
self.metrics = metrics
self.fetch_vars = [var for _, var in self.metrics.items()]
class FLClient(FLClientBase):
def __init__(self):
super(FLClient, self).__init__()
def __build_fl_client_info_desc(self, state_info):
# ......... to implement ...... #
state_info = {
ClientInfoAttr.DEVICE_TYPE: "Andorid",
ClientInfoAttr.COMPUTE_CAPACITY: 10,
ClientInfoAttr.BANDWIDTH: 100
}
client_info = the_one_ps_pb2.FLClientInfo()
client_info.device_type = state_info[ClientInfoAttr.DEVICE_TYPE]
client_info.compute_capacity = state_info[
ClientInfoAttr.COMPUTE_CAPACITY]
client_info.bandwidth = state_info[ClientInfoAttr.BANDWIDTH]
str_msg = text_format.MessageToString(client_info)
return str_msg
def run(self):
self.register_default_handlers()
self.print_program()
self.strategy_handlers['initialize_model_params']()
self.strategy_handlers['init_worker']()
self.load_dataset()
self.train_loop()
self.release_dataset()
self.strategy_handlers['finish']()
def train_loop(self):
while self.epoch_idx < self.total_train_epoch:
logger.info("fl-ps > curr epoch idx: {}".format(self.epoch_idx))
self.strategy_handlers['train']()
self.strategy_handlers['save_model']()
self.barrier()
state_info = {
"client id": self.worker_index,
"auc": 0.9,
"epoch": self.epoch_idx
}
self.push_fl_client_info_sync(state_info)
strategy_dict = self.pull_fl_strategy()
logger.info("fl-ps > recved fl strategy: {}".format(strategy_dict))
# ......... to implement ...... #
if strategy_dict['next_state'] == "JOIN":
self.strategy_handlers['infer']()
elif strategy_dict['next_state'] == "FINISH":
self.strategy_handlers['finish']()
def push_fl_client_info_sync(self, state_info):
str_msg = self.__build_fl_client_info_desc(state_info)
self._client_ptr.push_fl_client_info_sync(str_msg)
return
def pull_fl_strategy(self):
strategy_dict = {}
fl_strategy_str = self._client_ptr.pull_fl_strategy(
) # block: wait for coordinator's strategy arrived
logger.info("fl-ps > fl client recved fl_strategy(str):\n{}".format(
fl_strategy_str))
fl_strategy_desc = the_one_ps_pb2.FLStrategy()
text_format.Parse(bytes(fl_strategy_str, encoding="utf8"),
fl_strategy_desc)
strategy_dict["next_state"] = fl_strategy_desc.next_state
return strategy_dict
def barrier(self):
fleet.barrier_worker()
def register_handlers(self, strategy_type, callback_func):
self.strategy_handlers[strategy_type] = callback_func
def register_default_handlers(self):
self.register_handlers('train', self.callback_train)
self.register_handlers('infer', self.callback_infer)
self.register_handlers('finish', self.callback_finish)
self.register_handlers('initialize_model_params',
self.callback_initialize_model_params)
self.register_handlers('init_worker', self.callback_init_worker)
self.register_handlers('save_model', self.callback_save_model)
def callback_init_worker(self):
fleet.init_worker()
def callback_initialize_model_params(self):
if self.exe == None or self.main_program == None:
raise AssertionError("exe or main_program not set")
self.exe.run(self.startup_program)
def callback_train(self):
epoch_start_time = time.time()
self.set_dump_fields()
fetch_info = [
"Epoch {} Var {}".format(self.epoch_idx, var_name)
for var_name in self.metrics
]
self.exe.train_from_dataset(program=self.main_program,
dataset=self.train_dataset,
fetch_list=self.fetch_vars,
fetch_info=fetch_info,
print_period=self.print_step,
debug=self.debug)
self.epoch_idx += 1
epoch_time = time.time() - epoch_start_time
epoch_speed = self.train_example_nums / epoch_time
self.train_statical_info["speed"].append(epoch_speed)
logger.info("fl-ps > callback_train finished")
def callback_infer(self):
fetch_info = [
"Epoch {} Var {}".format(self.epoch_idx, var_name)
for var_name in self.metrics
]
self.exe.infer_from_dataset(program=self.main_program,
dataset=self.test_dataset,
fetch_list=self.fetch_vars,
fetch_info=fetch_info,
print_period=self.print_step,
debug=self.debug)
def callback_save_model(self):
model_dir = "{}/{}".format(self.save_model_path, self.epoch_idx)
if fleet.is_first_worker() and self.save_model_path:
if is_distributed_env():
fleet.save_persistables(self.exe, model_dir) # save all params
else:
raise ValueError("it is not distributed env")
def callback_finish(self):
fleet.stop_worker()
def print_program(self):
with open("./{}_worker_main_program.prototxt".format(self.worker_index),
'w+') as f:
f.write(str(self.main_program))
with open(
"./{}_worker_startup_program.prototxt".format(
self.worker_index), 'w+') as f:
f.write(str(self.startup_program))
def print_train_statical_info(self):
with open("./train_statical_info.txt", 'w+') as f:
f.write(str(self.train_statical_info))
class Coordinator(object):
def __init__(self, ps_hosts):
self._communicator = FLCommunicator(ps_hosts)
self._client_selector = None
def start_coordinator(self, self_endpoint, trainer_endpoints):
self._communicator.start_coordinator(self_endpoint, trainer_endpoints)
def make_fl_strategy(self):
logger.info("fl-ps > running make_fl_strategy(loop) in coordinator\n")
while True:
# 1. get all fl clients reported info
str_map = self._communicator.query_fl_clients_info(
) # block: wait for all fl clients info reported
# 2. generate fl strategy
self._client_selector = ClientSelector(str_map)
fl_strategy = self._client_selector.select()
# 3. save fl strategy from python to c++
self._communicator.save_fl_strategy(fl_strategy)
time.sleep(5)
......@@ -29,6 +29,7 @@ from paddle.distributed.fleet.base.private_helper_function import wait_server_re
from paddle.distributed.fleet.proto import the_one_ps_pb2
from paddle.fluid.communicator import Communicator, HeterClient
from google.protobuf import text_format
from paddle.distributed.ps.coordinator import Coordinator
__all__ = [
'Table', 'SparseTable', 'GeoSparseTable', 'BarrierTable', 'TensorTable',
......@@ -774,6 +775,7 @@ class PsDescBuilder(object):
self.fs_client = self._get_fs_client()
self.ps_desc = the_one_ps_pb2.PSParameter()
self.fl_desc = the_one_ps_pb2.FLParameter()
def _get_tensor_tables(self):
program_idx = 0
......@@ -812,6 +814,9 @@ class PsDescBuilder(object):
def _get_fs_client(self):
return fsClient(self.context["user_defined_strategy"].fs_client_param)
def build_fl_client_desc(self, client_info):
pass
def build_worker_desc(self):
for table in self.tables:
table_proto = self.ps_desc.worker_param.downpour_worker_param.downpour_table_param.add(
......@@ -850,6 +855,7 @@ class TheOnePSRuntime(RuntimeBase):
self._communicator = None
self._server = None
self._worker = fluid.core.DistFleetWrapper()
self._coordinator = None
self._server_sub_program = []
self._heter_client = None
self._send_ctx = None
......@@ -857,6 +863,8 @@ class TheOnePSRuntime(RuntimeBase):
def _set_basic_info(self, context):
self.context = context
self.role_maker = context["role_maker"]
self.role_id = get_role_id(self.role_maker)
self.debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
self.origin_main_program = context["origin_main_program"]
self.origin_main_programs = context.get("origin_main_programs",
......@@ -878,6 +886,8 @@ class TheOnePSRuntime(RuntimeBase):
self.context['tensor_table'] = {}
build_var_distributed(self.context)
self.trainer_endpoints = get_trainer_endpoints(self.role_maker)
self.endpoints = get_ps_endpoints(self.role_maker)
self.string_hosts = []
for idx, ep in enumerate(self.endpoints):
......@@ -885,6 +895,16 @@ class TheOnePSRuntime(RuntimeBase):
pshost = fluid.core.PSHost(host, int(port), idx)
self.string_hosts.append(pshost.serialize_to_string())
self.with_coordinator = self.role_maker._with_coordinator
self.coordinator_hosts = []
if self.with_coordinator:
print("fl-ps > all ps addrs: {}".format(self.string_hosts))
coordinator_endpoints = self.role_maker._get_coordinator_endpoints()
for idx, ep in enumerate(coordinator_endpoints):
ip, port = ep.split(":")
pshost = fluid.core.PSHost(ip, int(port), idx)
self.coordinator_hosts.append(pshost.serialize_to_string())
self.ps_desc_builder = PsDescBuilder(self.context)
def _init_all_params(self, scopes, send_ctx, recv_map):
......@@ -933,8 +953,6 @@ class TheOnePSRuntime(RuntimeBase):
def _init_worker(self, scopes=None):
worker_desc = self.ps_desc_builder.build_worker_desc()
#with open("test_fl_ps_worker_desc", "w") as f:
# f.write(worker_desc)
if self.context['use_ps_gpu']:
main_program = self.context['loss'].block.program
if not main_program._fleet_opt:
......@@ -963,10 +981,8 @@ class TheOnePSRuntime(RuntimeBase):
self._send_ctx = send_ctx
trainer_config = self.context['trainer']
proto_txt = worker_desc
debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
if debug:
print("worker: \n{}".format(proto_txt))
if self.debug:
print("worker_desc: \n{}".format(worker_desc))
print("communicator send_ctx:")
for key in send_ctx:
print("{}: {}".format(key, send_ctx[key]))
......@@ -986,15 +1002,21 @@ class TheOnePSRuntime(RuntimeBase):
print("communicator config:", trainer_config.get_communicator_flags())
role_id = get_role_id(self.role_maker)
self._worker.init_worker(proto_txt, self.string_hosts, role_id)
self._worker.init_worker(worker_desc, self.string_hosts, self.role_id)
self.trainer_endpoint = get_trainer_endpoint(self.role_maker)
print("fl-ps > trainer_endpoint: {}".format(self.trainer_endpoint))
print("fl-ps > with_coordinator? {}".format(self.with_coordinator))
print("fl-ps > coordinator addr: {}".format(self.coordinator_hosts))
if self.with_coordinator:
self._worker.init_fl_worker(self.coordinator_hosts, self.role_id,
self.trainer_endpoint)
if self.context[
'ps_mode'] == DistributedMode.GEO or self.is_heter_ps_mode:
self._communicator = Communicator(
trainer_config.mode, kwargs,
trainer_config.get_communicator_flags())
self._communicator.init_with_ctx(send_ctx, dense_map, proto_txt,
self._communicator.init_with_ctx(send_ctx, dense_map, worker_desc,
self.string_hosts,
fluid.global_scope())
fleet.util.barrier()
......@@ -1002,7 +1024,8 @@ class TheOnePSRuntime(RuntimeBase):
# info = self._communicator.get_client_info()
info = self._worker.get_client_info()
if isinstance(info, list) and len(info) > 0:
all_info = self.role_maker._all_gather(info[0])
all_info = self.role_maker._all_gather(
info[0]) # 收集其他 client 的 service 地址
# for unittest
if not isinstance(all_info, list):
warnings.warn("gloo may not initialize correctly")
......@@ -1045,7 +1068,7 @@ class TheOnePSRuntime(RuntimeBase):
self._communicator.init_params(init_params)
else:
if not self.context['use_ps_gpu']:
if role_id == 0:
if self.role_id == 0:
print("entering self._init_all_params()")
self._init_all_params(scopes, send_ctx, dense_map)
......@@ -1080,21 +1103,32 @@ class TheOnePSRuntime(RuntimeBase):
next_trainers, previous_trainers,
self.role_maker._role_id()) # --> HeterClient::GetInstance
def _init_coordinator(self, scopes=None):
if self._coordinator == None:
self._coordinator = Coordinator(self.string_hosts)
print(">>> curr node ip: {}".format(self.coordinator_hosts[0]))
print(">>> all trainer endpoints: {}".format(self.trainer_endpoints))
self._coordinator.start_coordinator(self.coordinator_hosts[0],
self.trainer_endpoints)
def _make_fl_strategy(self):
if self._coordinator == None:
assert ("Coordinator py object is null!")
else:
self._coordinator.make_fl_strategy()
def _init_server(self, dirname=None, var_names=None, **kwargs):
server_desc = self.ps_desc_builder.build_server_desc()
#with open("test_fl_ps_server_desc", "w") as f:
# f.write(server_desc)
role_id = get_role_id(self.role_maker)
trainers = get_trainers(self.role_maker)
if self.is_heter_ps_mode:
trainers += len(self.role_maker._get_heter_worker_endpoints())
debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
if debug:
print("server: \n{}".format(server_desc))
if self.debug:
print("server_desc: \n{}".format(server_desc))
self._server = fluid.core.DistFleetWrapper()
self._server.init_server(server_desc, self.string_hosts, role_id,
self._server.init_server(server_desc, self.string_hosts, self.role_id,
trainers, self._server_sub_program)
dist_varnames = get_sparse_tablenames(self.origin_main_programs, True)
......
......@@ -250,6 +250,10 @@ def get_trainer_endpoint(role_maker):
return role_maker._get_trainer_endpoint()
def get_trainer_endpoints(role_maker):
return role_maker._get_trainer_endpoints()
def get_previous_stage_trainers(role_maker):
try:
return role_maker._get_previous_trainers()
......@@ -1591,3 +1595,11 @@ def debug_program(file, program):
os.makedirs(os.path.dirname(file), exist_ok=True)
with open(file, 'w+') as f:
f.write(str(program))
def is_distributed_env():
node_role = os.getenv("TRAINING_ROLE")
if node_role is None:
return False
else:
return True
......@@ -34,7 +34,7 @@ It's a wrapper of a cpp class Communicator and should be used inside fleet API.
from . import core
from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode
__all__ = ['Communicator', 'LargeScaleKV']
__all__ = ['Communicator', 'FLCommunicator', 'LargeScaleKV']
class Communicator(object):
......@@ -208,6 +208,37 @@ class Communicator(object):
self.communicator_.push_sparse_param(var_name, table_id, scope)
class FLCommunicator(Communicator):
def __init__(self, ps_hosts, kwargs=None):
mode = None
super(FLCommunicator, self).__init__(mode, kwargs)
send_ctx = {}
dense_map = {}
prototxt = ""
self.mode = "WITH_COORDINATOR"
self.init_with_ctx(send_ctx, dense_map, prototxt, ps_hosts)
def start_coordinator(self, self_endpoint, trainer_endpoints):
if self.communicator_ != None:
self.communicator_.start_coordinator(self_endpoint,
trainer_endpoints)
return
def save_fl_strategy(self, mp):
if self.communicator_ != None:
self.communicator_.save_fl_strategy(mp)
else:
raise ValueError("self.communicator_ is null")
return
def query_fl_clients_info(self):
info_mp = {}
if self.communicator_ != None:
info_mp = self.communicator_.query_fl_clients_info()
return info_mp
class LargeScaleKV(object):
def __init__(self):
......
......@@ -1673,8 +1673,8 @@ class Executor(object):
return res
def _dump_debug_info(self, program=None, trainer=None):
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(str(trainer))
print("program_id: {}, trainer_desc:\n {}".format(
id(program), str(trainer)))
if program._fleet_opt and "fleet_desc" in program._fleet_opt:
with open("fleet_desc.prototxt", "w") as fout:
fout.write(str(program._fleet_opt["fleet_desc"]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册