未验证 提交 3bdf1544 编写于 作者: T tangwei12 提交者: GitHub

【Cherry-Pick】add trainer number for pserver (#30524)

* add trainers for pserver

Change-Id: I99c0ab1cc427318f1f9bf8f8f5faff2b8890645d

* add trainers for pserver

Change-Id: I1a75793ec81ce126d07f4c47cae09b95d530bbc8
上级 42f07437
...@@ -55,14 +55,14 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path, ...@@ -55,14 +55,14 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path,
void FleetWrapper::InitServer( void FleetWrapper::InitServer(
const std::string& dist_desc, const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index, const std::vector<std::string>& host_sign_list, int index, int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) { const std::vector<framework::ProgramDesc>& server_sub_program) {
if (!is_initialized_) { if (!is_initialized_) {
VLOG(3) << "Going to init server"; VLOG(3) << "Going to init server";
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>( pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
new paddle::distributed::PSCore()); new paddle::distributed::PSCore());
pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(), pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(),
index, server_sub_program); index, trainers, server_sub_program);
is_initialized_ = true; is_initialized_ = true;
} else { } else {
VLOG(3) << "Server can be initialized only once"; VLOG(3) << "Server can be initialized only once";
......
...@@ -156,7 +156,7 @@ class FleetWrapper { ...@@ -156,7 +156,7 @@ class FleetWrapper {
// const std::vector<uint64_t>& host_sign_list, int index); // const std::vector<uint64_t>& host_sign_list, int index);
void InitServer( void InitServer(
const std::string& dist_desc, const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index, const std::vector<std::string>& host_sign_list, int index, int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {}); const std::vector<framework::ProgramDesc>& server_sub_program = {});
// init trainer // init trainer
void InitWorker(const std::string& dist_desc, void InitWorker(const std::string& dist_desc,
......
...@@ -58,9 +58,11 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { ...@@ -58,9 +58,11 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
std::string ip_port = ip + ":" + std::to_string(port); std::string ip_port = ip + ":" + std::to_string(port);
VLOG(3) << "server of rank " << _rank << " starts at " << ip_port; VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
int num_threads = std::thread::hardware_concurrency();
brpc::ServerOptions options; brpc::ServerOptions options;
options.num_threads = num_threads;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->get_trainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) { if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port; LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port;
......
...@@ -161,6 +161,10 @@ class PSEnvironment { ...@@ -161,6 +161,10 @@ class PSEnvironment {
return {}; return {};
} }
virtual void set_trainers(int trainers) { trainers_ = trainers; }
virtual int get_trainers() { return trainers_; }
protected: protected:
//注册一个host //注册一个host
virtual int32_t registe_ps_host(const std::string &ip, uint32_t port, virtual int32_t registe_ps_host(const std::string &ip, uint32_t port,
...@@ -178,17 +182,11 @@ class PSEnvironment { ...@@ -178,17 +182,11 @@ class PSEnvironment {
host_list.push_back(host); host_list.push_back(host);
sign_set.insert(rank); sign_set.insert(rank);
} }
// if (sign_set.count(host.serialize_to_uint64()) > 0) {
// LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port
// << ", rank:" << host.rank
// << " already register, ignore register";
// } else {
// host_list.push_back(host);
// sign_set.insert(host.serialize_to_uint64());
// }
return 0; return 0;
} }
int trainers_ = 0;
std::vector<PSHost> _ps_client_list; std::vector<PSHost> _ps_client_list;
std::unordered_set<uint64_t> _ps_client_sign_set; // for unique filter std::unordered_set<uint64_t> _ps_client_sign_set; // for unique filter
...@@ -198,7 +196,7 @@ class PSEnvironment { ...@@ -198,7 +196,7 @@ class PSEnvironment {
class PaddlePSEnvironment : public PSEnvironment { class PaddlePSEnvironment : public PSEnvironment {
public: public:
explicit PaddlePSEnvironment() {} explicit PaddlePSEnvironment() { trainers_ = 0; } // NOLINT
virtual ~PaddlePSEnvironment() {} virtual ~PaddlePSEnvironment() {}
virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) {
......
...@@ -69,11 +69,13 @@ void PSCore::init_gflag(const std::string& gflags) { ...@@ -69,11 +69,13 @@ void PSCore::init_gflag(const std::string& gflags) {
int PSCore::init_server( int PSCore::init_server(
const std::string& dist_desc, const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index, const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) { const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags()); init_gflag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num); _ps_env.set_ps_servers(host_sign_list, node_num);
_ps_env.set_trainers(trainers);
int ret = 0; int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>( _server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(_ps_param)); paddle::distributed::PSServerFactory::create(_ps_param));
......
...@@ -40,6 +40,7 @@ class PSCore { ...@@ -40,6 +40,7 @@ class PSCore {
virtual int init_server( virtual int init_server(
const std::string& dist_desc, const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index, const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {}); const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int init_worker( virtual int init_worker(
const std::string& dist_desc, const std::string& dist_desc,
......
...@@ -742,6 +742,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -742,6 +742,7 @@ class TheOnePSRuntime(RuntimeBase):
role_id = self.compiled_strategy.get_role_id() role_id = self.compiled_strategy.get_role_id()
endpoints = self.compiled_strategy.get_ps_endpoints() endpoints = self.compiled_strategy.get_ps_endpoints()
is_sync = self.compiled_strategy.is_sync_mode() is_sync = self.compiled_strategy.is_sync_mode()
trainers = self.compiled_strategy.get_trainers()
server = self._get_fleet_proto(is_server=True, is_sync=is_sync) server = self._get_fleet_proto(is_server=True, is_sync=is_sync)
proto_txt = str(server) proto_txt = str(server)
...@@ -757,7 +758,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -757,7 +758,7 @@ class TheOnePSRuntime(RuntimeBase):
string_hosts.append(pshost.serialize_to_string()) string_hosts.append(pshost.serialize_to_string())
self._server = fluid.core.DistFleetWrapper() self._server = fluid.core.DistFleetWrapper()
self._server.init_server(proto_txt, string_hosts, role_id, self._server.init_server(proto_txt, string_hosts, role_id, trainers,
self._server_sub_program) self._server_sub_program)
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册