diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index b1aeaca353e65ba7206c65bcde9bc28ec4b06416..8db32c5cc4d08aa5949d19346a37f19523b3fdb3 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -55,14 +55,14 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path, void FleetWrapper::InitServer( const std::string& dist_desc, - const std::vector& host_sign_list, int index, + const std::vector& host_sign_list, int index, int trainers, const std::vector& server_sub_program) { if (!is_initialized_) { VLOG(3) << "Going to init server"; pserver_ptr_ = std::shared_ptr( new paddle::distributed::PSCore()); pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(), - index, server_sub_program); + index, trainers, server_sub_program); is_initialized_ = true; } else { VLOG(3) << "Server can be initialized only once"; diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h index 5de278e067ecd307bd0e0a26a2ba7c0c4f72fb6e..03d915c500530ed3950e71ba687f2d3e92de89c1 100644 --- a/paddle/fluid/distributed/fleet.h +++ b/paddle/fluid/distributed/fleet.h @@ -156,7 +156,7 @@ class FleetWrapper { // const std::vector& host_sign_list, int index); void InitServer( const std::string& dist_desc, - const std::vector& host_sign_list, int index, + const std::vector& host_sign_list, int index, int trainers, const std::vector& server_sub_program = {}); // init trainer void InitWorker(const std::string& dist_desc, diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc index a6837cd4525b771f359b58530fd02f818001b1ad..b9afff8c4390620f9033b057d5bc96466f99eeff 100644 --- a/paddle/fluid/distributed/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -58,9 +58,11 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { std::string ip_port = ip + ":" + std::to_string(port); VLOG(3) << "server of rank " << _rank << " starts at " << ip_port; - int num_threads = std::thread::hardware_concurrency(); brpc::ServerOptions options; - options.num_threads = num_threads; + + int num_threads = std::thread::hardware_concurrency(); + auto trainers = _environment->get_trainers(); + options.num_threads = trainers > num_threads ? trainers : num_threads; if (_server.Start(ip_port.c_str(), &options) != 0) { LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port; diff --git a/paddle/fluid/distributed/service/env.h b/paddle/fluid/distributed/service/env.h index 42f31717f7fba4203cdbd24d59cfa2d9973d5e8a..206ff2c5cc48eacb433c3ccd13cf2407b3164b58 100644 --- a/paddle/fluid/distributed/service/env.h +++ b/paddle/fluid/distributed/service/env.h @@ -161,6 +161,10 @@ class PSEnvironment { return {}; } + virtual void set_trainers(int trainers) { trainers_ = trainers; } + + virtual int get_trainers() { return trainers_; } + protected: //注册一个host virtual int32_t registe_ps_host(const std::string &ip, uint32_t port, @@ -178,17 +182,11 @@ class PSEnvironment { host_list.push_back(host); sign_set.insert(rank); } - // if (sign_set.count(host.serialize_to_uint64()) > 0) { - // LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port - // << ", rank:" << host.rank - // << " already register, ignore register"; - // } else { - // host_list.push_back(host); - // sign_set.insert(host.serialize_to_uint64()); - // } return 0; } + int trainers_ = 0; + std::vector _ps_client_list; std::unordered_set _ps_client_sign_set; // for unique filter @@ -198,7 +196,7 @@ class PSEnvironment { class PaddlePSEnvironment : public PSEnvironment { public: - explicit PaddlePSEnvironment() {} + explicit PaddlePSEnvironment() { trainers_ = 0; } // NOLINT virtual ~PaddlePSEnvironment() {} virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { diff --git a/paddle/fluid/distributed/service/service.cc b/paddle/fluid/distributed/service/service.cc index 47b840cffd0808dae0f2ddb67a16b792cee3d57c..1d360eb5669b5191c7755bbc8a7b02b48019e537 100644 --- a/paddle/fluid/distributed/service/service.cc +++ b/paddle/fluid/distributed/service/service.cc @@ -69,11 +69,13 @@ void PSCore::init_gflag(const std::string& gflags) { int PSCore::init_server( const std::string& dist_desc, const std::vector* host_sign_list, int node_num, int index, + int trainers, const std::vector& server_sub_program) { google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); init_gflag(_ps_param.init_gflags()); _ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env.set_ps_servers(host_sign_list, node_num); + _ps_env.set_trainers(trainers); int ret = 0; _server_ptr = std::shared_ptr( paddle::distributed::PSServerFactory::create(_ps_param)); diff --git a/paddle/fluid/distributed/service/service.h b/paddle/fluid/distributed/service/service.h index b4ba691cced5feabc549238c3412203dee11f1c2..a8b86dafd8d7e5ce8217d12218c216b5a288942f 100644 --- a/paddle/fluid/distributed/service/service.h +++ b/paddle/fluid/distributed/service/service.h @@ -40,6 +40,7 @@ class PSCore { virtual int init_server( const std::string& dist_desc, const std::vector* host_sign_list, int node_num, int index, + int trainers, const std::vector& server_sub_program = {}); virtual int init_worker( const std::string& dist_desc, diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 20bf443689ef06d37287cbcb27392e9c0e137040..dc78e1ce485e0e7e662ac79f68a11de085e994f2 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -742,6 +742,7 @@ class TheOnePSRuntime(RuntimeBase): role_id = self.compiled_strategy.get_role_id() endpoints = self.compiled_strategy.get_ps_endpoints() is_sync = self.compiled_strategy.is_sync_mode() + trainers = self.compiled_strategy.get_trainers() server = self._get_fleet_proto(is_server=True, is_sync=is_sync) proto_txt = str(server) @@ -757,7 +758,7 @@ class TheOnePSRuntime(RuntimeBase): string_hosts.append(pshost.serialize_to_string()) self._server = fluid.core.DistFleetWrapper() - self._server.init_server(proto_txt, string_hosts, role_id, + self._server.init_server(proto_txt, string_hosts, role_id, trainers, self._server_sub_program) from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames