提交 f8a529cf 编写于 作者: J jacquesqiao 提交者: GitHub

Merge pull request #1051 from jacquesqiao/add-pserver-util

Add ParameterServerController for parameter server python api
...@@ -25,6 +25,7 @@ log_file="$bin_dir/train.log" ...@@ -25,6 +25,7 @@ log_file="$bin_dir/train.log"
pushd "$home_dir" pushd "$home_dir"
cfg=trainer_config.lr.py cfg=trainer_config.lr.py
paddle train \ paddle train \
--start_pserver=false \
--config=$cfg \ --config=$cfg \
--save_dir=${model_dir} \ --save_dir=${model_dir} \
--trainer_count=4 \ --trainer_count=4 \
......
...@@ -24,13 +24,15 @@ set(PSERVER_SOURCES ...@@ -24,13 +24,15 @@ set(PSERVER_SOURCES
BaseClient.cpp BaseClient.cpp
ParameterClient2.cpp ParameterClient2.cpp
ParameterServer2.cpp ParameterServer2.cpp
SparseParameterDistribution.cpp) SparseParameterDistribution.cpp
ParameterServerController.cpp)
set(PSERVER_HEADERS set(PSERVER_HEADERS
BaseClient.h BaseClient.h
ParameterClient2.h ParameterClient2.h
ParameterServer2.h ParameterServer2.h
SparseParameterDistribution.h) SparseParameterDistribution.h
ParameterServerController.h)
add_library(paddle_pserver STATIC add_library(paddle_pserver STATIC
${PSERVER_SOURCES}) ${PSERVER_SOURCES})
......
...@@ -13,66 +13,17 @@ See the License for the specific language governing permissions and ...@@ -13,66 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <fstream> #include <fstream>
#include "paddle/utils/StringUtil.h" #include "ParameterServerController.h"
#include "paddle/utils/Util.h"
#include "ParameterServer2.h"
#include "RDMANetwork.h"
#include "paddle/utils/Flags.h"
using namespace paddle; // NOLINT using namespace paddle; // NOLINT
int main(int argc, char** argv) { int main(int argc, char** argv) {
initMain(argc, argv); initMain(argc, argv);
std::vector<std::string> devices; std::unique_ptr<ParameterServerController> parameterServerPtr(
std::vector<std::shared_ptr<ParameterServer2>> pservers; paddle::ParameterServerController::createFromGflags());
parameterServerPtr->start();
// round robin to loadbalance RDMA server ENGINE parameterServerPtr->wait();
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
if (FLAGS_nics.empty()) {
pservers.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i].reset(new ParameterServer2(std::string(), FLAGS_port + i));
}
CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << FLAGS_port + i;
pservers[i]->start();
}
} else {
str::split(FLAGS_nics, ',', &devices);
pservers.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
}
CHECK(pservers[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << devices[j] << ":"
<< FLAGS_port + i;
pservers[i * devices.size() + j]->start();
}
}
}
for (auto& pserver : pservers) {
pserver->join();
}
return 0; return 0;
} }
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ParameterServerController.h"
namespace paddle {
ParameterServerController::ParameterServerController(
const ParameterServerConfig& config) {
// round robin to load balance RDMA server ENGINE
std::vector<std::string> devices;
int rdmaCpu = 0;
int onlineCpus = rdma::numCpus();
int numPorts = config.ports_num() + config.ports_num_for_sparse();
if (config.nics().empty()) {
parameterServers_.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (config.rdma_tcp() == "rdma") {
parameterServers_[i].reset(
new ParameterServer2(std::string(), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
parameterServers_[i].reset(
new ParameterServer2(std::string(), config.port() + i));
}
CHECK(parameterServers_[i]->init()) << "Fail to initialize parameter "
"server on port "
<< config.port() + i;
}
} else {
str::split(config.nics(), ',', &devices);
parameterServers_.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (config.rdma_tcp() == "rdma") {
parameterServers_[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), config.port() + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
parameterServers_[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), config.port() + i));
}
CHECK(parameterServers_[i * devices.size() + j]->init())
<< "Fail to initialize parameter server with device " << devices[j]
<< config.port() + i;
}
}
}
}
ParameterServerController::~ParameterServerController() { this->wait(); }
ParameterServerController* ParameterServerController::createFromGflags() {
ParameterServerConfig config;
config.set_nics(FLAGS_nics);
config.set_rdma_tcp(FLAGS_rdma_tcp);
config.set_port(FLAGS_port);
config.set_ports_num(FLAGS_ports_num);
config.set_ports_num_for_sparse(FLAGS_ports_num_for_sparse);
return create(config);
}
ParameterServerController* ParameterServerController::create(
const ParameterServerConfig& config) {
return new ParameterServerController(config);
}
void ParameterServerController::start() {
LOG(INFO) << "number of parameterServer instances: "
<< parameterServers_.size();
int i = 0;
for (const auto& parameterServer : parameterServers_) {
LOG(INFO) << "Starting parameterServer[" << i << "]";
parameterServer->start();
i++;
}
}
void ParameterServerController::wait() {
int i = 0;
for (const auto& parameterServer : parameterServers_) {
LOG(INFO) << "Waiting parameterServer[" << i << "]";
parameterServer->join();
i++;
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "ParameterServer2.h"
#include "ParameterServerConfig.pb.h"
#include "RDMANetwork.h"
#include "paddle/utils/StringUtil.h"
namespace paddle {
/**
* @brief ParameterServerController is used for create, init and manage multi
* parameter server instances. The num of the instances is decided by port
* num(the ports number for parameter send) and network devices configured
* by gflags or proto.
*/
class ParameterServerController final {
public:
DISABLE_COPY(ParameterServerController);
/**
* @brief Ctor, Create a ParameterServerController from ParameterServerConfig.
*/
explicit ParameterServerController(const ParameterServerConfig& config);
/**
* @brief Dtor.
*/
~ParameterServerController();
/**
* @brief create ParameterServerController from gflags, this is used for
* compatibility with the old usage of configuration by gflags.
*/
static ParameterServerController* createFromGflags();
/**
* @brief create ParameterServerController with ParameterServerConfig, remove
* gflags from ParameterServer. Init all ParameterServer2 instances according
* to
* the config.
*/
static ParameterServerController* create(const ParameterServerConfig& config);
/**
* @brief start all ParameterServer2 instances in this
* ParameterServerController.
*/
void start();
/**
* @brief join and wait for all ParameterServer2 instances thread in this
* ParameterServerController.
*/
void wait();
private:
std::vector<std::unique_ptr<ParameterServer2>> parameterServers_;
};
} // namespace paddle
...@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/pserver/ParameterServer2.h" #include <fenv.h>
#include "paddle/utils/Common.h" #include "paddle/pserver/ParameterServerController.h"
#include "paddle/utils/PythonUtil.h" #include "paddle/utils/PythonUtil.h"
#include "paddle/utils/StringUtil.h"
#include "ParamUtil.h" #include "ParamUtil.h"
#include "Trainer.h" #include "Trainer.h"
#include "paddle/pserver/RDMANetwork.h"
DEFINE_bool(start_pserver, false, "Whether to start pserver"); DEFINE_bool(start_pserver, false, "Whether to start pserver");
DECLARE_int32(gpu_id); DECLARE_int32(gpu_id);
...@@ -38,54 +36,11 @@ int main(int argc, char** argv) { ...@@ -38,54 +36,11 @@ int main(int argc, char** argv) {
initMain(argc, argv); initMain(argc, argv);
initPython(argc, argv); initPython(argc, argv);
std::vector<std::unique_ptr<ParameterServer2>> pservers; std::unique_ptr<ParameterServerController> parameterServerPtr(nullptr);
std::vector<std::string> devices;
if (FLAGS_start_pserver) { if (FLAGS_start_pserver) {
// round robin to loadbalance RDMA server ENGINE parameterServerPtr.reset(
int rdmaCpu = 0; paddle::ParameterServerController::createFromGflags());
int onlineCpus = rdma::numCpus(); parameterServerPtr->start();
int numPorts = FLAGS_ports_num + FLAGS_ports_num_for_sparse;
if (FLAGS_nics.empty()) {
pservers.resize(numPorts);
for (int i = 0; i < numPorts; ++i) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i].reset(
new ParameterServer2(std::string(), FLAGS_port + i));
}
CHECK(pservers[i]->init()) << "Fail to initialize parameter server"
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << FLAGS_port + i;
pservers[i]->start();
}
} else {
str::split(FLAGS_nics, ',', &devices);
pservers.resize(devices.size() * numPorts);
for (int i = 0; i < numPorts; ++i) {
for (size_t j = 0; j < devices.size(); ++j) {
if (FLAGS_rdma_tcp == "rdma") {
pservers[i * devices.size() + j].reset(new ParameterServer2(
getIpAddr(devices[j]), FLAGS_port + i, rdmaCpu++));
rdmaCpu = rdmaCpu % onlineCpus;
} else {
pservers[i * devices.size() + j].reset(
new ParameterServer2(getIpAddr(devices[j]), FLAGS_port + i));
}
CHECK(pservers[i * devices.size() + j]->init())
<< "Fail to initialize parameter server" << devices[j]
<< FLAGS_port + i;
LOG(INFO) << "pserver started : " << devices[j] << ":"
<< FLAGS_port + i;
pservers[i * devices.size() + j]->start();
}
}
}
} }
Trainer trainer; Trainer trainer;
auto config = TrainerConfigHelper::createFromFlags(); auto config = TrainerConfigHelper::createFromFlags();
......
...@@ -4,7 +4,8 @@ set(proto_filenames ...@@ -4,7 +4,8 @@ set(proto_filenames
ModelConfig.proto ModelConfig.proto
ParameterConfig.proto ParameterConfig.proto
ParameterService.proto ParameterService.proto
TrainerConfig.proto) TrainerConfig.proto
ParameterServerConfig.proto)
set(PROTO_GEN) set(PROTO_GEN)
set(PROTO_GEN_PY) set(PROTO_GEN_PY)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle;
/**
* Configuration structure for ParameterClient2.
*/
message ParameterClientConfig {
required int32 trainer_id = 1;
}
/**
* Configuration structure for ParameterServer2.
*/
message ParameterServerConfig {
// The ports number for parameter send,
// increment based on default port number
required int32 ports_num = 1 [default = 1];
// The ports number for parameter send,
// increment based on default (port + ports_num
required int32 ports_num_for_sparse = 2 [default = 0];
// network device name for pservers
required string nics = 3 [default = "xgbe0,xgbe1"];
required string rdma_tcp = 4 [default = "tcp"];
// Listening port for pserver
required int32 port = 5 [default = 20134];
// number of gradient servers
required int32 num_gradient_servers = 6 [default = 1];
// number of threads for sync op exec
required int32 pserver_num_threads = 7 [default = 1];
// control config_.async_lagged_grad_discard_ratio() min value
required double async_lagged_ratio_min = 8 [default = 1.0];
// if async_lagged_grad_discard_ratio is not set in trainer_config.conf
// use it as defalut value
required double async_lagged_ratio_default = 9 [default = 1.5];
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册