diff --git a/demo/quick_start/cluster/pserver.sh b/demo/quick_start/cluster/pserver.sh index 4e1ffe5139e27b4f1209e6b22b42e17d0bbc1b0c..b187c1d9b9108a607ed310253d54ecc096f0e792 100755 --- a/demo/quick_start/cluster/pserver.sh +++ b/demo/quick_start/cluster/pserver.sh @@ -19,7 +19,7 @@ source "$bin_dir/env.sh" paddle pserver \ --nics=`get_nics` \ --port=7164 \ - --ports_num=2 \ + --ports_num=1 \ --ports_num_for_sparse=1 \ --num_gradient_servers=1 \ --comment="paddle_pserver" \ diff --git a/paddle/pserver/PServerUtil.cpp b/paddle/pserver/PServerUtil.cpp index e64569793613cf4c9ed38152d081bd450086dcdd..68a91743306263ffb323ca93ffe8273402945f7a 100644 --- a/paddle/pserver/PServerUtil.cpp +++ b/paddle/pserver/PServerUtil.cpp @@ -16,30 +16,11 @@ limitations under the License. */ namespace paddle { -ParameterServerConfig* PServerUtil::initConfig() { - ParameterServerConfig* config = new ParameterServerConfig(); - config->set_nics(FLAGS_nics); - config->set_port(FLAGS_port); - config->set_ports_num(FLAGS_ports_num); - config->set_rdma_tcp(FLAGS_rdma_tcp); - return config; -} - -PServerUtil* PServerUtil::create() { - auto& pServerConfig = *paddle::PServerUtil::initConfig(); - return PServerUtil::create(pServerConfig); -} - -PServerUtil* PServerUtil::create(const ParameterServerConfig& config) { - return new PServerUtil(config); -} - PServerUtil::PServerUtil(const ParameterServerConfig& config) { // round robin to load balance RDMA server ENGINE std::vector devices; int rdmaCpu = 0; int onlineCpus = rdma::numCpus(); - ; int numPorts = config.ports_num() + config.ports_num_for_sparse(); if (FLAGS_nics.empty()) { @@ -78,6 +59,24 @@ PServerUtil::PServerUtil(const ParameterServerConfig& config) { PServerUtil::~PServerUtil() { this->join(); } +ParameterServerConfig* PServerUtil::initConfig() { + ParameterServerConfig* config = new ParameterServerConfig(); + config->set_nics(FLAGS_nics); + config->set_port(FLAGS_port); + config->set_ports_num(FLAGS_ports_num); + config->set_rdma_tcp(FLAGS_rdma_tcp); + return config; +} + +PServerUtil* PServerUtil::createWithGflags() { + auto& pServerConfig = *paddle::PServerUtil::initConfig(); + return create(pServerConfig); +} + +PServerUtil* PServerUtil::create(const ParameterServerConfig& config) { + return new PServerUtil(config); +} + void PServerUtil::start() { LOG(INFO) << "pserver sizes : " << pservers_.size(); int i = 0; diff --git a/paddle/pserver/PServerUtil.h b/paddle/pserver/PServerUtil.h index dd8d32a4e9bc6a957ec4af0e173099cb4d1c3603..117dde37e3f8846505b1d75bec5033f3f4141c02 100644 --- a/paddle/pserver/PServerUtil.h +++ b/paddle/pserver/PServerUtil.h @@ -24,16 +24,47 @@ namespace paddle { class PServerUtil { public: DISABLE_COPY(PServerUtil); - static PServerUtil* create(); - static PServerUtil* create(const ParameterServerConfig& config); + + /** + * @brief Ctor, Create a PServerUtil from ParameterServerConfig. + */ explicit PServerUtil(const ParameterServerConfig& config); + + /** + * @brief Dtor. + */ ~PServerUtil(); - static ParameterServerConfig* initConfig(); + + /** + * @brief create PServerUtil from gflags, this is used for + * compatibility with the old usage of configuration by gflags. + */ + static PServerUtil* createWithGflags(); + + /** + * @brief create PServerUtil with ParameterServerConfig, remove gflags + * from ParameterServer. Init all pservers thread according to the config. + */ + static PServerUtil* create(const ParameterServerConfig& config); + + /** + * @brief start all pserver thread in this PServerUtil. + */ void start(); + + /** + * @brief join and wait for all pserver thread in this PServerUtil. + */ void join(); private: std::vector> pservers_; + + /** + * @brief create ParameterServerConfig from gflags, this is used for + * compatibility with the old usage of configuration by gflags. + */ + static ParameterServerConfig* initConfig(); }; } // namespace paddle diff --git a/paddle/pserver/ParameterServer2Main.cpp b/paddle/pserver/ParameterServer2Main.cpp index afba7293eb8c99ff80378e853593806c37489c00..8c1baea0cef76849276c3dbbad6a2e0ef3a5689f 100644 --- a/paddle/pserver/ParameterServer2Main.cpp +++ b/paddle/pserver/ParameterServer2Main.cpp @@ -21,7 +21,8 @@ using namespace paddle; // NOLINT int main(int argc, char** argv) { initMain(argc, argv); - std::unique_ptr pServerPtr(paddle::PServerUtil::create()); + std::unique_ptr pServerPtr( + paddle::PServerUtil::createWithGflags()); pServerPtr->start(); pServerPtr->join(); diff --git a/paddle/trainer/TrainerMain.cpp b/paddle/trainer/TrainerMain.cpp index 0d3e4514d6b007506fcd7cd8f1532ee918ab2253..a690268c2c7ab819e4e18d8b321b292bf54f9b97 100644 --- a/paddle/trainer/TrainerMain.cpp +++ b/paddle/trainer/TrainerMain.cpp @@ -38,7 +38,7 @@ int main(int argc, char** argv) { initPython(argc, argv); if (FLAGS_start_pserver) { - PServerUtil* pServerUtil = paddle::PServerUtil::create(); + PServerUtil* pServerUtil = paddle::PServerUtil::createWithGflags(); pServerUtil->start(); } Trainer trainer;