From 4bc22b69c8db4123dc7fb0a540940748655723eb Mon Sep 17 00:00:00 2001 From: ziyoujiyi <73728031+ziyoujiyi@users.noreply.github.com> Date: Tue, 26 Jul 2022 18:43:23 +0800 Subject: [PATCH] add horizontal federation learning ps feature (#44327) * back fl * delete ssl cert * . * make warning * . * unittest paral degree * solve unittest * heter & multi cloud commm ready * . * . * fl-ps v1.0 * . * support N + N mode * . * . * . * . * delete print * . * . * . * . * fix bug * . * . * fl-ps with coordinator ready * merge dev * update message parse only * update fl client scheduler * fix bug * update multithreads sync * fix ci errors * update role_maker.py * update role_maker.py * fix ci error: windows py import error * fix ci error: windows py import error * fix windows ci pylib import error * add dump fields & params * try to fix windows import fleet error * fix ps FLAGS error --- cmake/external/brpc.cmake | 1 - .../distributed/ps/service/CMakeLists.txt | 5 + .../distributed/ps/service/brpc_ps_client.cc | 161 +++++++- .../distributed/ps/service/brpc_ps_client.h | 75 +++- .../ps/service/communicator/CMakeLists.txt | 1 + .../ps/service/communicator/communicator.cc | 132 +++++-- .../ps/service/communicator/communicator.h | 67 +++- .../ps/service/coordinator_client.cc | 205 ++++++++++ .../ps/service/coordinator_client.h | 256 +++++++++++++ paddle/fluid/distributed/ps/service/env.h | 32 +- .../distributed/ps/service/heter_client.cc | 4 +- .../distributed/ps/service/heter_client.h | 4 +- .../distributed/ps/service/heter_server.h | 6 +- .../fluid/distributed/ps/service/ps_client.cc | 4 +- .../fluid/distributed/ps/service/ps_client.h | 6 +- .../distributed/ps/service/sendrecv.proto | 15 + paddle/fluid/distributed/ps/service/server.cc | 0 paddle/fluid/distributed/ps/wrapper/fleet.cc | 28 ++ paddle/fluid/distributed/ps/wrapper/fleet.h | 8 + paddle/fluid/distributed/the_one_ps.proto | 26 ++ paddle/fluid/framework/device_worker.cc | 7 +- .../framework/distributed_strategy.proto | 1 + paddle/fluid/framework/multi_trainer.cc | 9 +- paddle/fluid/pybind/fleet_py.cc | 12 +- python/paddle/distributed/fleet/__init__.py | 4 + .../fleet/base/distributed_strategy.py | 12 + .../distributed/fleet/base/fleet_base.py | 22 ++ .../distributed/fleet/base/role_maker.py | 53 +-- python/paddle/distributed/fleet/launch.py | 10 + .../paddle/distributed/fleet/launch_utils.py | 153 +++++++- .../fleet/meta_optimizers/ps_optimizer.py | 2 + python/paddle/distributed/ps/coordinator.py | 351 ++++++++++++++++++ python/paddle/distributed/ps/the_one_ps.py | 70 +++- python/paddle/distributed/ps/utils/public.py | 12 + python/paddle/fluid/communicator.py | 33 +- python/paddle/fluid/executor.py | 4 +- 36 files changed, 1676 insertions(+), 115 deletions(-) mode change 100644 => 100755 paddle/fluid/distributed/ps/service/brpc_ps_client.h mode change 100644 => 100755 paddle/fluid/distributed/ps/service/communicator/CMakeLists.txt mode change 100644 => 100755 paddle/fluid/distributed/ps/service/communicator/communicator.h create mode 100644 paddle/fluid/distributed/ps/service/coordinator_client.cc create mode 100644 paddle/fluid/distributed/ps/service/coordinator_client.h mode change 100644 => 100755 paddle/fluid/distributed/ps/service/env.h mode change 100644 => 100755 paddle/fluid/distributed/ps/service/heter_client.h mode change 100644 => 100755 paddle/fluid/distributed/ps/service/heter_server.h mode change 100644 => 100755 paddle/fluid/distributed/ps/service/server.cc mode change 100644 => 100755 paddle/fluid/distributed/ps/wrapper/fleet.h mode change 100644 => 100755 paddle/fluid/distributed/the_one_ps.proto mode change 100644 => 100755 paddle/fluid/framework/multi_trainer.cc mode change 100644 => 100755 paddle/fluid/pybind/fleet_py.cc mode change 100644 => 100755 python/paddle/distributed/fleet/__init__.py mode change 100644 => 100755 python/paddle/distributed/fleet/base/role_maker.py mode change 100644 => 100755 python/paddle/distributed/fleet/launch.py mode change 100644 => 100755 python/paddle/distributed/fleet/launch_utils.py create mode 100755 python/paddle/distributed/ps/coordinator.py mode change 100644 => 100755 python/paddle/fluid/communicator.py diff --git a/cmake/external/brpc.cmake b/cmake/external/brpc.cmake index 4434e3fbed1..6ace45e11b8 100755 --- a/cmake/external/brpc.cmake +++ b/cmake/external/brpc.cmake @@ -47,7 +47,6 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} # TODO(gongwb): change to de newst repo when they changed GIT_REPOSITORY "https://github.com/wangjiawei04/brpc" - #GIT_REPOSITORY "https://github.com/ziyoujiyi/brpc" # ssl error in the previous repo(can be mannual fixed) GIT_TAG "e203afb794caf027da0f1e0776443e7d20c0c28e" PREFIX ${BRPC_PREFIX_DIR} UPDATE_COMMAND "" diff --git a/paddle/fluid/distributed/ps/service/CMakeLists.txt b/paddle/fluid/distributed/ps/service/CMakeLists.txt index 709d11f7fbb..9d87e885314 100755 --- a/paddle/fluid/distributed/ps/service/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/service/CMakeLists.txt @@ -78,6 +78,10 @@ set_source_files_properties( graph_brpc_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties( graph_brpc_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +set_source_files_properties( + coordinator_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + cc_library( brpc_utils SRCS brpc_utils.cc @@ -90,6 +94,7 @@ cc_library( cc_library( downpour_client SRCS graph_brpc_client.cc brpc_ps_client.cc ps_local_client.cc + coordinator_client.cc DEPS eigen3 table brpc_utils simple_threadpool ${RPC_DEPS}) cc_library( diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index c9135f919cb..942d5077361 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -18,10 +18,22 @@ #include #include +#include "paddle/fluid/distributed/ps/service/coordinator_client.h" #include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/string/split.h" static const int max_port = 65535; +namespace paddle { +namespace framework { +class Scope; +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace distributed { + DEFINE_int32(pserver_push_dense_merge_limit, 12, "limit max push_dense local merge requests"); @@ -66,16 +78,6 @@ DEFINE_int32(pserver_sparse_table_shard_num, 1000, "sparse table shard for save & load"); -namespace paddle { -namespace framework { -class Scope; -class Variable; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace distributed { - inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, uint64_t key) { @@ -101,7 +103,7 @@ void DownpourPsClientService::service( } } -// 启动client端RpcService 用于数据互发等操作 +// 启动 client 端 RpcService 用于数据互发等操作 int32_t BrpcPsClient::StartClientService() { if (_service.Configure(this, _client_id) != 0) { LOG(ERROR) @@ -122,6 +124,35 @@ int32_t BrpcPsClient::StartClientService() { _server_started = true; _env->RegistePsClient( butil::my_ip_cstr(), _server.listen_address().port, _client_id); + VLOG(0) << "BrpcPsClient Service addr: " << butil::my_ip_cstr() << ", " + << _server.listen_address().port << ", " << _client_id; + return 0; +} + +// 启动 FlClientService,用户接收 coordinator 数据 +int32_t BrpcPsClient::StartFlClientService(const std::string &self_endpoint) { + _fl_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + if (self_endpoint.empty()) { + LOG(ERROR) << "fl-ps > fl client endpoint not set"; + return -1; + } + + if (_fl_server.Start(self_endpoint.c_str(), &options) != 0) { + VLOG(0) << "fl-ps > StartFlClientService failed. Try again."; + auto ip_port = paddle::string::Split(self_endpoint, ':'); + std::string ip = ip_port[0]; + int port = std::stoi(ip_port[1]); + std::string int_ip_port = GetIntTypeEndpoint(ip, port); + if (_fl_server.Start(int_ip_port.c_str(), &options) != 0) { + LOG(ERROR) << "fl-ps > StartFlClientService failed, ip_port= " + << int_ip_port; + return -1; + } + } else { + VLOG(0) << "fl-ps > StartFlClientService succeed! listen on " + << self_endpoint; + } return 0; } @@ -166,6 +197,96 @@ int32_t BrpcPsClient::CreateClient2ClientConnection( return 0; } +int32_t BrpcPsClient::InitializeFlWorker(const std::string &self_endpoint) { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = FLAGS_pserver_timeout_ms; + options.connection_type = "pooled"; + options.connect_timeout_ms = + paddle::distributed::FLAGS_pserver_connect_timeout_ms; + options.max_retry = 3; + // 获取 coordinator 列表,并连接 + std::string coordinator_ip_port; + std::vector coordinator_list = _env->GetCoordinators(); + _coordinator_channels.resize(coordinator_list.size()); + for (size_t i = 0; i < coordinator_list.size(); ++i) { + coordinator_ip_port.assign(coordinator_list[i].ip.c_str()); + coordinator_ip_port.append(":"); + coordinator_ip_port.append(std::to_string(coordinator_list[i].port)); + VLOG(0) << "fl-ps > BrpcFlclient connetcting to coordinator: " + << coordinator_ip_port; + for (size_t j = 0; j < _coordinator_channels[i].size(); ++j) { + _coordinator_channels[i][j].reset(new brpc::Channel()); + if (_coordinator_channels[i][j]->Init( + coordinator_ip_port.c_str(), "", &options) != 0) { + LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:" + << coordinator_ip_port << " Failed! Try again."; + std::string int_ip_port = GetIntTypeEndpoint(coordinator_list[i].ip, + coordinator_list[i].port); + if (_coordinator_channels[i][j]->Init( + int_ip_port.c_str(), "", &options) != 0) { + LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:" + << int_ip_port << " Failed!"; + return -1; + } + } + } + } + StartFlClientService(self_endpoint); + VLOG(0) << "fl-ps > InitializeFlWorker finished!"; + return 0; +} + +void BrpcPsClient::PushFLClientInfoSync(const std::string &fl_client_info) { + size_t request_call_num = _coordinator_channels.size(); + FlClientBrpcClosure *closure = + new FlClientBrpcClosure(request_call_num, [request_call_num](void *done) { + auto *closure = reinterpret_cast(done); + int ret = 0; + for (size_t i = 0; i < request_call_num; i++) { + if (closure->check_response(i, PUSH_FL_CLIENT_INFO_SYNC) != 0) { + LOG(ERROR) << "fl-ps > PushFLClientInfoSync response from " + "coordinator is failed"; + ret = -1; + return; + } else { + VLOG(0) << "fl-ps > rpc service call cost time: " + << (closure->cntl(i)->latency_us() / 1000) << " ms"; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + std::future fut = promise->get_future(); + closure->add_promise(promise); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PUSH_FL_CLIENT_INFO_SYNC); + closure->request(i)->set_client_id(_client_id); + closure->request(i)->set_str_params(fl_client_info); + brpc::Channel *rpc_channel = _coordinator_channels[0][0].get(); + if (rpc_channel == nullptr) { + LOG(ERROR) << "_coordinator_channels is null"; + return; + } + PsService_Stub rpc_stub(rpc_channel); // CoordinatorService + rpc_stub.FLService( + closure->cntl(i), closure->request(i), closure->response(i), closure); + fut.wait(); + } + VLOG(0) << "fl-ps > PushFLClientInfoSync finished, client id: " << _client_id; + return; +} + +std::string BrpcPsClient::PullFlStrategy() { + while (!_service._is_fl_strategy_ready) { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + VLOG(0) << "fl-ps > waiting for fl strategy returned from coordinator"; + } + _service._is_fl_strategy_ready = + false; // only support single thread, no need for multi-threads + return _service._fl_strategy; +} + int32_t BrpcPsClient::Initialize() { _async_call_num = 0; @@ -300,6 +421,24 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) { return data; } +int FlClientBrpcClosure::check_response(size_t request_idx, int cmd_id) { + if (_cntls[request_idx]->Failed()) { + LOG(ERROR) << "resquest cmd_id:" << cmd_id + << " failed, " + "err:" + << _cntls[request_idx]->ErrorText(); + return -1; + } + if (_responses[request_idx].err_code() != 0) { + LOG(ERROR) << "response ret bad, server_idx:" << request_idx + << "cmd_id:" << cmd_id + << " err_code:" << _responses[request_idx].err_code() + << " err_msg:" << _responses[request_idx].err_msg(); + return -1; + } + return 0; +} + std::future BrpcPsClient::PrintTableStat(uint32_t table_id) { size_t request_call_num = _server_channels.size(); DownpourBrpcClosure *closure = new DownpourBrpcClosure( diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.h b/paddle/fluid/distributed/ps/service/brpc_ps_client.h old mode 100644 new mode 100755 index 3b455a44dc0..bbaecc498a8 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.h @@ -25,6 +25,7 @@ #include "brpc/server.h" #include "paddle/fluid/distributed/ps/service/brpc_utils.h" #include "paddle/fluid/distributed/ps/service/ps_client.h" +#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" @@ -56,16 +57,71 @@ class DownpourPsClientService : public PsService { _rank = rank_id; return 0; } - void service(::google::protobuf::RpcController *controller, - const PsRequestMessage *request, - PsResponseMessage *response, - ::google::protobuf::Closure *done) override; + + virtual void service(::google::protobuf::RpcController *controller, + const PsRequestMessage *request, + PsResponseMessage *response, + ::google::protobuf::Closure *done); + + virtual void FLService(::google::protobuf::RpcController *controller, + const CoordinatorReqMessage *request, + CoordinatorResMessage *response, + ::google::protobuf::Closure *done) { + brpc::ClosureGuard done_guard(done); + size_t client_id = request->client_id(); + CHECK(_client->_client_id == client_id) + << "request client id not matched self"; + _fl_strategy = request->str_params(); + _is_fl_strategy_ready = true; + response->set_err_code(0); + response->set_err_msg(""); + VLOG(0) << "fl-ps > DownpourPsClientService::FLService finished!"; + return; + } + + public: + std::string _fl_strategy; + bool _is_fl_strategy_ready = false; protected: size_t _rank; PSClient *_client; }; +class FlClientBrpcClosure : public PSClientClosure { + public: + FlClientBrpcClosure(size_t num, PSClientCallBack callback) + : PSClientClosure(callback) { + _waiting_num = num; + + _cntls.resize(num); + _requests.resize(num); + _responses.resize(num); + for (size_t i = 0; i < num; ++i) { + _cntls[i].reset(new brpc::Controller()); + } + } + virtual ~FlClientBrpcClosure() {} + void Run() override { + if (_waiting_num.fetch_sub(1) == 1) { + _callback(this); + delete this; + } + } + CoordinatorReqMessage *request(size_t i) { return &_requests[i]; } + CoordinatorResMessage *response(size_t i) { return &_responses[i]; } + brpc::Controller *cntl(size_t i) { return _cntls[i].get(); } + int check_response(size_t request_idx, int cmd_id); + int check_save_response(size_t request_idx, int cmd_id); + std::string get_response(size_t request_idx, int cmd_id); + + private: + std::atomic _waiting_num; + std::vector _requests; + std::vector _responses; + std::vector> _cntls; +}; + class DownpourBrpcClosure : public PSClientClosure { public: DownpourBrpcClosure(size_t num, PSClientCallBack callback) @@ -267,6 +323,14 @@ class BrpcPsClient : public PSClient { } int32_t Initialize() override; + // for fl + public: + virtual int32_t InitializeFlWorker(const std::string &self_endpoint); + int32_t StartFlClientService(const std::string &self_endpoint); + virtual void PushFLClientInfoSync(const std::string &fl_client_info); + std::string PullFlStrategy(); + // for fl + private: inline uint32_t DenseDimPerShard(uint32_t dense_dim_total, uint32_t shard_num) { @@ -320,6 +384,8 @@ class BrpcPsClient : public PSClient { _client_channels; // client2client std::vector, 3>> _server_channels; // client2server + std::vector, 1>> + _coordinator_channels; // client2coordinator std::future PushDenseRawGradient(int table_id, float *total_send_data, size_t total_send_data_size, @@ -360,6 +426,7 @@ class BrpcPsClient : public PSClient { float _mse = 0; uint16_t _push_times = 0; brpc::Server _server; + brpc::Server _fl_server; DownpourPsClientService _service; bool _server_started = false; std::atomic_uint grad_num_{0}; diff --git a/paddle/fluid/distributed/ps/service/communicator/CMakeLists.txt b/paddle/fluid/distributed/ps/service/communicator/CMakeLists.txt old mode 100644 new mode 100755 index 03244ecba7b..d6ef2970963 --- a/paddle/fluid/distributed/ps/service/communicator/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/service/communicator/CMakeLists.txt @@ -1,5 +1,6 @@ get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) +set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") set_source_files_properties( communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.cc b/paddle/fluid/distributed/ps/service/communicator/communicator.cc index 0856c81121f..414bc560772 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.cc +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.cc @@ -89,7 +89,7 @@ int Communicator::SetClients(std::vector &host_sign_list) { void Communicator::RpcRecvDense(const std::vector &varnames, int table_id, - Scope *scope) { + Scope *scope) { // pserver_scope_ platform::RecordEvent record_event("Communicator->RpcRecvDense", platform::TracerEventType::Communication, 1); @@ -106,7 +106,7 @@ void Communicator::RpcRecvDense(const std::vector &varnames, float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); paddle::distributed::Region reg(temp_data, tensor->numel()); regions.emplace_back(std::move(reg)); - VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + VLOG(1) << "Communicator::RpcRecvDense Var " << t << " table_id " << table_id << " Temp_data[0] " << temp_data[0] << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; #endif @@ -123,11 +123,11 @@ void Communicator::RpcRecvDense(const std::vector &varnames, for (auto &t : varnames) { Variable *var = scope->FindVar(t); LoDTensor *tensor = var->GetMutable(); - VLOG(3) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " + VLOG(3) << "Communicator::RecvNoBarrier Var " << t << " On gpu? " << platform::is_gpu_place(tensor->place()); float *temp_recv_data = tensor->mutable_data(platform::CPUPlace()); - VLOG(3) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + VLOG(3) << "Communicator::RpcRecvDense Var " << t << " table_id " << table_id << " Temp_data[0] " << temp_recv_data[0] << " Temp_data[-1] " << temp_recv_data[tensor->numel() - 1]; if (platform::is_gpu_place(tensor->place())) { @@ -136,7 +136,7 @@ void Communicator::RpcRecvDense(const std::vector &varnames, xpu_temp_scope_->FindVar(t)->GetMutable(); framework::TensorCopy(*temp_tensor, tensor->place(), tensor); float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); - VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + VLOG(1) << "Communicator::RpcRecvDense Var " << t << " table_id " << table_id << " Temp_data[0] " << temp_data[0] << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; #endif @@ -187,7 +187,8 @@ void Communicator::RpcSendDenseParam(const std::vector &varnames, return; } -void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { +void Communicator::RpcSendDense(const CommContext &ctx, + const Scope &scope) { // delta_scope_ platform::RecordEvent record_event("Communicator->RpcSendDense", platform::TracerEventType::Communication, 1); @@ -343,21 +344,21 @@ void Communicator::RpcRecvSparse(const std::string &varname, auto dim = tensor->dims()[1]; uint64_t sparse_num = static_cast(tensor->dims()[0]); - std::vector sparse_push_keys(sparse_num); - std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0); + std::vector sparse_pull_keys(sparse_num); + std::iota(sparse_pull_keys.begin(), sparse_pull_keys.end(), 0); - std::vector push_g_vec; - for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { - push_g_vec.push_back(tensor->data() + i * dim); + std::vector pull_g_vec; + for (auto i = 0; i < static_cast(sparse_pull_keys.size()); ++i) { + pull_g_vec.push_back(tensor->data() + i * dim); } bool training = true; auto status = - _worker_ptr->PullSparseParam(static_cast(push_g_vec.data()), + _worker_ptr->PullSparseParam(static_cast(pull_g_vec.data()), table_id, - sparse_push_keys.data(), - sparse_push_keys.size(), + sparse_pull_keys.data(), + sparse_pull_keys.size(), training); status.wait(); return; @@ -1013,8 +1014,9 @@ void SyncCommunicator::BarrierRecv() { VLOG(4) << "BarrierRecv with SyncCommunicator"; } -void GeoCommunicator::Send(const std::vector &var_names, - const framework::Scope &scope) { +void GeoCommunicator::Send( + const std::vector &var_names, + const framework::Scope &scope) { // last op in program platform::RecordEvent record_event( "GeoCommunicator->Send", platform::TracerEventType::Communication, 1); waiting_ = false; @@ -1041,10 +1043,13 @@ void GeoCommunicator::Send(const std::vector &var_names, auto &rows = var->Get().rows(); // insert ids which has not been record - for (size_t j = 0; j < rows.size(); j++) { + // VLOG(0) << "fl-ps > table_name: " << table_name << " splited_var_nums: " << + // splited_var_nums << " rows size: " << rows.size(); + for (size_t j = 0; j < rows.size(); j++) { // batch_size == rows.size() auto ep_idx = rows[j] % splited_var_nums; ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx]) .insert(rows[j]); + // VLOG(0) << " id: " << rows[j] << " "; } for (auto &iter : ids_table) { @@ -1143,7 +1148,7 @@ void GeoCommunicator::InitDense(std::vector &varnames, } else { BarrierWithTable(1); RpcRecvDense(varnames, table_id, recv_scope_); - VLOG(1) << "pull dense param to table " << table_id + VLOG(1) << "pull dense param from table " << table_id << " from 0' trainer done"; } @@ -1153,7 +1158,7 @@ void GeoCommunicator::InitDense(std::vector &varnames, global_var->GetMutable(); auto *old_var = old_scope_->Var(t); old_var->GetMutable(); - framework::CopyVariable(*global_var, old_var); + framework::CopyVariable(*global_var, old_var); // src, dst // init pserver_scope_ auto *pserver_var = pserver_scope_->Var(t); pserver_var->GetMutable(); @@ -1218,7 +1223,7 @@ void GeoCommunicator::RecvDense(const CommContext &send_ctx) { // 1. recv from pserver RpcRecvDense(varnames, table_id, pserver_scope_.get()); - // 2.1 pserver - old => delta; 2.2 latest + old => latest 2.3 old => pserver + // 2.1 pserver - old => delta; 2.2 latest + delta => latest 2.3 old => pserver phi::CPUContext cpu_ctx; for (auto &varname : varnames) { auto *var_latest = recv_scope_->FindVar(varname); @@ -1267,7 +1272,7 @@ void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) { VLOG(1) << "Init Sparse " << var_name << " : table " << table_id << " done."; auto *global_var = recv_scope_->FindVar(var_name); auto *var = old_scope_->Var(var_name); - framework::CopyVariable(*global_var, var); + framework::CopyVariable(*global_var, var); // src, dst return; } @@ -1278,7 +1283,8 @@ std::vector GeoCommunicator::MergeSparseIds( 1); size_t merge_num = 0, wait_times = 0; std::unordered_set sparse_ids; - while (merge_num < static_cast(max_merge_var_num_)) { + while (merge_num < + static_cast(max_merge_var_num_)) { // -> geo_step: 100 VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num; if (sparse_id_queues_.at(send_varname)->Size() > 0) { wait_times = 0; @@ -1467,7 +1473,9 @@ void GeoCommunicator::MainThread() { for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) { // varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0 auto send_recv_task = [this, table_id, ep_idx, &ctx] { - auto splited_varname = ctx.splited_varnames[ep_idx]; + auto splited_varname = + ctx.splited_varnames[ep_idx]; // embedding_0.w_0.block0 + // embedding_1.w_0.block0 auto sparse_ids = MergeSparseIds(splited_varname); SendSparse(splited_varname, sparse_ids, table_id, ep_idx); RecvSparse(splited_varname, table_id, ep_idx); @@ -1490,5 +1498,83 @@ void GeoCommunicator::MainThread() { } } +void FLCommunicator::InitBrpcClient( + const std::string &dist_desc, + const std::vector &host_sign_list) { + auto fleet = paddle::distributed::FleetWrapper::GetInstance(); + if (_worker_ptr.get() == nullptr) { + VLOG(0) << "fl-ps > FLCommunicator::InitBrpcClient get _worker_ptr"; + _worker_ptr = + fleet->worker_ptr_; // FleetWrapper::InitWorker must be excuted before, + // but no need for Coordinator + } + if (coordinator_client_ptr_ == nullptr) { + coordinator_client_ptr_.reset(new CoordinatorClient); + } + int16_t servers = host_sign_list.size(); + coordinator_client_ptr_->_env = &ps_env_; + coordinator_client_ptr_->_env->SetPsServers(&host_sign_list, servers); +} + +void FLCommunicator::StartCoordinatorClient( + const std::vector &trainer_endpoints) { + if (coordinator_client_ptr_ == nullptr) { + LOG(ERROR) << "coordinator_client_ptr_ is null"; + return; + } + coordinator_client_ptr_->Initialize(trainer_endpoints); + VLOG(0) << "fl-ps > StartCoordinatorClient finish!"; +} + +void FLCommunicator::StartCoordinatorServer() { + if (coordinator_client_ptr_ == nullptr) { + LOG(ERROR) << "coordinator_client_ptr_ is null"; + } + int ret = coordinator_client_ptr_->StartClientService(); + if (ret != 0) { + LOG(ERROR) << "coordinator_client_ptr_ StartClientService failed"; + } + VLOG(0) << "fl-ps > StartCoordinatorServer finished!"; + return; +} + +std::unordered_map FLCommunicator::QueryFLClientsInfo() { + return coordinator_client_ptr_->QueryFLClientsInfo(); +} + +void FLCommunicator::SaveFLStrategy( + const std::unordered_map &fl_strategy) { + coordinator_client_ptr_->SaveFLStrategy(fl_strategy); + return; +} + +void FLCommunicator::SendThreadAsync() { + while (is_running_) { + RpcSendFLStrategy(); + } + return; +} + +void FLCommunicator::RpcSendFLStrategy() { + std::set clients = coordinator_client_ptr_->GetFLClientIds(); + coordinator_client_ptr_->WaitForFLStrategyReady(); + for (auto client_id : clients) { + coordinator_client_ptr_->SendFLStrategy(client_id); + } + coordinator_client_ptr_->ResetFLStrategyFlag(); + VLOG(0) << "fl-ps > RpcSendFLStrategy finished!"; + return; +} + +void FLCommunicator::StartCoordinator( + const std::string &self_endpoint, + const std::vector &trainer_endpoints) { + coordinator_client_ptr_->SetEndpoint(self_endpoint); + StartCoordinatorClient(trainer_endpoints); + StartCoordinatorServer(); + async_send_thread_.reset( + new std::thread(&FLCommunicator::SendThreadAsync, this)); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.h b/paddle/fluid/distributed/ps/service/communicator/communicator.h old mode 100644 new mode 100755 index f08208ed02d..5af035d5dcf --- a/paddle/fluid/distributed/ps/service/communicator/communicator.h +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.h @@ -31,6 +31,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" +#include "paddle/fluid/distributed/ps/service/coordinator_client.h" #include "paddle/fluid/distributed/ps/service/ps_client.h" #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/scope.h" @@ -241,9 +242,11 @@ class Communicator { envs[iter.first] = iter.second; VLOG(3) << iter.first << ": " << iter.second; } - barrier_table_id_ = std::stoi(envs.at("barrier_table_id")); - trainer_id_ = std::stoi(envs.at("trainer_id")); - trainers_ = std::stoi(envs.at("trainers")); + if (!envs.empty()) { + barrier_table_id_ = std::stoi(envs.at("barrier_table_id")); + trainer_id_ = std::stoi(envs.at("trainer_id")); + trainers_ = std::stoi(envs.at("trainers")); + } } virtual void InitBrpcClient(const std::string &dist_desc, @@ -280,6 +283,15 @@ class Communicator { int batches, Scope *send_scope); + virtual std::unordered_map QueryFLClientsInfo() { + return {}; + } + virtual void SaveFLStrategy( + const std::unordered_map &fl_strategy) {} + virtual void StartCoordinator( + const std::string &self_endpoint, + const std::vector &trainer_endpoints) {} + virtual ~Communicator() {} virtual void RpcProfilerControl(); @@ -376,10 +388,6 @@ class Communicator { PSClient *GetPsClient() { return _worker_ptr.get(); } - std::shared_ptr GetPsClientPtr() { - return std::move(_worker_ptr); - } - RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; } std::shared_ptr _worker_ptr; // pointer to worker @@ -657,5 +665,50 @@ class GeoCommunicator : public AsyncCommunicator { sparse_id_queues_; }; +class FLCommunicator : public GeoCommunicator { + public: + FLCommunicator() : GeoCommunicator() {} + + ~FLCommunicator() { + is_running_ = false; + async_send_thread_->join(); + } + + explicit FLCommunicator(const std::map &envs) + : GeoCommunicator(envs) {} + + void InitEnvs() override {} + + virtual void InitBrpcClient(const std::string &dist_desc, + const std::vector &host_sign_list); + + void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) override {} + + void StartCoordinatorClient( + const std::vector &trainer_endpoints); + + void StartCoordinatorServer(); + + void StartCoordinator( + const std::string &self_endpoint, + const std::vector &trainer_endpoints) override; + + std::unordered_map QueryFLClientsInfo(); + void SaveFLStrategy( + const std::unordered_map &fl_strategy); + + void SendThreadAsync(); + void RpcSendFLStrategy(); + + private: + int thread_pool_size_ = 1; + bool is_running_ = true; + PaddlePSEnvironment ps_env_; + std::shared_ptr coordinator_client_ptr_{nullptr}; + std::unique_ptr async_send_thread_{nullptr}; +}; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/coordinator_client.cc b/paddle/fluid/distributed/ps/service/coordinator_client.cc new file mode 100644 index 00000000000..7d48520118d --- /dev/null +++ b/paddle/fluid/distributed/ps/service/coordinator_client.cc @@ -0,0 +1,205 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/ps/service/coordinator_client.h" + +#include +#include +#include + +#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/string/split.h" + +static const int MIN_PORT = 8500; +static const int MAX_PORT = 65535; + +namespace paddle { +namespace distributed { + +DEFINE_uint64(total_fl_client_size, 100, "supported total fl client size"); +DEFINE_uint32(coordinator_wait_all_clients_max_time, 60, "uint32: s"); + +void CoordinatorService::FLService( + ::google::protobuf::RpcController* controller, + const CoordinatorReqMessage* request, + CoordinatorResMessage* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + response->set_err_code(0); + response->set_err_msg(""); + brpc::Controller* cntl = static_cast(controller); + int32_t msg_type = request->cmd_id(); + uint32_t from_client_id = request->client_id(); + VLOG(0) << "fl-ps > recv from client id: " << from_client_id + << ", msg_type: " << msg_type; + // TODO(ziyoujiyi): find is not thread safe, beacuse of RB_Tree traversal + auto itr = _service_handle_map.find(msg_type); + if (itr == _service_handle_map.end()) { + LOG(ERROR) << "fl-ps > unknown flClient2Coordinator msg type: " << msg_type; + return; + } + int ret = itr->second(*request, response, cntl); // SaveFLClientInfo + if (ret != 0) { + response->set_err_code(-1); + response->set_err_msg("fl-ps > handle flClient2Coordinator msg failed"); + } + return; +} + +int32_t CoordinatorClient::Initialize( + const std::vector& trainer_endpoints) { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = paddle::distributed::FLAGS_pserver_timeout_ms; + options.connection_type = "pooled"; + options.connect_timeout_ms = + paddle::distributed::FLAGS_pserver_connect_timeout_ms; + options.max_retry = 3; + + std::string server_ip_port; + + // 获取 Pserver 列表,并连接 + if (_env == nullptr) { + LOG(ERROR) << "_env is null in CoordinatorClient::Initialize()"; + return -1; + } + std::vector pserver_list = _env->GetPsServers(); + + _pserver_channels.resize(pserver_list.size()); + for (size_t i = 0; i < pserver_list.size(); ++i) { + server_ip_port.assign(pserver_list[i].ip.c_str()); + server_ip_port.append(":"); + server_ip_port.append(std::to_string(pserver_list[i].port)); + for (size_t j = 0; j < _pserver_channels[i].size(); ++j) { + _pserver_channels[i][j].reset(new brpc::Channel()); + if (_pserver_channels[i][j]->Init(server_ip_port.c_str(), "", &options) != + 0) { + LOG(ERROR) << "CoordinatorClient connect to PServer:" << server_ip_port + << " Failed! Try again."; + std::string int_ip_port = + GetIntTypeEndpoint(pserver_list[i].ip, pserver_list[i].port); + if (_pserver_channels[i][j]->Init(int_ip_port.c_str(), "", &options) != + 0) { + LOG(ERROR) << "CoordinatorClient connect to PServer:" << int_ip_port + << " Failed!"; + return -1; + } + } + } + } + + // 获取 fl_client 列表,并连接 + std::vector fl_client_list; + fl_client_list.resize(trainer_endpoints.size()); + if (fl_client_list.empty()) { + LOG(ERROR) << ">>> fl clients addr info lost"; + return -1; + } + for (size_t i = 0; i < trainer_endpoints.size(); i++) { + std::vector addr = + paddle::string::Split(trainer_endpoints[i], ':'); + fl_client_list[i].ip = addr[0]; + fl_client_list[i].port = std::stol(addr[1]); + fl_client_list[i].rank = i; // TO CHECK + } + std::string fl_client_ip_port; + for (size_t i = 0; i < fl_client_list.size(); ++i) { + fl_client_ip_port.assign(fl_client_list[i].ip); + fl_client_ip_port.append(":"); + fl_client_ip_port.append(std::to_string(fl_client_list[i].port)); + uint32_t rank = fl_client_list[i].rank; + VLOG(0) << "fl-ps > coordinator connect to fl_client: " << rank; + _fl_client_channels[rank].reset(new brpc::Channel()); + if (_fl_client_channels[rank]->Init( + fl_client_ip_port.c_str(), "", &options) != 0) { + LOG(ERROR) << "CoordinatorClient connect to FLClient:" + << fl_client_ip_port << " Failed! Try again."; + std::string int_ip_port = + GetIntTypeEndpoint(fl_client_list[i].ip, fl_client_list[i].port); + if (_fl_client_channels[rank]->Init(int_ip_port.c_str(), "", &options) != + 0) { + LOG(ERROR) << "CoordinatorClient connect to PSClient:" << int_ip_port + << " Failed!"; + return -1; + } + } + } + + SetTotalFLClientsNum(fl_client_list.size()); + SetDefaultFLStrategy(); + return 0; +} + +int32_t CoordinatorClient::StartClientService() { + _service.Initialize(); + + _server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + options.num_threads = 1; + if (_endpoint.empty()) { + LOG(ERROR) << "fl-ps > coordinator server endpoint not set"; + return -1; + } + auto addr = paddle::string::Split(_endpoint, ':'); + std::string ip = addr[0]; + std::string port = addr[1]; + std::string rank = addr[2]; + std::string ip_port = ip + ":" + port; + if (_server.Start(ip_port.c_str(), &options) != 0) { + LOG(ERROR) << "fl-ps > StartClientService failed"; + return -1; + } + uint32_t port_ = std::stol(port); + int32_t rank_ = std::stoi(rank); + _env->RegisteCoordinatorClient(ip, port_, rank_); + VLOG(0) << "fl-ps > coordinator service addr: " << ip << ", " << port << ", " + << _coordinator_id; + return 0; +} + +void CoordinatorClient::SendFLStrategy(const uint32_t& client_id) { + size_t request_call_num = 1; + FlClientBrpcClosure* closure = + new FlClientBrpcClosure(request_call_num, [](void* done) { + auto* closure = reinterpret_cast(done); + int ret = 0; + if (closure->check_response(0, PUSH_FL_STRATEGY) != 0) { + LOG(ERROR) << "fl-ps > SendFLStrategy failed"; + ret = -1; + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + std::future fut = promise->get_future(); + closure->add_promise(promise); + closure->request(0)->set_cmd_id(PUSH_FL_STRATEGY); + closure->request(0)->set_client_id(client_id); + std::string fl_strategy = _fl_strategy_mp[client_id]; + closure->request(0)->set_str_params(fl_strategy); + brpc::Channel* rpc_channel = _fl_client_channels[client_id].get(); + if (rpc_channel == nullptr) { + LOG(ERROR) << "fl-ps > _fl_client_channels is null"; + return; + } + PsService_Stub rpc_stub(rpc_channel); // DownpourPsClientService + rpc_stub.FLService( + closure->cntl(0), closure->request(0), closure->response(0), closure); + fut.wait(); + VLOG(0) << "fl-ps > SendFLStrategy to client: " << client_id << " finished"; + return; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/coordinator_client.h b/paddle/fluid/distributed/ps/service/coordinator_client.h new file mode 100644 index 00000000000..883799fe500 --- /dev/null +++ b/paddle/fluid/distributed/ps/service/coordinator_client.h @@ -0,0 +1,256 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include + +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/ps/service/brpc_utils.h" +#include "paddle/fluid/distributed/ps/service/ps_client.h" +#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" + +namespace paddle { +namespace distributed { + +DECLARE_int32(pserver_timeout_ms); +DECLARE_int32(pserver_connect_timeout_ms); +DECLARE_uint64(total_fl_client_size); +DECLARE_uint32(coordinator_wait_all_clients_max_time); + +using CoordinatorServiceFunc = + std::function; + +class ClientReportedInfo { + public: + ClientReportedInfo() {} + ~ClientReportedInfo() {} + uint32_t client_id; + uint32_t iteration_idx; + double auc = 0.0; +}; + +class CoordinatorServiceHandle { + public: + CoordinatorServiceHandle() {} + + virtual ~CoordinatorServiceHandle() {} + + void SaveFLClientInfo(const CoordinatorReqMessage& request) { + auto client_id = request.client_id(); + const std::string& str_params = request.str_params(); + // each client is allowed to send empty message to maintain heartbeat(i.e. + // use staleness msg) + std::unique_lock lck(_mtx); + if (str_params.size() != 0) { + _client_info_mp[client_id] = str_params; + } else { + LOG(INFO) << "fl-ps > content in request from " << client_id + << " is null"; + } + fl_client_ids.insert(client_id); + _fl_clients_count++; + // TODO(ziyoujiyi): how to process when a client loss connection? + if (_fl_clients_count.load() == last_round_total_fl_clients_num) { + _is_all_clients_info_collected = true; + _cv.notify_one(); + } + lck.unlock(); + VLOG(0) << "last_round_total_fl_clients_num: " + << last_round_total_fl_clients_num + << ", has recved fl client num: " << _fl_clients_count.load(); + return; + } + + std::unordered_map QueryFLClientsInfo() { + platform::Timer timeline; + double query_wait_time = 0.0; + timeline.Start(); + auto f = [&]() -> bool { + while (query_wait_time < + paddle::distributed:: + FLAGS_coordinator_wait_all_clients_max_time) { // in case that + // some + // clients down + if (_is_all_clients_info_collected == true) { + // LOG(INFO) << "fl-ps > _is_all_clients_info_collected"; + return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + timeline.Pause(); + query_wait_time += timeline.ElapsedSec(); + } + // LOG(WARNNING) << "fl-ps > query_wait_time exceed!"; + return true; + }; + + std::unique_lock lck(_mtx); + _cv.wait(lck, f); + lck.unlock(); + + _is_all_clients_info_collected = false; + _fl_clients_count.store(0); + return _client_info_mp; + } + + public: + std::unordered_map _client_info_mp; + std::set fl_client_ids; + uint32_t last_round_total_fl_clients_num = 0; + bool _is_all_clients_info_collected = false; + + private: + std::mutex _mtx; + std::condition_variable _cv; + std::atomic _fl_clients_count{0}; +}; + +class CoordinatorService : public PsService { + public: + CoordinatorService() { + _coordinator_service_handle = std::make_shared(); + } + + virtual ~CoordinatorService() {} + + virtual void Initialize() { + _service_handle_map[PUSH_FL_CLIENT_INFO_SYNC] = + std::bind(&CoordinatorService::SaveFLClientInfo, + this, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3); + } + + virtual void FLService(::google::protobuf::RpcController* controller, + const CoordinatorReqMessage* request, + CoordinatorResMessage* response, + ::google::protobuf::Closure* done); + + int32_t SaveFLClientInfo(const CoordinatorReqMessage& request, + CoordinatorResMessage* response, + brpc::Controller* cntl) { + _coordinator_service_handle->SaveFLClientInfo(request); + return 0; + } + + void SetTotalFLClientsNum(uint32_t all_fl_clients_num) { + if (_coordinator_service_handle.get() != nullptr) { + _coordinator_service_handle->last_round_total_fl_clients_num = + all_fl_clients_num; + } else { + LOG(ERROR) << "fl-ps > _coordinator_service_handle is null in " + "CoordinatorService"; + } + return; + } + + std::set GetFLClientIds() { + return _coordinator_service_handle->fl_client_ids; + } + + std::unordered_map QueryFLClientsInfo() { + return _coordinator_service_handle->QueryFLClientsInfo(); + } + + private: + std::shared_ptr _coordinator_service_handle; + std::unordered_map _service_handle_map; + std::mutex _mtx; +}; + +class CoordinatorClient : public BrpcPsClient { + public: + CoordinatorClient() : _coordinator_id(0) {} + + virtual ~CoordinatorClient() {} + + int32_t Initialize(const std::vector& trainer_endpoints); + + void SetTotalFLClientsNum(uint32_t all_fl_clients_num) { + _service.SetTotalFLClientsNum(all_fl_clients_num); + this->_total_clients_num = all_fl_clients_num; + return; + } + + int32_t StartClientService(); + + void SaveFLStrategy( + const std::unordered_map& fl_strategy) { + for (auto it = fl_strategy.begin(); it != fl_strategy.end(); it++) { + uint32_t client_id = it->first; + _fl_strategy_mp[client_id] = it->second; + } + std::unique_lock lck(_mtx); + _is_fl_strategy_ready = true; + _cv.notify_all(); + return; + } + + void WaitForFLStrategyReady() { + std::unique_lock lck(_mtx); + _cv.wait(lck, [=]() { return _is_fl_strategy_ready; }); + } + + void SendFLStrategy(const uint32_t& client_id); + + void ResetFLStrategyFlag() { _is_fl_strategy_ready = false; } + + void SetDefaultFLStrategy() { + for (size_t i = 0; i < _total_clients_num; i++) { + _fl_strategy_mp[i] = ""; + } + return; + } + + std::set GetFLClientIds() { return _service.GetFLClientIds(); } + + std::unordered_map QueryFLClientsInfo() { + return _service.QueryFLClientsInfo(); + } + + void SetEndpoint(const std::string& endpoint) { + _endpoint = std::move(endpoint); + } + + public: + size_t _coordinator_id; + uint32_t _total_clients_num; + std::string _endpoint; + std::vector, 1>> + _pserver_channels; // coordinator2pserver + std::unordered_map> + _fl_client_channels; // coordinator2psclient + brpc::Server _server; + CoordinatorService _service; + std::unordered_map _fl_strategy_mp; + bool _is_fl_strategy_ready = false; + std::mutex _mtx; + std::condition_variable _cv; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/env.h b/paddle/fluid/distributed/ps/service/env.h old mode 100644 new mode 100755 index 7b730c7ffbd..8e97e2126c2 --- a/paddle/fluid/distributed/ps/service/env.h +++ b/paddle/fluid/distributed/ps/service/env.h @@ -65,7 +65,7 @@ struct PSHost { s << "host: " << ip; s << " port: " << port; s << " rank: " << rank; - s << " uint: " << SerializeToUint64(); + s << " uint64: " << SerializeToUint64(); return s.str(); } @@ -130,6 +130,7 @@ class PSEnvironment { virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) { return 0; } + virtual uint64_t GetLocalHostSign() { return 0; } virtual std::vector GetPsServers() const { return _ps_server_list; } virtual int32_t RegistePsServer(const std::string &ip, @@ -145,6 +146,16 @@ class PSEnvironment { return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set); } + virtual std::vector GetCoordinators() const { + return _coordinator_list; + } + virtual int32_t RegisteCoordinatorClient(const std::string &ip, + uint32_t port, + int32_t rank) { + return RegistePsHost( + ip, port, rank, _coordinator_list, _coordinator_sign_set); + } + virtual std::vector GetClientInfo() { std::vector client_info; for (auto &i : _ps_client_list) { @@ -196,6 +207,9 @@ class PSEnvironment { std::vector _ps_server_list; std::unordered_set _ps_server_sign_set; // for unique filter + + std::vector _coordinator_list; + std::unordered_set _coordinator_sign_set; }; class PaddlePSEnvironment : public PSEnvironment { @@ -278,6 +292,22 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } + virtual void SetCoordinators(const std::vector *host_sign_list, + size_t node_num) { + _coordinator_list.clear(); + _coordinator_sign_set.clear(); + for (size_t i = 0; i < node_num; ++i) { + if (host_sign_list->at(i) != "") { + PSHost host; + host.ParseFromString(host_sign_list->at(i)); + _coordinator_list.push_back(host); + _coordinator_sign_set.insert(host.rank); + VLOG(0) << "fl-ps > coordinator info in env: " << host.ToString(); + } + } + return; + } + virtual uint64_t GetLocalHostSign() { if (_ps_client_list.size() > 0) { return _ps_client_list[0].SerializeToUint64(); diff --git a/paddle/fluid/distributed/ps/service/heter_client.cc b/paddle/fluid/distributed/ps/service/heter_client.cc index 89e267093e2..91a20a432a3 100644 --- a/paddle/fluid/distributed/ps/service/heter_client.cc +++ b/paddle/fluid/distributed/ps/service/heter_client.cc @@ -17,11 +17,11 @@ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/platform/profiler.h" +namespace paddle { +namespace distributed { DEFINE_int32(heter_world_size, 100, "group size"); // group max size DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s"); -namespace paddle { -namespace distributed { std::shared_ptr HeterClient::s_instance_ = nullptr; std::mutex HeterClient::mtx_; std::shared_ptr HeterClient::switch_s_instance_ = nullptr; diff --git a/paddle/fluid/distributed/ps/service/heter_client.h b/paddle/fluid/distributed/ps/service/heter_client.h old mode 100644 new mode 100755 index 40423b24cfe..84fbee44043 --- a/paddle/fluid/distributed/ps/service/heter_client.h +++ b/paddle/fluid/distributed/ps/service/heter_client.h @@ -39,10 +39,10 @@ namespace framework { class Scope; } // namespace framework } // namespace paddle -DECLARE_int32(pserver_timeout_ms); + namespace paddle { namespace distributed { - +DECLARE_int32(pserver_timeout_ms); using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; using VarMsg = ::paddle::distributed::VariableMessage; diff --git a/paddle/fluid/distributed/ps/service/heter_server.h b/paddle/fluid/distributed/ps/service/heter_server.h old mode 100644 new mode 100755 index 915a60bbac9..7983d375e6a --- a/paddle/fluid/distributed/ps/service/heter_server.h +++ b/paddle/fluid/distributed/ps/service/heter_server.h @@ -52,14 +52,14 @@ class ProgramDesc; class Scope; } // namespace framework } // namespace paddle - DECLARE_double(eager_delete_tensor_gb); +namespace paddle { +namespace distributed { + DECLARE_int32(pserver_timeout_ms); DECLARE_int32(heter_world_size); DECLARE_int32(switch_send_recv_timeout_s); -namespace paddle { -namespace distributed { using MultiVarMsg = MultiVariableMessage; using VarMsg = VariableMessage; diff --git a/paddle/fluid/distributed/ps/service/ps_client.cc b/paddle/fluid/distributed/ps/service/ps_client.cc index 15fca2b1b64..5da600ab925 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_client.cc @@ -16,6 +16,7 @@ #include "glog/logging.h" #include "paddle/fluid/distributed/ps/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/ps/service/coordinator_client.h" #include "paddle/fluid/distributed/ps/service/graph_brpc_client.h" #include "paddle/fluid/distributed/ps/service/ps_local_client.h" #include "paddle/fluid/distributed/ps/table/table.h" @@ -25,8 +26,9 @@ namespace distributed { REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient); REGISTER_PSCORE_CLASS(PSClient, PsLocalClient); REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient); +REGISTER_PSCORE_CLASS(PSClient, CoordinatorClient); -int32_t PSClient::Configure( +int32_t PSClient::Configure( // called in FleetWrapper::InitWorker const PSParameter &config, const std::map> ®ions, PSEnvironment &env, diff --git a/paddle/fluid/distributed/ps/service/ps_client.h b/paddle/fluid/distributed/ps/service/ps_client.h index b9a6aa0390f..5654669d76f 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.h +++ b/paddle/fluid/distributed/ps/service/ps_client.h @@ -321,14 +321,16 @@ class PSClient { protected: virtual int32_t Initialize() = 0; - size_t _client_id; PSParameter _config; std::map> _dense_pull_regions; - PSEnvironment *_env; std::unordered_map> _table_accessors; std::unordered_map _msg_handler_map; // 处理client2client消息 + + public: + size_t _client_id; + PSEnvironment *_env; }; template diff --git a/paddle/fluid/distributed/ps/service/sendrecv.proto b/paddle/fluid/distributed/ps/service/sendrecv.proto index 57919b6a706..076cd2250a8 100755 --- a/paddle/fluid/distributed/ps/service/sendrecv.proto +++ b/paddle/fluid/distributed/ps/service/sendrecv.proto @@ -69,6 +69,8 @@ enum PsCmdID { PS_CHECK_SAVE_PRE_PATCH_DONE = 48; // pserver2pserver cmd start from 100 PS_S2S_MSG = 101; + PUSH_FL_CLIENT_INFO_SYNC = 200; + PUSH_FL_STRATEGY = 201; } message PsRequestMessage { @@ -85,6 +87,18 @@ message PsResponseMessage { optional bytes data = 3; }; +message CoordinatorReqMessage { + required uint32 cmd_id = 1; + optional int32 client_id = 2; + optional string str_params = 3; +}; + +message CoordinatorResMessage { + required int32 err_code = 1 [ default = 0 ]; + required string err_msg = 2 [ default = "" ]; + optional string str_params = 3; +}; + enum VarType { LOD_TENSOR = 0; SELECTED_ROWS = 1; @@ -134,6 +148,7 @@ message MultiVariableMessage { service PsService { rpc service(PsRequestMessage) returns (PsResponseMessage); + rpc FLService(CoordinatorReqMessage) returns (CoordinatorResMessage); rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage); rpc SendToWorker(MultiVariableMessage) returns (PsResponseMessage); rpc SendToSwitch(MultiVariableMessage) returns (PsResponseMessage); diff --git a/paddle/fluid/distributed/ps/service/server.cc b/paddle/fluid/distributed/ps/service/server.cc old mode 100644 new mode 100755 diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index 375ea05fc32..5df74883f92 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -146,6 +146,34 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, } } +void FleetWrapper::InitFlWorker(const std::vector& host_list, + int index, + const std::string& self_endpoint) { + assert(worker_ptr_.get() != nullptr); + uint32_t coordinator_num = host_list.size(); + ps_env_.SetCoordinators(&host_list, coordinator_num); + auto ptr = dynamic_cast(worker_ptr_.get()); + ptr->InitializeFlWorker(self_endpoint); + return; +} + +void FleetWrapper::PushFLClientInfoSync(const std::string& fl_client_info) { + // FLClientInfo fci; + // google::protobuf::TextFormat::ParseFromString(fl_client_info, &fci); + // InitGFlag(fci.init_gflags()); + auto ptr = dynamic_cast(worker_ptr_.get()); + VLOG(0) << "fl-ps > PushFLClientInfoSync: " << typeid(worker_ptr_).name() + << ", " << typeid(ptr).name() << ", " << typeid(BrpcPsClient).name(); + ptr->PushFLClientInfoSync(fl_client_info); + return; +} + +std::string FleetWrapper::PullFlStrategy() { + auto ptr = dynamic_cast(worker_ptr_.get()); + std::string str = ptr->PullFlStrategy(); + return str; +} + void FleetWrapper::StopServer() { VLOG(3) << "Going to stop server"; auto status = worker_ptr_->StopServer(); diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.h b/paddle/fluid/distributed/ps/wrapper/fleet.h old mode 100644 new mode 100755 index 2d3a371b4e7..28347b35027 --- a/paddle/fluid/distributed/ps/wrapper/fleet.h +++ b/paddle/fluid/distributed/ps/wrapper/fleet.h @@ -305,6 +305,14 @@ class FleetWrapper { void Revert(); void CheckSavePrePatchDone(); + //********* for fl-coordinator + void InitFlWorker(const std::vector& host_list, + int index, + const std::string& self_endpoint); + void PushFLClientInfoSync(const std::string& fl_client_info); + std::string PullFlStrategy(); + //********** + static std::shared_ptr pserver_ptr_; static std::shared_ptr worker_ptr_; diff --git a/paddle/fluid/distributed/the_one_ps.proto b/paddle/fluid/distributed/the_one_ps.proto old mode 100644 new mode 100755 index 38c9f4d5eb3..44c1ef3121c --- a/paddle/fluid/distributed/the_one_ps.proto +++ b/paddle/fluid/distributed/the_one_ps.proto @@ -239,3 +239,29 @@ message GraphFeature { repeated string dtype = 2; repeated int32 shape = 3; } + +message FLParameter { + optional FLStrategy fl_strategy = 1; + optional FLClientInfo client_info = 2; +} + +message FLStrategy { + optional uint64 iteration_num = 1; + optional uint64 client_id = 2; + optional string next_state = 3 [default = "JOIN"]; + optional string init_gflags = 4 [ default = "" ]; +} + +message FLClientInfo { + optional uint32 client_id = 1; + optional string device_type = 2; + optional int32 compute_capacity = 3; + optional int32 bandwidth = 4; + optional LocalTrainingResult local_training_result = 5; + optional string init_gflags = 6 [ default = "" ]; +} + +message LocalTrainingResult { + optional double acc = 1; + optional double loss = 2; +} diff --git a/paddle/fluid/framework/device_worker.cc b/paddle/fluid/framework/device_worker.cc index f1e5eb389b7..ae593542fb7 100644 --- a/paddle/fluid/framework/device_worker.cc +++ b/paddle/fluid/framework/device_worker.cc @@ -163,13 +163,13 @@ void DeviceWorker::DumpField(const Scope& scope, continue; } hit[i] = true; - } + } // dump_mode = 0 for (size_t i = 0; i < ins_id_vec.size(); i++) { if (!hit[i]) { continue; } ars[i] += ins_id_vec[i]; - ars[i] = ars[i] + "\t" + ins_content_vec[i]; + ars[i] += "\t" + ins_content_vec[i]; } for (auto& field : *dump_fields_) { Variable* var = scope.FindVar(field); @@ -202,8 +202,7 @@ void DeviceWorker::DumpField(const Scope& scope, continue; } auto bound = GetTensorBound(tensor, i); - ars[i] = ars[i] + "\t" + field + ":" + - std::to_string(bound.second - bound.first); + ars[i] += "\t" + field + ":" + std::to_string(bound.second - bound.first); ars[i] += PrintLodTensor(tensor, bound.first, bound.second); } } diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 45758389c54..832d91d131a 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -316,6 +316,7 @@ message DistributedStrategy { optional bool auto_search = 37 [ default = false ]; optional bool heter_ccl_mode = 38 [ default = false ]; optional bool is_fl_ps_mode = 39 [ default = false ]; + optional bool with_coordinator = 40 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc old mode 100644 new mode 100755 index c2c05f373c2..11afe6f280e --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -254,7 +254,6 @@ void MultiTrainer::Finalize() { if (need_dump_field_ || need_dump_param_) { FinalizeDumpEnv(); } - for (size_t i = 0; i < need_merge_var_names_.size(); i++) { Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]); if (root_var == nullptr) { @@ -297,8 +296,12 @@ void MultiTrainer::Finalize() { if (communicator == nullptr) { VLOG(0) << "MultiTrainer::Finalize communicator is null!"; } else { - communicator->_worker_ptr->Flush(); - VLOG(1) << "MultiTrainer::Finalize ps client flush done"; + if (communicator->_worker_ptr != nullptr) { + communicator->_worker_ptr->Flush(); + VLOG(1) << "MultiTrainer::Finalize ps client flush done"; + } else { + VLOG(0) << "communicator->_worker_ptr is null"; + } } #endif root_scope_->DropKids(); diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc old mode 100644 new mode 100755 index 03de3520959..f8501efde05 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -76,6 +76,9 @@ void BindDistFleetWrapper(py::module* m) { .def("get_cache_threshold", &FleetWrapper::GetCacheThreshold) .def("cache_shuffle", &FleetWrapper::CacheShuffle) .def("save_cache", &FleetWrapper::SaveCache) + .def("init_fl_worker", &FleetWrapper::InitFlWorker) + .def("push_fl_client_info_sync", &FleetWrapper::PushFLClientInfoSync) + .def("pull_fl_strategy", &FleetWrapper::PullFlStrategy) .def("revert", &FleetWrapper::Revert) .def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone); } @@ -131,6 +134,7 @@ void BindCommunicatorContext(py::module* m) { } using paddle::distributed::AsyncCommunicator; +using paddle::distributed::FLCommunicator; using paddle::distributed::GeoCommunicator; using paddle::distributed::RecvCtxMap; using paddle::distributed::RpcCtxMap; @@ -157,6 +161,9 @@ void BindDistCommunicator(py::module* m) { } else if (mode == "GEO") { Communicator::InitInstance( send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs); + } else if (mode == "WITH_COORDINATOR") { + Communicator::InitInstance( + send_ctx, recv_ctx, dist_desc, host_sign_list, param_scope, envs); } else { PADDLE_THROW(platform::errors::InvalidArgument( "unsuported communicator MODE")); @@ -172,7 +179,10 @@ void BindDistCommunicator(py::module* m) { .def("create_client_to_client_connection", &Communicator::CreateC2CConnection) .def("get_client_info", &Communicator::GetClientInfo) - .def("set_clients", &Communicator::SetClients); + .def("set_clients", &Communicator::SetClients) + .def("start_coordinator", &Communicator::StartCoordinator) + .def("query_fl_clients_info", &Communicator::QueryFLClientsInfo) + .def("save_fl_strategy", &Communicator::SaveFLStrategy); } void BindHeterClient(py::module* m) { diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py old mode 100644 new mode 100755 index 8c0394c9944..0cfb946d3d8 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -57,6 +57,10 @@ world_device_ids = fleet.world_device_ids local_rank = fleet.local_rank rank_in_node = local_rank is_worker = fleet.is_worker +is_coordinator = fleet.is_coordinator +init_coordinator = fleet.init_coordinator +make_fl_strategy = fleet.make_fl_strategy +get_fl_client = fleet.get_fl_client worker_endpoints = fleet.worker_endpoints server_num = fleet.server_num server_index = fleet.server_index diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index c58b539b687..6f8e2926abe 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1348,6 +1348,18 @@ class DistributedStrategy(object): else: print("WARNING: is_fl_ps_mode should have value of bool type") + @property + def is_with_coordinator(self): + return self.strategy.with_coordinator + + @is_with_coordinator.setter + @is_strict_auto + def is_with_coordinator(self, flag): + if isinstance(flag, bool): + self.strategy.with_coordinator = flag + else: + print("WARNING: with_coordinator should have value of bool type") + @pipeline.setter @is_strict_auto def pipeline(self, flag): diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index f4f2076cd12..1a9b3f565b7 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -511,6 +511,9 @@ class Fleet(object): """ return self._role_maker._is_worker() + def is_coordinator(self): + return self._role_maker._is_coordinator() + def worker_endpoints(self, to_string=False): """ Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"]. @@ -642,6 +645,25 @@ class Fleet(object): """ self._runtime_handle._init_worker(scopes) + @is_non_distributed_check + @inited_runtime_handler + def init_coordinator(self, scopes=None): + """ + initialize coordinator node + """ + self._runtime_handle._init_coordinator(scopes) + + def make_fl_strategy(self): + self._runtime_handle._make_fl_strategy() + + @is_non_distributed_check + @inited_runtime_handler + def get_fl_client(self): + """ + get worker(training node) ptr + """ + return self._runtime_handle._worker + @is_non_distributed_check @inited_runtime_handler def init_server(self, *args, **kwargs): diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py old mode 100644 new mode 100755 index 36155bbf1a2..67350be6210 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -30,6 +30,7 @@ class Role: SERVER = 2 HETER_WORKER = 3 ALL = 4 + COORDINATOR = 5 class Gloo(object): @@ -376,6 +377,7 @@ class RoleMakerBase(object): def __init__(self): self._worker_endpoints = [] self._server_endpoints = [] + self._cur_endpoint = "" self._role_is_generated = False self._role = None self._current_id = -1 @@ -544,6 +546,8 @@ class PaddleCloudRoleMaker(RoleMakerBase): self._server_endpoints = [] self._worker_endpoints = [] + self._coordinator_endpoints = None + self._with_coordinator = False self._gloo = Gloo() # gloo instance @@ -612,6 +616,11 @@ class PaddleCloudRoleMaker(RoleMakerBase): self._generate_role() return self._role == Role.SERVER + def _is_coordinator(self): + if not self._role_is_generated: + self._generate_role() + return self._role == Role.COORDINATOR + def _is_first_worker(self): """ whether current process is worker of rank 0 @@ -734,6 +743,11 @@ class PaddleCloudRoleMaker(RoleMakerBase): self._generate_role() return self._server_endpoints + def _get_coordinator_endpoints(self): + if not self._role_is_generated: + self._generate_role() + return self._coordinator_endpoints + def _get_previous_trainers(self): """ invoked by heter worker @@ -781,7 +795,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): self._generate_role() return self._role == Role.HETER_WORKER - def _ps_env(self): + def _ps_env(self): # each role will execute it # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set # format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002 self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None) @@ -806,6 +820,14 @@ class PaddleCloudRoleMaker(RoleMakerBase): else: self._worker_endpoints = [] + self._coordinator_endpoints = os.getenv("PADDLE_COORDINATOR_ENDPOINTS", + "") + if self._coordinator_endpoints == "": + print("fl-ps > coordinator address is null!") + else: + self._with_coordinator = True + self._coordinator_endpoints = self._coordinator_endpoints.split(",") + trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None) if trainers_num == None: raise ValueError( @@ -818,9 +840,11 @@ class PaddleCloudRoleMaker(RoleMakerBase): raise ValueError( "Can not find TRAINING_ROLE, please check your environment.") - if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]: + if training_role not in [ + "TRAINER", "PSERVER", "HETER_TRAINER", "COORDINATOR" + ]: raise ValueError( - "TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment." + "TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER or COORDINATOR, but get {}, please check your environment." .format(training_role)) # For Heter Parameter Server env setting @@ -862,29 +886,10 @@ class PaddleCloudRoleMaker(RoleMakerBase): "Can not Find PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ." ) - #self._is_heter_parameter_server_mode = True - #heter_trainers_num = len(all_heter_trainer_eplist.split(",")) - #self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",") else: self._is_heter_parameter_server_mode = False self._heter_trainers_num = 0 - #if previous_heter_trainer_eplist == "": - # self._is_heter_parameter_server_mode = False - # heter_trainers_num = 0 - #else: ## for the last heter worker - # try: - # previous_heter_trainer_eplist = os.environ[ - # "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"].split(",") - # self._previous_heter_trainer_endpoints = previous_heter_trainer_eplist - # except: - # raise ValueError( - # "Can not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ." - # ) - # self._is_heter_parameter_server_mode = True - # heter_trainers_num = len(all_heter_trainer_eplist.split(",")) - # self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",") - if training_role == "TRAINER": role = Role.WORKER current_id = os.getenv("PADDLE_TRAINER_ID", None) @@ -922,6 +927,10 @@ class PaddleCloudRoleMaker(RoleMakerBase): "Can not find POD_IP, please check your environment.") curr_endpoint = ":".join([cur_ip, cur_port]) self._cur_endpoint = curr_endpoint + elif training_role == "COORDINATOR": + print(">>> curr node is coordinator!") + role = Role.COORDINATOR + current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) elif training_role == "PSERVER": role = Role.SERVER cur_port = os.getenv("PADDLE_PORT", None) diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py old mode 100644 new mode 100755 index 78c0d9f1a74..158938b76d0 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -211,6 +211,10 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra type=str, default="", help="User defined workers ip:port") + ps_group.add_argument("--coordinators", + type=str, + default="", + help="User defined coordinators ip:port") ps_group.add_argument( "--heter_workers", type=str, @@ -223,6 +227,9 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra help="User defined heter devices in each stage cpu;gpu;cpu") ps_group.add_argument("--worker_num", type=int, help="number of workers") + ps_group.add_argument("--coordinator_num", + type=int, + help="number of coordinators") ps_group.add_argument("--server_num", type=int, help="number of servers") ps_group.add_argument("--heter_worker_num", type=str, @@ -474,6 +481,8 @@ def which_distributed_mode(args): ps_heter_args = ["--heter_worker_num", "--heter_workers", "--heter_devices"] + coordinator_args = ["--coordinator_num", "--coordinators"] + has_ps_args = [ ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1]) ] @@ -503,6 +512,7 @@ def which_distributed_mode(args): "Run parameter-sever mode. pserver arguments:{}, accelerators count:{}" .format(has_ps_args, accelerators)) has_ps_heter_args = list(set(has_ps_args) & set(ps_heter_args)) + has_coordinator_args = list(set(has_ps_args) & set(coordinator_args)) if len(has_ps_heter_args) > 0: return DistributeMode.PS_HETER else: diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py old mode 100644 new mode 100755 index e10709416f8..f2f9b4d87db --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -189,17 +189,19 @@ class Pod(object): self.trainers = [] self.servers = [] self.workers = [] + self.coordinators = [] self.heter_workers = [] self.accelerators = [] self.device_mode = None def __str__(self): return "rank:{} id:{} addr:{} port:{} visible_accelerator:{} trainers:{} servers:{} \ - workers:{} heter_workers:{}".format( + workers:{} heter_workers:{} coordinators:{}".format( self.rank, self.id, self.addr, self.port, self.accelerators, [str(t) for t in self.trainers], [str(s) for s in self.servers], [str(w) - for w in self.workers], [str(h) for h in self.heter_workers]) + for w in self.workers], [str(h) for h in self.heter_workers], + [str(c) for c in self.coordinators]) def __eq__(self, pod): if self.rank != pod.rank or \ @@ -1172,9 +1174,11 @@ class ParameterServerLauncher(object): def __init__(self, args, distribute_mode): self.args = args self.distribute_mode = distribute_mode + self.with_coordinator = False self.server_num = 0 self.worker_num = 0 self.heter_worker_num = 0 + self.coordinator_num = 0 self.server_endpoints = "" self.server_endpoints_ips = [] @@ -1188,6 +1192,10 @@ class ParameterServerLauncher(object): self.heter_worker_endpoints_ips = [] self.heter_worker_endpoints_port = [] + self.coordinator_endpoints = "" + self.coordinator_endpoints_ips = [] + self.coordinator_endpoints_port = [] + self.is_local = True self.current_node_ip = "" @@ -1257,6 +1265,23 @@ class ParameterServerLauncher(object): else: self.worker_endpoints = args.workers + # get coordinator envs + if args.coordinator_num: + self.with_coordinator = True + self.coordinator_num = args.coordinator_num + if args.coordinators: + assert len( + args.coordinators.split(",") + ) == self.coordinator_num, "The coordinator_num and coordinators doesn't match. Expect coordinators endpoints num epual to coordinator_num, but received coordinator enpoint num: {} and coordinator_num {}".format( + len(args.coordinators.split(",")), self.coordinator_num) + + self.coordinator_endpoints = args.coordinators + else: + ports = get_ports(self.coordinator_num, 1) + self.coordinator_endpoints = ",".join( + ["127.0.0.1:" + str(x) for x in ports]) + print(">>> use default coordinator addr(only one process)") + # get heter worker envs if self.distribute_mode == DistributeMode.PS_HETER: assert args.heter_devices != "", "The setting of Parameter-Server heter mode must has heter_devices." @@ -1398,6 +1423,17 @@ class ParameterServerLauncher(object): self.worker_endpoints_ips = [ x.strip().split(":")[0] for x in self.worker_endpoints.split(",") ] + + if self.with_coordinator == True: + self.coordinator_endpoints_ips = [ + x.strip().split(":")[0] + for x in self.coordinator_endpoints.split(",") + ] + self.coordinator_endpoints_port = [ + x.strip().split(":")[1] + for x in self.coordinator_endpoints.split(",") + ] + self.server_endpoints_port = [ x.strip().split(":")[1] for x in self.server_endpoints.split(",") ] @@ -1451,6 +1487,7 @@ class ParameterServerLauncher(object): server_rank = 0 worker_rank = 0 heter_worker_rank = 0 + coordinator_rank = 0 for node_rank, ip in enumerate(self.node_ips): pod = Pod() pod.rank = node_rank @@ -1472,6 +1509,16 @@ class ParameterServerLauncher(object): worker.stage = 1 worker_rank += 1 pod.workers.append(worker) + for m in range(len(self.coordinator_endpoints_ips)): + if ip == self.coordinator_endpoints_ips[m]: + coordinator = Trainer() + coordinator.endpoint = "%s:%s" % ( + ip, self.coordinator_endpoints_port[m]) + coordinator.rank = coordinator_rank + coordinator.stage = 1 + coordinator_rank += 1 + pod.coordinators.append(coordinator) + for k in range(len(self.heter_worker_endpoints_ips)): if ip == self.heter_worker_endpoints_ips[k]: heter_worker = Trainer() @@ -1488,18 +1535,36 @@ class ParameterServerLauncher(object): self.gloo_rendezvous_dir = tempfile.mkdtemp() # 3. subproces start - self.procs = {"worker": [], "server": [], "heter_worker": []} - self.cmds = {"worker": [], "server": [], "heter_worker": []} - self.log_fns = {"worker": [], "server": [], "heter_worker": []} + self.procs = { + "worker": [], + "coordinator": [], + "server": [], + "heter_worker": [] + } + self.cmds = { + "worker": [], + "coordinator": [], + "server": [], + "heter_worker": [] + } + self.log_fns = { + "worker": [], + "coordinator": [], + "server": [], + "heter_worker": [] + } self.start_pod_server(self.args, pod) self.start_pod_worker(self.args, pod) + if self.with_coordinator: + self.start_pod_coordinator(self.args, pod) if self.distribute_mode == DistributeMode.PS_HETER: self.start_pod_heter_worker(self.args, pod) logger.info( - "Please check servers, workers and heter_worker logs in {}/workerlog.*, {}/serverlog.* and {}/heterlog.*" - .format(self.args.log_dir, self.args.log_dir, self.args.log_dir)) + "Please check servers, workers, coordinator and heter_worker logs in {}/workerlog.*, {}/serverlog.* , {}/coordinatorlog.*, and {}/heterlog.*" + .format(self.args.log_dir, self.args.log_dir, self.args.log_dir, + self.args.log_dir)) # 4. wait for finish training if len(self.procs["worker"]) > 0: @@ -1524,6 +1589,12 @@ class ParameterServerLauncher(object): self.procs["server"][i].proc.terminate() logger.info("all parameter server are killed") + if len(self.procs["coordinator"]) > 0: + for i, proc in enumerate(self.procs["coordinator"]): + self.log_fns["coordinator"][i].close() + self.procs["coordinator"][i].proc.terminate() + logger.info("all coordinators are killed") + else: # if node has not worker procs # blocking training process @@ -1548,6 +1619,7 @@ class ParameterServerLauncher(object): proc_env = { "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, + "PADDLE_COORDINATOR_ENDPOINTS": self.coordinator_endpoints, "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST": self.heter_worker_endpoints, "PADDLE_PORT": cur_server.endpoint.split(":")[1], @@ -1563,6 +1635,7 @@ class ParameterServerLauncher(object): proc_env = { "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, + "PADDLE_COORDINATOR_ENDPOINTS": self.coordinator_endpoints, "PADDLE_PORT": cur_server.endpoint.split(":")[1], "TRAINING_ROLE": "PSERVER", "PADDLE_TRAINERS_NUM": str(self.worker_num), @@ -1633,6 +1706,8 @@ class ParameterServerLauncher(object): self.worker_endpoints, "PADDLE_TRAINERS_NUM": str(self.worker_num), + "PADDLE_COORDINATOR_ENDPOINTS": + self.coordinator_endpoints, "PADDLE_STAGE_TRAINERS_NUM": str(self.stage_trainer_num), "STAGE_ID": @@ -1678,6 +1753,7 @@ class ParameterServerLauncher(object): "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, "PADDLE_TRAINERS_NUM": str(self.worker_num), "TRAINING_ROLE": "TRAINER", + "PADDLE_COORDINATOR_ENDPOINTS": self.coordinator_endpoints, "POD_IP": cur_worker.endpoint.split(":")[0], "PADDLE_PORT": cur_worker.endpoint.split(":")[1], "PADDLE_TRAINER_ID": str(cur_worker.rank), @@ -1725,6 +1801,69 @@ class ParameterServerLauncher(object): self.procs["worker"].append(tp) + def start_pod_coordinator(self, args, pod): + print(">>> entering start_pod_coordinator") + default_env = os.environ.copy() + current_env = copy.copy(default_env) + current_env.pop("http_proxy", None) + current_env.pop("https_proxy", None) + + for idx, cur_coordinator in enumerate(pod.coordinators): + device_id = "0" + proc_env = { + "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, + "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, + "PADDLE_TRAINERS_NUM": str(self.worker_num), + "PADDLE_COORDINATOR_ENDPOINTS": self.coordinator_endpoints, + "PADDLE_COORDINATOR_NUM": str(self.coordinator_num), + "TRAINING_ROLE": "COORDINATOR", + "POD_IP": cur_coordinator.endpoint.split(":")[0], + "PADDLE_PORT": cur_coordinator.endpoint.split(":")[1], + "PADDLE_TRAINER_ID": str(cur_coordinator.rank), + "PADDLE_WITH_GLOO": str(os.getenv("PADDLE_WITH_GLOO", "0")), + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, + "FLAGS_selected_gpus": "0", + "FLAGS_selected_xpus": "0", + "CUDA_VISIBLE_DEVICES": device_id, + "XPU_VISIBLE_DEVICES": device_id, + "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port + } + + current_env.update(proc_env) + cmd = [sys.executable, "-u", args.training_script + ] + args.training_script_args + self.cmds["coordinator"].append(cmd) + + if idx == 0: + logger.info( + "Local coordinator start {} processes. First process distributed " + "environment info (Only For Debug): {}".format( + len(pod.coordinators), + pretty_print_envs(proc_env, + ("Distributed Envs", "Value")))) + + if args.log_dir is not None: + os.system("mkdir -p {}".format(args.log_dir)) + fn = open("%s/coordinator.%d" % (args.log_dir, idx), "w") + self.log_fns["coordinator"].append(fn) + proc = subprocess.Popen(cmd, + env=current_env, + stdout=fn, + stderr=fn) + else: + proc = subprocess.Popen(cmd, env=current_env) + + tp = TrainerProc() + tp.proc = proc + tp.rank = cur_coordinator.rank + tp.local_rank = idx + tp.log_fn = fn + tp.log_offset = fn.tell() if fn else None + tp.cmd = cmd + + self.procs["coordinator"].append(tp) + def start_pod_heter_worker(self, args, pod): default_env = os.environ.copy() current_env = copy.copy(default_env) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py index cd6bc03a5d5..fb1149dcba3 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py @@ -78,6 +78,8 @@ class ParameterServerOptimizer(MetaOptimizerBase): attrs['lr_decay_steps'] = self.user_defined_strategy.a_sync_configs[ "lr_decay_steps"] attrs['is_fl_ps_mode'] = self.user_defined_strategy.is_fl_ps_mode + attrs[ + 'with_coordinator'] = self.user_defined_strategy.is_with_coordinator attrs['k_steps'] = self.user_defined_strategy.a_sync_configs["k_steps"] attrs['launch_barrier'] = self.user_defined_strategy.a_sync_configs[ "launch_barrier"] diff --git a/python/paddle/distributed/ps/coordinator.py b/python/paddle/distributed/ps/coordinator.py new file mode 100755 index 00000000000..0d7fa87f245 --- /dev/null +++ b/python/paddle/distributed/ps/coordinator.py @@ -0,0 +1,351 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.communicator import FLCommunicator +from paddle.distributed.fleet.proto import the_one_ps_pb2 +from google.protobuf import text_format +from paddle.distributed.ps.utils.public import is_distributed_env +from paddle.distributed import fleet +import time +import abc +import os +import logging + +logging.basicConfig( + format='%(asctime)s %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s', + level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +class ClientInfoAttr: + CLIENT_ID = 0 + DEVICE_TYPE = 1 + COMPUTE_CAPACITY = 2 + BANDWIDTH = 3 + + +class FLStrategy: + JOIN = 0 + WAIT = 1 + FINISH = 2 + + +class ClientSelectorBase(abc.ABC): + + def __init__(self, fl_clients_info_mp): + self.fl_clients_info_mp = fl_clients_info_mp + self.clients_info = {} + self.fl_strategy = {} + + def parse_from_string(self): + if not self.fl_clients_info_mp: + logger.warning("fl-ps > fl_clients_info_mp is null!") + + for client_id, info in self.fl_clients_info_mp.items(): + self.fl_client_info_desc = the_one_ps_pb2.FLClientInfo() + text_format.Parse(bytes(info, encoding="utf8"), + self.fl_client_info_desc) + self.clients_info[client_id] = {} + self.clients_info[client_id][ + ClientInfoAttr. + DEVICE_TYPE] = self.fl_client_info_desc.device_type + self.clients_info[client_id][ + ClientInfoAttr. + COMPUTE_CAPACITY] = self.fl_client_info_desc.compute_capacity + self.clients_info[client_id][ + ClientInfoAttr.BANDWIDTH] = self.fl_client_info_desc.bandwidth + + @abc.abstractmethod + def select(self): + pass + + +class ClientSelector(ClientSelectorBase): + + def __init__(self, fl_clients_info_mp): + super().__init__(fl_clients_info_mp) + self.__fl_strategy = {} + + def select(self): + self.parse_from_string() + for client_id in self.clients_info: + logger.info("fl-ps > client {} info : {}".format( + client_id, self.clients_info[client_id])) + # ......... to implement ...... # + fl_strategy_desc = the_one_ps_pb2.FLStrategy() + fl_strategy_desc.iteration_num = 99 + fl_strategy_desc.client_id = 0 + fl_strategy_desc.next_state = "JOIN" + str_msg = text_format.MessageToString(fl_strategy_desc) + self.__fl_strategy[client_id] = str_msg + return self.__fl_strategy + + +class FLClientBase(abc.ABC): + + def __init__(self): + pass + + def set_basic_config(self, role_maker, config, metrics): + self.role_maker = role_maker + self.config = config + self.total_train_epoch = int(self.config.get("runner.epochs")) + self.train_statical_info = dict() + self.train_statical_info['speed'] = [] + self.epoch_idx = 0 + self.worker_index = fleet.worker_index() + self.main_program = paddle.static.default_main_program() + self.startup_program = paddle.static.default_startup_program() + self._client_ptr = fleet.get_fl_client() + self._coordinators = self.role_maker._get_coordinator_endpoints() + logger.info("fl-ps > coordinator enpoints: {}".format( + self._coordinators)) + self.strategy_handlers = dict() + self.exe = None + self.use_cuda = int(self.config.get("runner.use_gpu")) + self.place = paddle.CUDAPlace(0) if self.use_cuda else paddle.CPUPlace() + self.print_step = int(self.config.get("runner.print_interval")) + self.debug = self.config.get("runner.dataset_debug", False) + self.reader_type = self.config.get("runner.reader_type", "QueueDataset") + self.set_executor() + self.make_save_model_path() + self.set_metrics(metrics) + + def set_train_dataset_info(self, train_dataset, train_file_list): + self.train_dataset = train_dataset + self.train_file_list = train_file_list + logger.info("fl-ps > {}, data_feed_desc:\n {}".format( + type(self.train_dataset), self.train_dataset._desc())) + + def set_test_dataset_info(self, test_dataset, test_file_list): + self.test_dataset = test_dataset + self.test_file_list = test_file_list + + def set_train_example_num(self, num): + self.train_example_nums = num + + def load_dataset(self): + if self.reader_type == "InmemoryDataset": + self.train_dataset.load_into_memory() + + def release_dataset(self): + if self.reader_type == "InmemoryDataset": + self.train_dataset.release_memory() + + def set_executor(self): + self.exe = paddle.static.Executor(self.place) + + def make_save_model_path(self): + self.save_model_path = self.config.get("runner.model_save_path") + if self.save_model_path and (not os.path.exists(self.save_model_path)): + os.makedirs(self.save_model_path) + + def set_dump_fields(self): + # DumpField + # TrainerDesc -> SetDumpParamVector -> DumpParam -> DumpWork + if self.config.get("runner.need_dump"): + self.debug = True + dump_fields_path = "{}/epoch_{}".format( + self.config.get("runner.dump_fields_path"), self.epoch_idx) + dump_fields = self.config.get("runner.dump_fields", []) + dump_param = self.config.get("runner.dump_param", []) + persist_vars_list = self.main_program.all_parameters() + persist_vars_name = [ + str(param).split(":")[0].strip().split()[-1] + for param in persist_vars_list + ] + logger.info( + "fl-ps > persist_vars_list: {}".format(persist_vars_name)) + + if dump_fields_path is not None: + self.main_program._fleet_opt[ + 'dump_fields_path'] = dump_fields_path + if dump_fields is not None: + self.main_program._fleet_opt["dump_fields"] = dump_fields + if dump_param is not None: + self.main_program._fleet_opt["dump_param"] = dump_param + + def set_metrics(self, metrics): + self.metrics = metrics + self.fetch_vars = [var for _, var in self.metrics.items()] + + +class FLClient(FLClientBase): + + def __init__(self): + super(FLClient, self).__init__() + + def __build_fl_client_info_desc(self, state_info): + # ......... to implement ...... # + state_info = { + ClientInfoAttr.DEVICE_TYPE: "Andorid", + ClientInfoAttr.COMPUTE_CAPACITY: 10, + ClientInfoAttr.BANDWIDTH: 100 + } + client_info = the_one_ps_pb2.FLClientInfo() + client_info.device_type = state_info[ClientInfoAttr.DEVICE_TYPE] + client_info.compute_capacity = state_info[ + ClientInfoAttr.COMPUTE_CAPACITY] + client_info.bandwidth = state_info[ClientInfoAttr.BANDWIDTH] + str_msg = text_format.MessageToString(client_info) + return str_msg + + def run(self): + self.register_default_handlers() + self.print_program() + self.strategy_handlers['initialize_model_params']() + self.strategy_handlers['init_worker']() + self.load_dataset() + self.train_loop() + self.release_dataset() + self.strategy_handlers['finish']() + + def train_loop(self): + while self.epoch_idx < self.total_train_epoch: + logger.info("fl-ps > curr epoch idx: {}".format(self.epoch_idx)) + self.strategy_handlers['train']() + self.strategy_handlers['save_model']() + self.barrier() + state_info = { + "client id": self.worker_index, + "auc": 0.9, + "epoch": self.epoch_idx + } + self.push_fl_client_info_sync(state_info) + strategy_dict = self.pull_fl_strategy() + logger.info("fl-ps > recved fl strategy: {}".format(strategy_dict)) + # ......... to implement ...... # + if strategy_dict['next_state'] == "JOIN": + self.strategy_handlers['infer']() + elif strategy_dict['next_state'] == "FINISH": + self.strategy_handlers['finish']() + + def push_fl_client_info_sync(self, state_info): + str_msg = self.__build_fl_client_info_desc(state_info) + self._client_ptr.push_fl_client_info_sync(str_msg) + return + + def pull_fl_strategy(self): + strategy_dict = {} + fl_strategy_str = self._client_ptr.pull_fl_strategy( + ) # block: wait for coordinator's strategy arrived + logger.info("fl-ps > fl client recved fl_strategy(str):\n{}".format( + fl_strategy_str)) + fl_strategy_desc = the_one_ps_pb2.FLStrategy() + text_format.Parse(bytes(fl_strategy_str, encoding="utf8"), + fl_strategy_desc) + strategy_dict["next_state"] = fl_strategy_desc.next_state + return strategy_dict + + def barrier(self): + fleet.barrier_worker() + + def register_handlers(self, strategy_type, callback_func): + self.strategy_handlers[strategy_type] = callback_func + + def register_default_handlers(self): + self.register_handlers('train', self.callback_train) + self.register_handlers('infer', self.callback_infer) + self.register_handlers('finish', self.callback_finish) + self.register_handlers('initialize_model_params', + self.callback_initialize_model_params) + self.register_handlers('init_worker', self.callback_init_worker) + self.register_handlers('save_model', self.callback_save_model) + + def callback_init_worker(self): + fleet.init_worker() + + def callback_initialize_model_params(self): + if self.exe == None or self.main_program == None: + raise AssertionError("exe or main_program not set") + self.exe.run(self.startup_program) + + def callback_train(self): + epoch_start_time = time.time() + self.set_dump_fields() + fetch_info = [ + "Epoch {} Var {}".format(self.epoch_idx, var_name) + for var_name in self.metrics + ] + self.exe.train_from_dataset(program=self.main_program, + dataset=self.train_dataset, + fetch_list=self.fetch_vars, + fetch_info=fetch_info, + print_period=self.print_step, + debug=self.debug) + self.epoch_idx += 1 + epoch_time = time.time() - epoch_start_time + epoch_speed = self.train_example_nums / epoch_time + self.train_statical_info["speed"].append(epoch_speed) + logger.info("fl-ps > callback_train finished") + + def callback_infer(self): + fetch_info = [ + "Epoch {} Var {}".format(self.epoch_idx, var_name) + for var_name in self.metrics + ] + self.exe.infer_from_dataset(program=self.main_program, + dataset=self.test_dataset, + fetch_list=self.fetch_vars, + fetch_info=fetch_info, + print_period=self.print_step, + debug=self.debug) + + def callback_save_model(self): + model_dir = "{}/{}".format(self.save_model_path, self.epoch_idx) + if fleet.is_first_worker() and self.save_model_path: + if is_distributed_env(): + fleet.save_persistables(self.exe, model_dir) # save all params + else: + raise ValueError("it is not distributed env") + + def callback_finish(self): + fleet.stop_worker() + + def print_program(self): + with open("./{}_worker_main_program.prototxt".format(self.worker_index), + 'w+') as f: + f.write(str(self.main_program)) + with open( + "./{}_worker_startup_program.prototxt".format( + self.worker_index), 'w+') as f: + f.write(str(self.startup_program)) + + def print_train_statical_info(self): + with open("./train_statical_info.txt", 'w+') as f: + f.write(str(self.train_statical_info)) + + +class Coordinator(object): + + def __init__(self, ps_hosts): + self._communicator = FLCommunicator(ps_hosts) + self._client_selector = None + + def start_coordinator(self, self_endpoint, trainer_endpoints): + self._communicator.start_coordinator(self_endpoint, trainer_endpoints) + + def make_fl_strategy(self): + logger.info("fl-ps > running make_fl_strategy(loop) in coordinator\n") + while True: + # 1. get all fl clients reported info + str_map = self._communicator.query_fl_clients_info( + ) # block: wait for all fl clients info reported + # 2. generate fl strategy + self._client_selector = ClientSelector(str_map) + fl_strategy = self._client_selector.select() + # 3. save fl strategy from python to c++ + self._communicator.save_fl_strategy(fl_strategy) + time.sleep(5) diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 7d240983a1c..bee1ee169ef 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -29,6 +29,7 @@ from paddle.distributed.fleet.base.private_helper_function import wait_server_re from paddle.distributed.fleet.proto import the_one_ps_pb2 from paddle.fluid.communicator import Communicator, HeterClient from google.protobuf import text_format +from paddle.distributed.ps.coordinator import Coordinator __all__ = [ 'Table', 'SparseTable', 'GeoSparseTable', 'BarrierTable', 'TensorTable', @@ -774,6 +775,7 @@ class PsDescBuilder(object): self.fs_client = self._get_fs_client() self.ps_desc = the_one_ps_pb2.PSParameter() + self.fl_desc = the_one_ps_pb2.FLParameter() def _get_tensor_tables(self): program_idx = 0 @@ -812,6 +814,9 @@ class PsDescBuilder(object): def _get_fs_client(self): return fsClient(self.context["user_defined_strategy"].fs_client_param) + def build_fl_client_desc(self, client_info): + pass + def build_worker_desc(self): for table in self.tables: table_proto = self.ps_desc.worker_param.downpour_worker_param.downpour_table_param.add( @@ -850,6 +855,7 @@ class TheOnePSRuntime(RuntimeBase): self._communicator = None self._server = None self._worker = fluid.core.DistFleetWrapper() + self._coordinator = None self._server_sub_program = [] self._heter_client = None self._send_ctx = None @@ -857,6 +863,8 @@ class TheOnePSRuntime(RuntimeBase): def _set_basic_info(self, context): self.context = context self.role_maker = context["role_maker"] + self.role_id = get_role_id(self.role_maker) + self.debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) self.origin_main_program = context["origin_main_program"] self.origin_main_programs = context.get("origin_main_programs", @@ -878,6 +886,8 @@ class TheOnePSRuntime(RuntimeBase): self.context['tensor_table'] = {} build_var_distributed(self.context) + self.trainer_endpoints = get_trainer_endpoints(self.role_maker) + self.endpoints = get_ps_endpoints(self.role_maker) self.string_hosts = [] for idx, ep in enumerate(self.endpoints): @@ -885,6 +895,16 @@ class TheOnePSRuntime(RuntimeBase): pshost = fluid.core.PSHost(host, int(port), idx) self.string_hosts.append(pshost.serialize_to_string()) + self.with_coordinator = self.role_maker._with_coordinator + self.coordinator_hosts = [] + if self.with_coordinator: + print("fl-ps > all ps addrs: {}".format(self.string_hosts)) + coordinator_endpoints = self.role_maker._get_coordinator_endpoints() + for idx, ep in enumerate(coordinator_endpoints): + ip, port = ep.split(":") + pshost = fluid.core.PSHost(ip, int(port), idx) + self.coordinator_hosts.append(pshost.serialize_to_string()) + self.ps_desc_builder = PsDescBuilder(self.context) def _init_all_params(self, scopes, send_ctx, recv_map): @@ -933,8 +953,6 @@ class TheOnePSRuntime(RuntimeBase): def _init_worker(self, scopes=None): worker_desc = self.ps_desc_builder.build_worker_desc() - #with open("test_fl_ps_worker_desc", "w") as f: - # f.write(worker_desc) if self.context['use_ps_gpu']: main_program = self.context['loss'].block.program if not main_program._fleet_opt: @@ -963,10 +981,8 @@ class TheOnePSRuntime(RuntimeBase): self._send_ctx = send_ctx trainer_config = self.context['trainer'] - proto_txt = worker_desc - debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) - if debug: - print("worker: \n{}".format(proto_txt)) + if self.debug: + print("worker_desc: \n{}".format(worker_desc)) print("communicator send_ctx:") for key in send_ctx: print("{}: {}".format(key, send_ctx[key])) @@ -986,15 +1002,21 @@ class TheOnePSRuntime(RuntimeBase): print("communicator config:", trainer_config.get_communicator_flags()) - role_id = get_role_id(self.role_maker) - self._worker.init_worker(proto_txt, self.string_hosts, role_id) + self._worker.init_worker(worker_desc, self.string_hosts, self.role_id) + self.trainer_endpoint = get_trainer_endpoint(self.role_maker) + print("fl-ps > trainer_endpoint: {}".format(self.trainer_endpoint)) + print("fl-ps > with_coordinator? {}".format(self.with_coordinator)) + print("fl-ps > coordinator addr: {}".format(self.coordinator_hosts)) + if self.with_coordinator: + self._worker.init_fl_worker(self.coordinator_hosts, self.role_id, + self.trainer_endpoint) if self.context[ 'ps_mode'] == DistributedMode.GEO or self.is_heter_ps_mode: self._communicator = Communicator( trainer_config.mode, kwargs, trainer_config.get_communicator_flags()) - self._communicator.init_with_ctx(send_ctx, dense_map, proto_txt, + self._communicator.init_with_ctx(send_ctx, dense_map, worker_desc, self.string_hosts, fluid.global_scope()) fleet.util.barrier() @@ -1002,7 +1024,8 @@ class TheOnePSRuntime(RuntimeBase): # info = self._communicator.get_client_info() info = self._worker.get_client_info() if isinstance(info, list) and len(info) > 0: - all_info = self.role_maker._all_gather(info[0]) + all_info = self.role_maker._all_gather( + info[0]) # 收集其他 client 的 service 地址 # for unittest if not isinstance(all_info, list): warnings.warn("gloo may not initialize correctly") @@ -1045,7 +1068,7 @@ class TheOnePSRuntime(RuntimeBase): self._communicator.init_params(init_params) else: if not self.context['use_ps_gpu']: - if role_id == 0: + if self.role_id == 0: print("entering self._init_all_params()") self._init_all_params(scopes, send_ctx, dense_map) @@ -1080,21 +1103,32 @@ class TheOnePSRuntime(RuntimeBase): next_trainers, previous_trainers, self.role_maker._role_id()) # --> HeterClient::GetInstance + def _init_coordinator(self, scopes=None): + if self._coordinator == None: + self._coordinator = Coordinator(self.string_hosts) + + print(">>> curr node ip: {}".format(self.coordinator_hosts[0])) + print(">>> all trainer endpoints: {}".format(self.trainer_endpoints)) + self._coordinator.start_coordinator(self.coordinator_hosts[0], + self.trainer_endpoints) + + def _make_fl_strategy(self): + if self._coordinator == None: + assert ("Coordinator py object is null!") + else: + self._coordinator.make_fl_strategy() + def _init_server(self, dirname=None, var_names=None, **kwargs): server_desc = self.ps_desc_builder.build_server_desc() - #with open("test_fl_ps_server_desc", "w") as f: - # f.write(server_desc) - role_id = get_role_id(self.role_maker) trainers = get_trainers(self.role_maker) if self.is_heter_ps_mode: trainers += len(self.role_maker._get_heter_worker_endpoints()) - debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) - if debug: - print("server: \n{}".format(server_desc)) + if self.debug: + print("server_desc: \n{}".format(server_desc)) self._server = fluid.core.DistFleetWrapper() - self._server.init_server(server_desc, self.string_hosts, role_id, + self._server.init_server(server_desc, self.string_hosts, self.role_id, trainers, self._server_sub_program) dist_varnames = get_sparse_tablenames(self.origin_main_programs, True) diff --git a/python/paddle/distributed/ps/utils/public.py b/python/paddle/distributed/ps/utils/public.py index a57b30a8c19..2fc3284f609 100755 --- a/python/paddle/distributed/ps/utils/public.py +++ b/python/paddle/distributed/ps/utils/public.py @@ -250,6 +250,10 @@ def get_trainer_endpoint(role_maker): return role_maker._get_trainer_endpoint() +def get_trainer_endpoints(role_maker): + return role_maker._get_trainer_endpoints() + + def get_previous_stage_trainers(role_maker): try: return role_maker._get_previous_trainers() @@ -1591,3 +1595,11 @@ def debug_program(file, program): os.makedirs(os.path.dirname(file), exist_ok=True) with open(file, 'w+') as f: f.write(str(program)) + + +def is_distributed_env(): + node_role = os.getenv("TRAINING_ROLE") + if node_role is None: + return False + else: + return True diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py old mode 100644 new mode 100755 index 291a6b58377..251247f795a --- a/python/paddle/fluid/communicator.py +++ b/python/paddle/fluid/communicator.py @@ -34,7 +34,7 @@ It's a wrapper of a cpp class Communicator and should be used inside fleet API. from . import core from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode -__all__ = ['Communicator', 'LargeScaleKV'] +__all__ = ['Communicator', 'FLCommunicator', 'LargeScaleKV'] class Communicator(object): @@ -208,6 +208,37 @@ class Communicator(object): self.communicator_.push_sparse_param(var_name, table_id, scope) +class FLCommunicator(Communicator): + + def __init__(self, ps_hosts, kwargs=None): + mode = None + super(FLCommunicator, self).__init__(mode, kwargs) + send_ctx = {} + dense_map = {} + prototxt = "" + self.mode = "WITH_COORDINATOR" + self.init_with_ctx(send_ctx, dense_map, prototxt, ps_hosts) + + def start_coordinator(self, self_endpoint, trainer_endpoints): + if self.communicator_ != None: + self.communicator_.start_coordinator(self_endpoint, + trainer_endpoints) + return + + def save_fl_strategy(self, mp): + if self.communicator_ != None: + self.communicator_.save_fl_strategy(mp) + else: + raise ValueError("self.communicator_ is null") + return + + def query_fl_clients_info(self): + info_mp = {} + if self.communicator_ != None: + info_mp = self.communicator_.query_fl_clients_info() + return info_mp + + class LargeScaleKV(object): def __init__(self): diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index c7bfd19e5a9..8fd2283f0f9 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1673,8 +1673,8 @@ class Executor(object): return res def _dump_debug_info(self, program=None, trainer=None): - with open(str(id(program)) + "_train_desc.prototxt", "w") as fout: - fout.write(str(trainer)) + print("program_id: {}, trainer_desc:\n {}".format( + id(program), str(trainer))) if program._fleet_opt and "fleet_desc" in program._fleet_opt: with open("fleet_desc.prototxt", "w") as fout: fout.write(str(program._fleet_opt["fleet_desc"])) -- GitLab