提交 aa9f5162 编写于 作者: Q qiaolongfei

code refine, add comment and some naming problem

上级 d32c7a6b
...@@ -20,10 +20,10 @@ using namespace paddle; // NOLINT ...@@ -20,10 +20,10 @@ using namespace paddle; // NOLINT
int main(int argc, char** argv) { int main(int argc, char** argv) {
initMain(argc, argv); initMain(argc, argv);
std::unique_ptr<ParameterServerController> pServerPtr( std::unique_ptr<ParameterServerController> parameterServerPtr(
paddle::ParameterServerController::createByGflags()); paddle::ParameterServerController::createFromGflags());
pServerPtr->start(); parameterServerPtr->start();
pServerPtr->join(); parameterServerPtr->wait();
return 0; return 0;
} }
...@@ -25,43 +25,44 @@ ParameterServerController::ParameterServerController( ...@@ -25,43 +25,44 @@ ParameterServerController::ParameterServerController(
int numPorts = config.ports_num() + config.ports_num_for_sparse(); int numPorts = config.ports_num() + config.ports_num_for_sparse();
if (config.nics().empty()) { if (config.nics().empty()) {
pservers_.resize(numPorts); parameterServers_.resize(numPorts);
for (int i = 0; i < numPorts; ++i) { for (int i = 0; i < numPorts; ++i) {
if (config.rdma_tcp() == "rdma") { if (config.rdma_tcp() == "rdma") {
pservers_[i].reset( parameterServers_[i].reset(
new ParameterServer2(std::string(), config.port() + i, rdmaCpu++)); new ParameterServer2(std::string(), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus; rdmaCpu = rdmaCpu % onlineCpus;
} else { } else {
pservers_[i].reset( parameterServers_[i].reset(
new ParameterServer2(std::string(), config.port() + i)); new ParameterServer2(std::string(), config.port() + i));
} }
CHECK(pservers_[i]->init()) << "Fail to initialize parameter server" CHECK(parameterServers_[i]->init()) << "Fail to initialize parameter "
"server on port "
<< config.port() + i; << config.port() + i;
} }
} else { } else {
str::split(config.nics(), ',', &devices); str::split(config.nics(), ',', &devices);
pservers_.resize(devices.size() * numPorts); parameterServers_.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) { for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) { for (size_t j = 0; j < devices.size(); ++j) {
if (config.rdma_tcp() == "rdma") { if (config.rdma_tcp() == "rdma") {
pservers_[i * devices.size() + j].reset(new ParameterServer2( parameterServers_[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), config.port() + i, rdmaCpu++)); getIpAddr(devices[j]), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus; rdmaCpu = rdmaCpu % onlineCpus;
} else { } else {
pservers_[i * devices.size() + j].reset( parameterServers_[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), config.port() + i)); new ParameterServer2(getIpAddr(devices[j]), config.port() + i));
} }
CHECK(pservers_[i * devices.size() + j]->init()) CHECK(parameterServers_[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j] << "Fail to initialize parameter server with device " << devices[j]
<< config.port() + i; << config.port() + i;
} }
} }
} }
} }
ParameterServerController::~ParameterServerController() { this->join(); } ParameterServerController::~ParameterServerController() { this->wait(); }
ParameterServerController* ParameterServerController::createByGflags() { ParameterServerController* ParameterServerController::createFromGflags() {
ParameterServerConfig config; ParameterServerConfig config;
config.set_nics(FLAGS_nics); config.set_nics(FLAGS_nics);
...@@ -79,21 +80,21 @@ ParameterServerController* ParameterServerController::create( ...@@ -79,21 +80,21 @@ ParameterServerController* ParameterServerController::create(
} }
void ParameterServerController::start() { void ParameterServerController::start() {
LOG(INFO) << "pserver sizes : " << pservers_.size(); LOG(INFO) << "number of parameterServer instances: "
<< parameterServers_.size();
int i = 0; int i = 0;
for (const auto& pserver : pservers_) { for (const auto& parameterServer : parameterServers_) {
LOG(INFO) << "pserver started : " << i; LOG(INFO) << "Starting parameterServer[" << i << "]";
pserver->start(); parameterServer->start();
i++; i++;
} }
} }
void ParameterServerController::join() { void ParameterServerController::wait() {
LOG(INFO) << "pserver sizes : " << pservers_.size();
int i = 0; int i = 0;
for (const auto& pserver : pservers_) { for (const auto& parameterServer : parameterServers_) {
LOG(INFO) << "pserver join : " << i; LOG(INFO) << "Waiting parameterServer[" << i << "]";
pserver->join(); parameterServer->join();
i++; i++;
} }
} }
......
...@@ -21,6 +21,12 @@ limitations under the License. */ ...@@ -21,6 +21,12 @@ limitations under the License. */
namespace paddle { namespace paddle {
/**
* @brief ParameterServerController is used for create, init and manage multi
* parameter server instances. The num of the instances is decided by port
* num(the ports number for parameter send) and network devices configured
* by gflags or proto.
*/
class ParameterServerController final { class ParameterServerController final {
public: public:
DISABLE_COPY(ParameterServerController); DISABLE_COPY(ParameterServerController);
...@@ -39,28 +45,30 @@ public: ...@@ -39,28 +45,30 @@ public:
* @brief create ParameterServerController from gflags, this is used for * @brief create ParameterServerController from gflags, this is used for
* compatibility with the old usage of configuration by gflags. * compatibility with the old usage of configuration by gflags.
*/ */
static ParameterServerController* createByGflags(); static ParameterServerController* createFromGflags();
/** /**
* @brief create ParameterServerController with ParameterServerConfig, remove * @brief create ParameterServerController with ParameterServerConfig, remove
* gflags from ParameterServer. Init all pservers thread according to the * gflags from ParameterServer. Init all ParameterServer2 instances according
* config. * to
* the config.
*/ */
static ParameterServerController* create(const ParameterServerConfig& config); static ParameterServerController* create(const ParameterServerConfig& config);
/** /**
* @brief start all pserver thread in this ParameterServerController. * @brief start all ParameterServer2 instances in this
* ParameterServerController.
*/ */
void start(); void start();
/** /**
* @brief join and wait for all pserver thread in this * @brief join and wait for all ParameterServer2 instances thread in this
* ParameterServerController. * ParameterServerController.
*/ */
void join(); void wait();
private: private:
std::vector<std::unique_ptr<ParameterServer2>> pservers_; std::vector<std::unique_ptr<ParameterServer2>> parameterServers_;
}; };
} // namespace paddle } // namespace paddle
...@@ -36,10 +36,11 @@ int main(int argc, char** argv) { ...@@ -36,10 +36,11 @@ int main(int argc, char** argv) {
initMain(argc, argv); initMain(argc, argv);
initPython(argc, argv); initPython(argc, argv);
std::unique_ptr<ParameterServerController> pServerPtr(nullptr); std::unique_ptr<ParameterServerController> parameterServerPtr(nullptr);
if (FLAGS_start_pserver) { if (FLAGS_start_pserver) {
pServerPtr.reset(paddle::ParameterServerController::createByGflags()); parameterServerPtr.reset(
pServerPtr->start(); paddle::ParameterServerController::createFromGflags());
parameterServerPtr->start();
} }
Trainer trainer; Trainer trainer;
auto config = TrainerConfigHelper::createFromFlags(); auto config = TrainerConfigHelper::createFromFlags();
......
...@@ -15,10 +15,17 @@ syntax = "proto2"; ...@@ -15,10 +15,17 @@ syntax = "proto2";
package paddle; package paddle;
/**
* Configuration structure for ParameterClient2.
*/
message ParameterClientConfig { message ParameterClientConfig {
required int32 trainer_id = 1; required int32 trainer_id = 1;
} }
/**
* Configuration structure for ParameterServer2.
*/
message ParameterServerConfig { message ParameterServerConfig {
// The ports number for parameter send, // The ports number for parameter send,
// increment based on default port number // increment based on default port number
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册