From 3bdf154414d77ebfeba68cfb95b096738b68b475 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 19 Jan 2021 16:06:08 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Cherry-Pick=E3=80=91add=20trainer=20nu?= =?UTF-8?q?mber=20for=20pserver=20(#30524)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add trainers for pserver Change-Id: I99c0ab1cc427318f1f9bf8f8f5faff2b8890645d * add trainers for pserver Change-Id: I1a75793ec81ce126d07f4c47cae09b95d530bbc8 --- paddle/fluid/distributed/fleet.cc | 4 ++-- paddle/fluid/distributed/fleet.h | 2 +- .../fluid/distributed/service/brpc_ps_server.cc | 6 ++++-- paddle/fluid/distributed/service/env.h | 16 +++++++--------- paddle/fluid/distributed/service/service.cc | 2 ++ paddle/fluid/distributed/service/service.h | 1 + .../distributed/fleet/runtime/the_one_ps.py | 3 ++- 7 files changed, 19 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index b1aeaca353..8db32c5cc4 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -55,14 +55,14 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path, void FleetWrapper::InitServer( const std::string& dist_desc, - const std::vector& host_sign_list, int index, + const std::vector& host_sign_list, int index, int trainers, const std::vector& server_sub_program) { if (!is_initialized_) { VLOG(3) << "Going to init server"; pserver_ptr_ = std::shared_ptr( new paddle::distributed::PSCore()); pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(), - index, server_sub_program); + index, trainers, server_sub_program); is_initialized_ = true; } else { VLOG(3) << "Server can be initialized only once"; diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h index 5de278e067..03d915c500 100644 --- a/paddle/fluid/distributed/fleet.h +++ b/paddle/fluid/distributed/fleet.h @@ -156,7 +156,7 @@ class FleetWrapper { // const std::vector& host_sign_list, int index); void InitServer( const std::string& dist_desc, - const std::vector& host_sign_list, int index, + const std::vector& host_sign_list, int index, int trainers, const std::vector& server_sub_program = {}); // init trainer void InitWorker(const std::string& dist_desc, diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc index a6837cd452..b9afff8c43 100644 --- a/paddle/fluid/distributed/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -58,9 +58,11 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { std::string ip_port = ip + ":" + std::to_string(port); VLOG(3) << "server of rank " << _rank << " starts at " << ip_port; - int num_threads = std::thread::hardware_concurrency(); brpc::ServerOptions options; - options.num_threads = num_threads; + + int num_threads = std::thread::hardware_concurrency(); + auto trainers = _environment->get_trainers(); + options.num_threads = trainers > num_threads ? trainers : num_threads; if (_server.Start(ip_port.c_str(), &options) != 0) { LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port; diff --git a/paddle/fluid/distributed/service/env.h b/paddle/fluid/distributed/service/env.h index 42f31717f7..206ff2c5cc 100644 --- a/paddle/fluid/distributed/service/env.h +++ b/paddle/fluid/distributed/service/env.h @@ -161,6 +161,10 @@ class PSEnvironment { return {}; } + virtual void set_trainers(int trainers) { trainers_ = trainers; } + + virtual int get_trainers() { return trainers_; } + protected: //注册一个host virtual int32_t registe_ps_host(const std::string &ip, uint32_t port, @@ -178,17 +182,11 @@ class PSEnvironment { host_list.push_back(host); sign_set.insert(rank); } - // if (sign_set.count(host.serialize_to_uint64()) > 0) { - // LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port - // << ", rank:" << host.rank - // << " already register, ignore register"; - // } else { - // host_list.push_back(host); - // sign_set.insert(host.serialize_to_uint64()); - // } return 0; } + int trainers_ = 0; + std::vector _ps_client_list; std::unordered_set _ps_client_sign_set; // for unique filter @@ -198,7 +196,7 @@ class PSEnvironment { class PaddlePSEnvironment : public PSEnvironment { public: - explicit PaddlePSEnvironment() {} + explicit PaddlePSEnvironment() { trainers_ = 0; } // NOLINT virtual ~PaddlePSEnvironment() {} virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { diff --git a/paddle/fluid/distributed/service/service.cc b/paddle/fluid/distributed/service/service.cc index 47b840cffd..1d360eb566 100644 --- a/paddle/fluid/distributed/service/service.cc +++ b/paddle/fluid/distributed/service/service.cc @@ -69,11 +69,13 @@ void PSCore::init_gflag(const std::string& gflags) { int PSCore::init_server( const std::string& dist_desc, const std::vector* host_sign_list, int node_num, int index, + int trainers, const std::vector& server_sub_program) { google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); init_gflag(_ps_param.init_gflags()); _ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env.set_ps_servers(host_sign_list, node_num); + _ps_env.set_trainers(trainers); int ret = 0; _server_ptr = std::shared_ptr( paddle::distributed::PSServerFactory::create(_ps_param)); diff --git a/paddle/fluid/distributed/service/service.h b/paddle/fluid/distributed/service/service.h index b4ba691cce..a8b86dafd8 100644 --- a/paddle/fluid/distributed/service/service.h +++ b/paddle/fluid/distributed/service/service.h @@ -40,6 +40,7 @@ class PSCore { virtual int init_server( const std::string& dist_desc, const std::vector* host_sign_list, int node_num, int index, + int trainers, const std::vector& server_sub_program = {}); virtual int init_worker( const std::string& dist_desc, diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 20bf443689..dc78e1ce48 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -742,6 +742,7 @@ class TheOnePSRuntime(RuntimeBase): role_id = self.compiled_strategy.get_role_id() endpoints = self.compiled_strategy.get_ps_endpoints() is_sync = self.compiled_strategy.is_sync_mode() + trainers = self.compiled_strategy.get_trainers() server = self._get_fleet_proto(is_server=True, is_sync=is_sync) proto_txt = str(server) @@ -757,7 +758,7 @@ class TheOnePSRuntime(RuntimeBase): string_hosts.append(pshost.serialize_to_string()) self._server = fluid.core.DistFleetWrapper() - self._server.init_server(proto_txt, string_hosts, role_id, + self._server.init_server(proto_txt, string_hosts, role_id, trainers, self._server_sub_program) from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames -- GitLab