diff --git a/paddle/pserver/ParameterServer2Main.cpp b/paddle/pserver/ParameterServer2Main.cpp index 114505252211d9a45418e0e298c188785638f1c2..845a2c27e242cfbe31679fea6eae13d2b400ec81 100644 --- a/paddle/pserver/ParameterServer2Main.cpp +++ b/paddle/pserver/ParameterServer2Main.cpp @@ -20,10 +20,10 @@ using namespace paddle; // NOLINT int main(int argc, char** argv) { initMain(argc, argv); - std::unique_ptr pServerPtr( - paddle::ParameterServerController::createByGflags()); - pServerPtr->start(); - pServerPtr->join(); + std::unique_ptr parameterServerPtr( + paddle::ParameterServerController::createFromGflags()); + parameterServerPtr->start(); + parameterServerPtr->wait(); return 0; } diff --git a/paddle/pserver/ParameterServerController.cpp b/paddle/pserver/ParameterServerController.cpp index ec24bc7e573dce2cca484d0deabf4710fd19a25c..1d11a2e1acbc0f091901f3854ca99490d89ebe36 100644 --- a/paddle/pserver/ParameterServerController.cpp +++ b/paddle/pserver/ParameterServerController.cpp @@ -25,43 +25,44 @@ ParameterServerController::ParameterServerController( int numPorts = config.ports_num() + config.ports_num_for_sparse(); if (config.nics().empty()) { - pservers_.resize(numPorts); + parameterServers_.resize(numPorts); for (int i = 0; i < numPorts; ++i) { if (config.rdma_tcp() == "rdma") { - pservers_[i].reset( + parameterServers_[i].reset( new ParameterServer2(std::string(), config.port() + i, rdmaCpu++)); rdmaCpu = rdmaCpu % onlineCpus; } else { - pservers_[i].reset( + parameterServers_[i].reset( new ParameterServer2(std::string(), config.port() + i)); } - CHECK(pservers_[i]->init()) << "Fail to initialize parameter server" - << config.port() + i; + CHECK(parameterServers_[i]->init()) << "Fail to initialize parameter " + "server on port " + << config.port() + i; } } else { str::split(config.nics(), ',', &devices); - pservers_.resize(devices.size() * numPorts); + parameterServers_.resize(devices.size() * numPorts); for (int i = 0; i < numPorts; ++i) { for (size_t j = 0; j < devices.size(); ++j) { if (config.rdma_tcp() == "rdma") { - pservers_[i * devices.size() + j].reset(new ParameterServer2( + parameterServers_[i * devices.size() + j].reset(new ParameterServer2( getIpAddr(devices[j]), config.port() + i, rdmaCpu++)); rdmaCpu = rdmaCpu % onlineCpus; } else { - pservers_[i * devices.size() + j].reset( + parameterServers_[i * devices.size() + j].reset( new ParameterServer2(getIpAddr(devices[j]), config.port() + i)); } - CHECK(pservers_[i * devices.size() + j]->init()) - << "Fail to initialize parameter server" << devices[j] + CHECK(parameterServers_[i * devices.size() + j]->init()) + << "Fail to initialize parameter server with device " << devices[j] << config.port() + i; } } } } -ParameterServerController::~ParameterServerController() { this->join(); } +ParameterServerController::~ParameterServerController() { this->wait(); } -ParameterServerController* ParameterServerController::createByGflags() { +ParameterServerController* ParameterServerController::createFromGflags() { ParameterServerConfig config; config.set_nics(FLAGS_nics); @@ -79,21 +80,21 @@ ParameterServerController* ParameterServerController::create( } void ParameterServerController::start() { - LOG(INFO) << "pserver sizes : " << pservers_.size(); + LOG(INFO) << "number of parameterServer instances: " + << parameterServers_.size(); int i = 0; - for (const auto& pserver : pservers_) { - LOG(INFO) << "pserver started : " << i; - pserver->start(); + for (const auto& parameterServer : parameterServers_) { + LOG(INFO) << "Starting parameterServer[" << i << "]"; + parameterServer->start(); i++; } } -void ParameterServerController::join() { - LOG(INFO) << "pserver sizes : " << pservers_.size(); +void ParameterServerController::wait() { int i = 0; - for (const auto& pserver : pservers_) { - LOG(INFO) << "pserver join : " << i; - pserver->join(); + for (const auto& parameterServer : parameterServers_) { + LOG(INFO) << "Waiting parameterServer[" << i << "]"; + parameterServer->join(); i++; } } diff --git a/paddle/pserver/ParameterServerController.h b/paddle/pserver/ParameterServerController.h index ee249de9d802de141e5848be472b6def513a77d3..fe9bb0b4d02339d0d31d5bc2942932e1f876098b 100644 --- a/paddle/pserver/ParameterServerController.h +++ b/paddle/pserver/ParameterServerController.h @@ -21,6 +21,12 @@ limitations under the License. */ namespace paddle { +/** + * @brief ParameterServerController is used for create, init and manage multi + * parameter server instances. The num of the instances is decided by port + * num(the ports number for parameter send) and network devices configured + * by gflags or proto. + */ class ParameterServerController final { public: DISABLE_COPY(ParameterServerController); @@ -39,28 +45,30 @@ public: * @brief create ParameterServerController from gflags, this is used for * compatibility with the old usage of configuration by gflags. */ - static ParameterServerController* createByGflags(); + static ParameterServerController* createFromGflags(); /** * @brief create ParameterServerController with ParameterServerConfig, remove - * gflags from ParameterServer. Init all pservers thread according to the - * config. + * gflags from ParameterServer. Init all ParameterServer2 instances according + * to + * the config. */ static ParameterServerController* create(const ParameterServerConfig& config); /** - * @brief start all pserver thread in this ParameterServerController. + * @brief start all ParameterServer2 instances in this + * ParameterServerController. */ void start(); /** - * @brief join and wait for all pserver thread in this + * @brief join and wait for all ParameterServer2 instances thread in this * ParameterServerController. */ - void join(); + void wait(); private: - std::vector> pservers_; + std::vector> parameterServers_; }; } // namespace paddle diff --git a/paddle/trainer/TrainerMain.cpp b/paddle/trainer/TrainerMain.cpp index 61de728f2a2a2830c875bc7adb23031fb33b840a..c5c1d484e5f85c774fd4b8f1d4a8d46abfa2f547 100644 --- a/paddle/trainer/TrainerMain.cpp +++ b/paddle/trainer/TrainerMain.cpp @@ -36,10 +36,11 @@ int main(int argc, char** argv) { initMain(argc, argv); initPython(argc, argv); - std::unique_ptr pServerPtr(nullptr); + std::unique_ptr parameterServerPtr(nullptr); if (FLAGS_start_pserver) { - pServerPtr.reset(paddle::ParameterServerController::createByGflags()); - pServerPtr->start(); + parameterServerPtr.reset( + paddle::ParameterServerController::createFromGflags()); + parameterServerPtr->start(); } Trainer trainer; auto config = TrainerConfigHelper::createFromFlags(); diff --git a/proto/ParameterServerConfig.proto b/proto/ParameterServerConfig.proto index b4fbf901c20cce58a5f1819c05d3518902c4c165..3068bba8b10d89b432b41076dc6eb3ebc40b3883 100644 --- a/proto/ParameterServerConfig.proto +++ b/proto/ParameterServerConfig.proto @@ -15,10 +15,17 @@ syntax = "proto2"; package paddle; + +/** + * Configuration structure for ParameterClient2. + */ message ParameterClientConfig { required int32 trainer_id = 1; } +/** + * Configuration structure for ParameterServer2. + */ message ParameterServerConfig { // The ports number for parameter send, // increment based on default port number