提交 aa9f5162 编写于 作者: Q qiaolongfei

code refine, add comment and some naming problem

上级 d32c7a6b
......@@ -20,10 +20,10 @@ using namespace paddle; // NOLINT
int main(int argc, char** argv) {
initMain(argc, argv);
std::unique_ptr<ParameterServerController> pServerPtr(
paddle::ParameterServerController::createByGflags());
pServerPtr->start();
pServerPtr->join();
std::unique_ptr<ParameterServerController> parameterServerPtr(
paddle::ParameterServerController::createFromGflags());
parameterServerPtr->start();
parameterServerPtr->wait();
return 0;
}
......@@ -25,43 +25,44 @@ ParameterServerController::ParameterServerController(
int numPorts = config.ports_num() + config.ports_num_for_sparse();
if (config.nics().empty()) {
pservers_.resize(numPorts);
parameterServers_.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (config.rdma_tcp() == "rdma") {
pservers_[i].reset(
parameterServers_[i].reset(
new ParameterServer2(std::string(), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers_[i].reset(
parameterServers_[i].reset(
new ParameterServer2(std::string(), config.port() + i));
}
CHECK(pservers_[i]->init()) << "Fail to initialize parameter server"
CHECK(parameterServers_[i]->init()) << "Fail to initialize parameter "
"server on port "
<< config.port() + i;
}
} else {
str::split(config.nics(), ',', &devices);
pservers_.resize(devices.size() * numPorts);
parameterServers_.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (config.rdma_tcp() == "rdma") {
pservers_[i * devices.size() + j].reset(new ParameterServer2(
parameterServers_[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers_[i * devices.size() + j].reset(
parameterServers_[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), config.port() + i));
}
CHECK(pservers_[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
CHECK(parameterServers_[i * devices.size() + j]->init())
<< "Fail to initialize parameter server with device " << devices[j]
<< config.port() + i;
}
}
}
}
ParameterServerController::~ParameterServerController() { this->join(); }
ParameterServerController::~ParameterServerController() { this->wait(); }
ParameterServerController* ParameterServerController::createByGflags() {
ParameterServerController* ParameterServerController::createFromGflags() {
ParameterServerConfig config;
config.set_nics(FLAGS_nics);
......@@ -79,21 +80,21 @@ ParameterServerController* ParameterServerController::create(
}
void ParameterServerController::start() {
LOG(INFO) << "pserver sizes : " << pservers_.size();
LOG(INFO) << "number of parameterServer instances: "
<< parameterServers_.size();
int i = 0;
for (const auto& pserver : pservers_) {
LOG(INFO) << "pserver started : " << i;
pserver->start();
for (const auto& parameterServer : parameterServers_) {
LOG(INFO) << "Starting parameterServer[" << i << "]";
parameterServer->start();
i++;
}
}
void ParameterServerController::join() {
LOG(INFO) << "pserver sizes : " << pservers_.size();
void ParameterServerController::wait() {
int i = 0;
for (const auto& pserver : pservers_) {
LOG(INFO) << "pserver join : " << i;
pserver->join();
for (const auto& parameterServer : parameterServers_) {
LOG(INFO) << "Waiting parameterServer[" << i << "]";
parameterServer->join();
i++;
}
}
......
......@@ -21,6 +21,12 @@ limitations under the License. */
namespace paddle {
/**
* @brief ParameterServerController is used for create, init and manage multi
* parameter server instances. The num of the instances is decided by port
* num(the ports number for parameter send) and network devices configured
* by gflags or proto.
*/
class ParameterServerController final {
public:
DISABLE_COPY(ParameterServerController);
......@@ -39,28 +45,30 @@ public:
* @brief create ParameterServerController from gflags, this is used for
* compatibility with the old usage of configuration by gflags.
*/
static ParameterServerController* createByGflags();
static ParameterServerController* createFromGflags();
/**
* @brief create ParameterServerController with ParameterServerConfig, remove
* gflags from ParameterServer. Init all pservers thread according to the
* config.
* gflags from ParameterServer. Init all ParameterServer2 instances according
* to
* the config.
*/
static ParameterServerController* create(const ParameterServerConfig& config);
/**
* @brief start all pserver thread in this ParameterServerController.
* @brief start all ParameterServer2 instances in this
* ParameterServerController.
*/
void start();
/**
* @brief join and wait for all pserver thread in this
* @brief join and wait for all ParameterServer2 instances thread in this
* ParameterServerController.
*/
void join();
void wait();
private:
std::vector<std::unique_ptr<ParameterServer2>> pservers_;
std::vector<std::unique_ptr<ParameterServer2>> parameterServers_;
};
} // namespace paddle
......@@ -36,10 +36,11 @@ int main(int argc, char** argv) {
initMain(argc, argv);
initPython(argc, argv);
std::unique_ptr<ParameterServerController> pServerPtr(nullptr);
std::unique_ptr<ParameterServerController> parameterServerPtr(nullptr);
if (FLAGS_start_pserver) {
pServerPtr.reset(paddle::ParameterServerController::createByGflags());
pServerPtr->start();
parameterServerPtr.reset(
paddle::ParameterServerController::createFromGflags());
parameterServerPtr->start();
}
Trainer trainer;
auto config = TrainerConfigHelper::createFromFlags();
......
......@@ -15,10 +15,17 @@ syntax = "proto2";
package paddle;
/**
* Configuration structure for ParameterClient2.
*/
message ParameterClientConfig {
required int32 trainer_id = 1;
}
/**
* Configuration structure for ParameterServer2.
*/
message ParameterServerConfig {
// The ports number for parameter send,
// increment based on default port number
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册