未验证 提交 c9e78a22 编写于 作者: T tangwei12 提交者: GitHub

add trainers for pserver (#30523)

* add trainers for pserver

Change-Id: I1a75793ec81ce126d07f4c47cae09b95d530bbc8
上级 5067e3a8
......@@ -55,14 +55,14 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path,
void FleetWrapper::InitServer(
const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index,
const std::vector<std::string>& host_sign_list, int index, int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) {
if (!is_initialized_) {
VLOG(3) << "Going to init server";
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
new paddle::distributed::PSCore());
pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(),
index, server_sub_program);
index, trainers, server_sub_program);
is_initialized_ = true;
} else {
VLOG(3) << "Server can be initialized only once";
......
......@@ -156,7 +156,7 @@ class FleetWrapper {
// const std::vector<uint64_t>& host_sign_list, int index);
void InitServer(
const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index,
const std::vector<std::string>& host_sign_list, int index, int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {});
// init trainer
void InitWorker(const std::string& dist_desc,
......
......@@ -58,9 +58,11 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
std::string ip_port = ip + ":" + std::to_string(port);
VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
int num_threads = std::thread::hardware_concurrency();
brpc::ServerOptions options;
options.num_threads = num_threads;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->get_trainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port;
......
......@@ -161,6 +161,10 @@ class PSEnvironment {
return {};
}
virtual void set_trainers(int trainers) { trainers_ = trainers; }
virtual int get_trainers() { return trainers_; }
protected:
//注册一个host // NOLINT
virtual int32_t registe_ps_host(
......@@ -179,17 +183,11 @@ class PSEnvironment {
host_list.push_back(host);
sign_set.insert(rank);
}
// if (sign_set.count(host.serialize_to_uint64()) > 0) {
// LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port
// << ", rank:" << host.rank
// << " already register, ignore register";
// } else {
// host_list.push_back(host);
// sign_set.insert(host.serialize_to_uint64());
// }
return 0;
}
int trainers_ = 0;
std::vector<PSHost> _ps_client_list;
std::unordered_set<uint64_t> _ps_client_sign_set; // for unique filter
......
......@@ -69,11 +69,13 @@ void PSCore::init_gflag(const std::string& gflags) {
int PSCore::init_server(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
init_gflag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(host_sign_list, node_num);
_ps_env.set_trainers(trainers);
int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::create(_ps_param));
......
......@@ -40,6 +40,7 @@ class PSCore {
virtual int init_server(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list, int node_num, int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int init_worker(
const std::string& dist_desc,
......
......@@ -742,6 +742,7 @@ class TheOnePSRuntime(RuntimeBase):
role_id = self.compiled_strategy.get_role_id()
endpoints = self.compiled_strategy.get_ps_endpoints()
is_sync = self.compiled_strategy.is_sync_mode()
trainers = self.compiled_strategy.get_trainers()
server = self._get_fleet_proto(is_server=True, is_sync=is_sync)
proto_txt = str(server)
......@@ -757,7 +758,7 @@ class TheOnePSRuntime(RuntimeBase):
string_hosts.append(pshost.serialize_to_string())
self._server = fluid.core.DistFleetWrapper()
self._server.init_server(proto_txt, string_hosts, role_id,
self._server.init_server(proto_txt, string_hosts, role_id, trainers,
self._server_sub_program)
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册