未验证 提交 1034ca31 编写于 作者: X xujiaqi01 提交者: GitHub

add timeout and http store in communication (#23436)

* add timeout and http store in communication, add revert and confirm in fleet
* test=develop
上级 1fc6cc50
......@@ -1166,6 +1166,28 @@ int32_t FleetWrapper::CopyTable(const uint64_t src_table_id,
#endif
}
void FleetWrapper::Confirm() {
#ifdef PADDLE_WITH_PSLIB
// FIXME(xujiaqi01): will later support confirm
// auto ret = pslib_ptr_->_worker_ptr->confirm();
// ret.wait();
VLOG(0) << "disable FleetWrapper::Confirm temporarily";
#else
VLOG(0) << "FleetWrapper::Confirm does nothing when no pslib";
#endif
}
void FleetWrapper::Revert() {
#ifdef PADDLE_WITH_PSLIB
// FIXME(xujiaqi01): will later support revert
// auto ret = pslib_ptr_->_worker_ptr->revert();
// ret.wait();
VLOG(0) << "disable FleetWrapper::Revert temporarily";
#else
VLOG(0) << "FleetWrapper::Revert does nothing when no pslib";
#endif
}
int32_t FleetWrapper::CopyTableByFeasign(
const uint64_t src_table_id, const uint64_t dest_table_id,
const std::vector<uint64_t>& feasign_list) {
......
......@@ -268,6 +268,10 @@ class FleetWrapper {
// send client to client message
std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg);
// confirm all the updated params in the current pass
void Confirm();
// revert all the updated params in the current pass
void Revert();
// FleetWrapper singleton
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
......
......@@ -10,16 +10,18 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/string/string_helper.h"
namespace gloo {
namespace rendezvous {
HdfsStore::HdfsStore(const std::string& path) {
path_ = path;
wait_sleep_ms_ = 3000;
wait_sleep_ms_ = 10000;
wait_timeout_ = std::chrono::seconds(999999999);
retry_times_ = 100;
}
......@@ -35,42 +37,72 @@ void HdfsStore::set(const std::string& key, const std::vector<char>& data) {
}
int err_no = 0;
for (int i = 1; i <= retry_times_; ++i) {
err_no = 0;
std::shared_ptr<FILE> fp =
paddle::framework::fs_open_write(tmp, &err_no, "");
if (err_no != 0) {
VLOG(0) << "fs_open_write failed, retry times " << i << " err no "
<< err_no;
fp.reset();
sleep(wait_sleep_ms_ / 1000);
continue;
}
size_t write_count = fwrite_unlocked(data.data(), 1, data.size(), fp.get());
if (write_count != data.size()) {
VLOG(0) << "fwrite_unlocked failed, retry times " << i << " write_count "
<< write_count << " data.size() " << data.size();
fp.reset();
sleep(2);
continue;
err_no = -1;
}
fp.reset();
if (err_no != 0) {
VLOG(0) << "fs_open_write failed, retry times " << i << " err no "
<< err_no;
sleep(wait_sleep_ms_ / 1000);
paddle::framework::fs_remove(tmp);
if (i == retry_times_) {
VLOG(0) << "fs_open_write failed, retry times reaches limit";
PADDLE_THROW(platform::errors::PreconditionNotMet(
"fs_open_write failed, retry times reaches"
" limit ",
retry_times_));
}
} else {
break;
}
}
paddle::framework::fs_mv(tmp, path);
#endif
}
#ifdef PADDLE_WITH_GLOO
int retry_do_func(std::function<int(void)> func, uint32_t max_try_time,
uint32_t retry_interval_ms) {
for (uint32_t i = 0; i < max_try_time; ++i) {
if (func() == 0) {
return 0;
}
#ifdef _LINUX
usleep(retry_interval_ms * 1000);
#endif
}
return -1;
}
#endif
std::vector<char> HdfsStore::get(const std::string& key) {
auto path = ObjectPath(key);
std::vector<char> result;
#ifdef PADDLE_WITH_GLOO
// block until key is set
wait({key});
bool is_exists = paddle::framework::fs_exists(path);
int ret = retry_do_func(
[&path]() { return paddle::framework::fs_exists(path) ? 0 : -1; }, 5,
wait_sleep_ms_);
bool is_exists = (ret == 0);
PADDLE_ENFORCE_EQ(is_exists, true,
paddle::platform::errors::NotFound(
"HdfsStore::get, path not exists: " + path));
int read_status = retry_do_func(
[&path, &result]() {
result.clear();
int err_no = 0;
std::shared_ptr<FILE> fp = paddle::framework::fs_open_read(path, &err_no, "");
{
std::shared_ptr<FILE> fp =
paddle::framework::fs_open_read(path, &err_no, "");
char buffer = '\0';
size_t read_count = 0;
while (fread(&buffer, 1, 1, fp.get()) == 1) {
......@@ -78,6 +110,13 @@ std::vector<char> HdfsStore::get(const std::string& key) {
result.push_back(buffer);
}
VLOG(3) << "HdfsStore::get read_count " << read_count;
}
return err_no;
},
5, wait_sleep_ms_);
PADDLE_ENFORCE_EQ(read_status, 0,
paddle::platform::errors::Fatal(
"HdfsStore::get, path read faied: " + path));
#endif
return result;
}
......@@ -92,22 +131,33 @@ void HdfsStore::wait(const std::vector<std::string>& keys,
const std::chrono::milliseconds&) { // NOLINT
#ifdef PADDLE_WITH_GLOO
auto start = std::chrono::steady_clock::now();
while (!Check(keys)) {
std::vector<bool> check_key_status(keys.size(), false);
while (!Check(keys, &check_key_status)) {
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (wait_timeout_ != gloo::kNoTimeout && elapsed > wait_timeout_) {
PADDLE_ENFORCE_EQ(0, 1, paddle::platform::errors::ExecutionTimeout(
"HdfsStore::wait, Wait timeout for key(s): " +
::gloo::MakeString(keys)));
int32_t last_check_rank = -1;
for (size_t i = 0; i < check_key_status.size(); ++i) {
if (!check_key_status[i]) {
last_check_rank = i;
break;
}
}
PADDLE_THROW(platform::errors::ExecutionTimeout(
"TIMEOUT self_rank = %d pair_rank = %d", self_rank_,
last_check_rank));
}
std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_));
}
#endif
}
void HdfsStore::SetTimeoutSeconds(int timeout_seconds) {
wait_timeout_ = std::chrono::seconds(timeout_seconds);
}
std::string HdfsStore::EncodeName(const std::string& name) {
thread_local std::hash<std::string> hash_func;
return std::to_string(hash_func(name));
return ::paddle::string::erase_spaces(name);
}
std::string HdfsStore::TmpPath(const std::string& name) {
......@@ -118,50 +168,124 @@ std::string HdfsStore::ObjectPath(const std::string& name) {
return path_ + "/" + EncodeName(name);
}
bool HdfsStore::Check(const std::vector<std::string>& keys) {
bool HdfsStore::Check(const std::vector<std::string>& keys,
std::vector<bool>* keys_check_status) {
#ifdef PADDLE_WITH_GLOO
bool ret = true;
std::vector<std::string> paths;
for (const auto& key : keys) {
paths.push_back(ObjectPath(key));
}
for (const auto& path : paths) {
for (size_t i = 0; i < paths.size(); ++i) {
if ((*keys_check_status)[i]) {
continue;
}
const auto& path = paths[i];
bool is_exists = paddle::framework::fs_exists(path);
VLOG(3) << "HdfsStore::Check " << is_exists << " path " << path;
if (!is_exists) {
return false;
ret = false;
}
(*keys_check_status)[i] = is_exists;
}
return ret;
#else
VLOG(0) << "HdfsStore::Check does nothing when no gloo";
#endif
return true;
}
#ifdef PADDLE_WITH_GLOO
void ParallelConnectContext::connectFullMesh(
Store& store, std::shared_ptr<transport::Device>& dev) {
std::vector<char> allBytes;
// Create pairs
auto transportContext = dev->createContext(rank, size);
transportContext->setTimeout(getTimeout());
for (int i = 0; i < size; i++) {
if (i == rank) {
continue;
}
auto& pair = transportContext->createPair(i);
auto addrBytes = pair->address().bytes();
allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end());
}
std::ostringstream storeKey;
storeKey << rank;
store.set(storeKey.str(), allBytes);
std::vector<std::shared_ptr<std::thread>> connect_threads(thread_num_);
// Connect every pair
for (uint32_t i = 0; i < connect_threads.size(); ++i) {
connect_threads[i].reset(new std::thread(
[&store, &transportContext, this](size_t thread_idx,
size_t thread_num) -> void {
for (int i = thread_idx; i < size; i += thread_num) {
if (i == rank) {
continue;
}
// Wait for address of other side of this pair to become available
std::string key = std::to_string(i);
store.wait({key}, getTimeout());
// Connect to other side of this pair
auto allAddrs = store.get(key);
auto addr = extractAddress(allAddrs, i);
transportContext->getPair(i)->connect(addr);
}
},
i, connect_threads.size()));
}
for (uint32_t i = 0; i < connect_threads.size(); ++i) {
connect_threads[i]->join();
}
device_ = dev;
transportContext_ = std::move(transportContext);
}
#endif
} // namespace rendezvous
} // namespace gloo
namespace paddle {
namespace framework {
void GlooWrapper::Init(int rank, int size, const std::string& path,
const std::string& fs_name, const std::string& fs_ugi,
const std::string& iface, const std::string& prefix) {
void GlooWrapper::Init() {
if (is_initialized_) {
return;
}
rank_ = rank;
size_ = size;
std::string cmd = std::string("${HADOOP_HOME}/bin/hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
paddle::framework::hdfs_set_command(cmd);
#ifdef PADDLE_WITH_GLOO
gloo::transport::tcp::attr attr;
attr.iface = iface;
auto file_store = gloo::rendezvous::HdfsStore(path);
auto prefix_store = gloo::rendezvous::PrefixStore(prefix, file_store);
attr.iface = iface_;
std::shared_ptr<gloo::rendezvous::HdfsStore> file_store = nullptr;
std::shared_ptr<gloo::rendezvous::HTTPStore> http_store = nullptr;
auto context =
std::make_shared<gloo::rendezvous::ParallelConnectContext>(rank_, size_);
context->setTimeout(run_timeout_);
auto dev = gloo::transport::tcp::CreateDevice(attr);
auto context = std::make_shared<gloo::rendezvous::Context>(rank, size);
context->setTimeout(file_store.wait_timeout_);
context->connectFullMesh(prefix_store, dev);
switch (store_type_) {
case GlooStoreType::HDFS: {
std::string cmd = std::string("${HADOOP_HOME}/bin/hadoop fs");
cmd += " -D fs.default.name=" + hdfs_name_;
cmd += " -D hadoop.job.ugi=" + hdfs_ugi_;
paddle::framework::hdfs_set_command(cmd);
file_store = std::make_shared<gloo::rendezvous::HdfsStore>(hdfs_path_);
file_store->SetTimeoutSeconds(init_timeout_.count());
auto prefix_store =
std::make_shared<gloo::rendezvous::PrefixStore>(prefix_, *file_store);
context->connectFullMesh(*prefix_store, dev);
break;
}
case GlooStoreType::HTTP: {
http_store = std::make_shared<gloo::rendezvous::HTTPStore>(
http_ip_, http_port_, prefix_ + "_" + http_scope_, rank_);
http_store->SetTimeoutSeconds(init_timeout_.count());
context->connectFullMesh(*http_store, dev);
http_store->Finalize();
break;
}
default:
LOG(ERROR) << "unknown store type " << store_type_;
exit(-1);
}
context_ = std::move(context);
#endif
is_initialized_ = true;
......@@ -170,6 +294,9 @@ void GlooWrapper::Init(int rank, int size, const std::string& path,
template std::vector<int64_t> GlooWrapper::AllReduce<int64_t>(
std::vector<int64_t>& sendbuf, // NOLINT
const std::string& mode);
template std::vector<float> GlooWrapper::AllReduce<float>(
std::vector<float>& sendbuf, // NOLINT
const std::string& mode);
template std::vector<double> GlooWrapper::AllReduce<double>(
std::vector<double>& sendbuf, // NOLINT
const std::string& mode);
......@@ -180,6 +307,8 @@ template std::vector<int64_t> GlooWrapper::AllGather<int64_t>(
int64_t& input); // NOLINT
template std::vector<uint64_t> GlooWrapper::AllGather<uint64_t>(
uint64_t& input); // NOLINT
template std::vector<float> GlooWrapper::AllGather<float>(
float& input); // NOLINT
template std::vector<double> GlooWrapper::AllGather<double>(
double& input); // NOLINT
......
......@@ -31,6 +31,7 @@ limitations under the License. */
#include <gloo/barrier.h>
#include <gloo/rendezvous/context.h>
#include <gloo/rendezvous/file_store.h>
#include <gloo/rendezvous/http_store.h>
#include <gloo/rendezvous/prefix_store.h>
#include <gloo/rendezvous/store.h>
#include <gloo/transport/tcp/device.h>
......@@ -59,44 +60,87 @@ class HdfsStore {
virtual void wait(const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout);
virtual void SetTimeoutSeconds(int timeout_seconds);
std::string EncodeName(const std::string& name);
std::string TmpPath(const std::string& name);
std::string ObjectPath(const std::string& name);
bool Check(const std::vector<std::string>& keys);
bool Check(const std::vector<std::string>& keys,
std::vector<bool>* keys_check_status);
void SetRank(int rank) { self_rank_ = rank; }
std::string path_;
int wait_sleep_ms_;
std::chrono::seconds wait_timeout_;
int retry_times_;
int self_rank_;
};
#ifdef PADDLE_WITH_GLOO
class ParallelConnectContext : public gloo::rendezvous::Context {
public:
ParallelConnectContext(int rank, int size, int base = 2)
: gloo::rendezvous::Context(rank, size, base) {}
virtual ~ParallelConnectContext() {}
// in gloo::rendezvous::Context wait&get one by one,
// slowly in case big size, especialy in HdfsStore
void connectFullMesh(Store& store, // NOLINT
std::shared_ptr<transport::Device>& dev); // NOLINT
protected:
int thread_num_ = 6;
};
#endif
} // namespace rendezvous
} // namespace gloo
namespace paddle {
namespace framework {
enum GlooStoreType { HDFS, HTTP };
class GlooWrapper {
public:
GlooWrapper() {}
virtual ~GlooWrapper() {}
void Init(int rank, int size, const std::string& path,
const std::string& fs_name, const std::string& fs_ugi,
const std::string& iface, const std::string& prefix);
void Init();
int Rank() {
CHECK_EQ(is_initialized_, true);
return rank_;
void SetTimeoutSeconds(int init_seconds, int run_seconds) {
init_timeout_ = std::chrono::seconds(init_seconds);
run_timeout_ = std::chrono::seconds(run_seconds);
}
int Size() {
CHECK_EQ(is_initialized_, true);
return size_;
int Rank() { return rank_; }
int Size() { return size_; }
void SetRank(int rank) { rank_ = rank; }
void SetSize(int size) { size_ = size; }
void SetIface(const std::string& iface) { iface_ = iface; }
void SetPrefix(const std::string& prefix) { prefix_ = prefix; }
void SetHdfsStore(const std::string& path, const std::string& fs_name,
const std::string& fs_ugi) {
store_type_ = GlooStoreType::HDFS;
hdfs_path_ = path;
hdfs_name_ = fs_name;
hdfs_ugi_ = fs_ugi;
}
void SetHttpStore(const std::string& ip, int port, const std::string& scope) {
store_type_ = GlooStoreType::HTTP;
http_ip_ = ip;
http_port_ = port;
http_scope_ = scope;
}
void Barrier() {
......@@ -104,6 +148,8 @@ class GlooWrapper {
#ifdef PADDLE_WITH_GLOO
gloo::BarrierOptions opts(context_);
gloo::barrier(opts);
#else
LOG(WARNING) << "Barrier does nothing when WITH_GLOO=OFF";
#endif
}
......@@ -134,6 +180,8 @@ class GlooWrapper {
"AllReduce mode not known: " + mode));
}
gloo::allreduce(opts);
#else
LOG(WARNING) << "AllReduce does nothing when WITH_GLOO=OFF";
#endif
return recvbuf;
}
......@@ -147,6 +195,8 @@ class GlooWrapper {
opts.setInput(&input, 1);
opts.setOutput(ret.data(), size_);
gloo::allgather(opts);
#else
LOG(WARNING) << "AllGather does nothing when WITH_GLOO=OFF";
#endif
return std::move(ret);
}
......@@ -158,6 +208,19 @@ class GlooWrapper {
#endif
int rank_ = 0;
int size_ = 0;
std::chrono::seconds init_timeout_ = std::chrono::seconds(9999999);
std::chrono::seconds run_timeout_ = std::chrono::seconds(9999999);
std::string iface_ = "lo";
std::string prefix_;
GlooStoreType store_type_ = GlooStoreType::HDFS;
// configs for hdfs store
std::string hdfs_path_;
std::string hdfs_name_;
std::string hdfs_ugi_;
std::string http_ip_;
// configs for http store
int http_port_;
std::string http_scope_;
};
} // namespace framework
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/string/string_helper.h"
#if defined _WIN32 || defined __APPLE__
#else
......@@ -37,14 +38,23 @@ TEST(TEST_GLOO, store_1) {
}
store.wait(std::vector<std::string>{"test"});
store.wait(std::vector<std::string>{"test"}, std::chrono::milliseconds(0));
store.SetTimeoutSeconds(100000);
store.EncodeName("1");
store.TmpPath("1");
store.ObjectPath("1");
store.Check(std::vector<std::string>{"test"});
std::vector<bool> status(1, false);
store.Check(std::vector<std::string>{"test"}, &status);
auto gw = paddle::framework::GlooWrapper();
gw.Init(0, 1, "", "", "", "", "");
gw.Init(0, 1, "", "", "", "", "");
gw.SetTimeoutSeconds(1000, 1000);
gw.SetRank(0);
gw.SetSize(1);
gw.SetPrefix("");
gw.SetIface("lo");
gw.SetHdfsStore("", "", "");
gw.Init();
gw.SetHttpStore("", 8099, "");
gw.Init();
gw.Rank();
gw.Size();
gw.Barrier();
......@@ -63,5 +73,8 @@ TEST(TEST_FLEET, fleet_1) {
fleet->RunServer("", 0);
fleet->SaveModelOneTable(0, "", 0);
fleet->SaveModelOneTablePrefix(0, "", 0, "");
fleet->Confirm();
fleet->Revert();
paddle::string::erase_spaces("1 2");
#endif
}
......@@ -78,6 +78,8 @@ void BindFleetWrapper(py::module* m) {
&framework::FleetWrapper::SetClient2ClientConfig)
.def("set_pull_local_thread_num",
&framework::FleetWrapper::SetPullLocalThreadNum)
.def("confirm", &framework::FleetWrapper::Confirm)
.def("revert", &framework::FleetWrapper::Revert)
.def("save_model_one_table", &framework::FleetWrapper::SaveModelOneTable)
.def("save_model_one_table_with_prefix",
&framework::FleetWrapper::SaveModelOneTablePrefix)
......
......@@ -37,11 +37,20 @@ void BindGlooWrapper(py::module* m) {
.def("rank", &framework::GlooWrapper::Rank)
.def("size", &framework::GlooWrapper::Size)
.def("barrier", &framework::GlooWrapper::Barrier)
.def("set_timeout_seconds", &framework::GlooWrapper::SetTimeoutSeconds)
.def("set_rank", &framework::GlooWrapper::SetRank)
.def("set_size", &framework::GlooWrapper::SetSize)
.def("set_iface", &framework::GlooWrapper::SetIface)
.def("set_prefix", &framework::GlooWrapper::SetPrefix)
.def("set_hdfs_store", &framework::GlooWrapper::SetHdfsStore)
.def("set_http_store", &framework::GlooWrapper::SetHttpStore)
.def("all_reduce", &framework::GlooWrapper::AllReduce<uint64_t>)
.def("all_reduce", &framework::GlooWrapper::AllReduce<int64_t>)
.def("all_reduce", &framework::GlooWrapper::AllReduce<float>)
.def("all_reduce", &framework::GlooWrapper::AllReduce<double>)
.def("all_gather", &framework::GlooWrapper::AllGather<uint64_t>)
.def("all_gather", &framework::GlooWrapper::AllGather<int64_t>)
.def("all_gather", &framework::GlooWrapper::AllGather<float>)
.def("all_gather", &framework::GlooWrapper::AllGather<double>);
} // end BindGlooWrapper
} // end namespace pybind
......
......@@ -61,6 +61,19 @@ std::string trim_spaces(const std::string& str) {
return std::string(p, len);
}
std::string erase_spaces(const std::string& str) {
std::string result;
result.reserve(str.size());
const char* p = str.c_str();
while (*p != 0) {
if (!isspace(*p)) {
result.append(p, 1);
}
++p;
}
return result;
}
inline int str_to_float(const char* str, float* v) {
const char* head = str;
char* cursor = NULL;
......
......@@ -62,6 +62,9 @@ std::string format_string(const std::string& fmt, ARGS&&... args) {
// remove leading and tailing spaces
std::string trim_spaces(const std::string& str);
// erase all spaces in str
std::string erase_spaces(const std::string& str);
int str_to_float(const char* str, float* v);
// split string by delim
......
......@@ -14,6 +14,7 @@
"""Defination of Role Makers."""
from __future__ import print_function
from multiprocessing import Process, Manager
import paddle.fluid as fluid
import os
import time
......@@ -556,7 +557,21 @@ class GeneralRoleMaker(RoleMakerBase):
self._role_is_generated = False
self._hdfs_name = kwargs.get("hdfs_name", "")
self._hdfs_ugi = kwargs.get("hdfs_ugi", "")
self._hdfs_path = kwargs.get("path", "")
self._hdfs_path = kwargs.get("path", "").rstrip("/")
self._init_timeout_seconds = kwargs.get("init_timeout_seconds", 3600)
self._run_timeout_seconds = kwargs.get("run_timeout_seconds", 9999999)
ip_port = kwargs.get("http_ip_port", "")
self._http_ip_port = []
self._http_server = None
# if ip_port is not empty, it will use http instead of hdfs
if ip_port != "":
self._http_ip_port = ip_port.split(":")
# it's for communication between processes
self._manager = Manager()
# global dict to store status
self._http_server_d = self._manager.dict()
# set running status of http server
self._http_server_d["running"] = False
self._iface = self.__get_default_iface()
# this environment variable can be empty
self._prefix = os.getenv("SYS_JOB_ID", "")
......@@ -572,17 +587,41 @@ class GeneralRoleMaker(RoleMakerBase):
trainers_num = len(worker_endpoints)
if training_role not in ["TRAINER", "PSERVER"]:
raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER")
if training_role == "TRAINER":
role = Role.WORKER
current_id = int(os.environ["PADDLE_TRAINER_ID"])
if current_id == 0 and len(self._http_ip_port) != 0:
size_d = {
"trainer": len(worker_endpoints),
"pserver": len(eplist),
"all": len(worker_endpoints) + len(eplist)
}
# child process for http server
self._http_server = Process(
target=self.__start_kv_server,
args=(self._http_server_d, size_d))
self._http_server.daemon = True
# set running status to True
self._http_server_d["running"] = True
# start child process
self._http_server.start()
self._node_type = 1
self._cur_endpoint = worker_endpoints[current_id]
gloo = fluid.core.Gloo()
gloo.init(current_id,
len(worker_endpoints),
self._hdfs_path.rstrip("/") + "/trainer",
self._hdfs_name, self._hdfs_ugi, self._iface,
self._prefix)
gloo.set_rank(current_id)
gloo.set_size(len(worker_endpoints))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
if len(self._http_ip_port) != 0:
gloo.set_http_store(self._http_ip_port[0],
int(self._http_ip_port[1]), "trainer")
else:
gloo.set_hdfs_store(self._hdfs_path + "/trainer",
self._hdfs_name, self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
elif training_role == "PSERVER":
role = Role.SERVER
......@@ -598,20 +637,36 @@ class GeneralRoleMaker(RoleMakerBase):
self._node_type = 0
self._cur_endpoint = cur_endpoint
gloo = fluid.core.Gloo()
gloo.init(current_id,
len(eplist),
self._hdfs_path.rstrip("/") + "/pserver",
self._hdfs_name, self._hdfs_ugi, self._iface,
self._prefix)
gloo.set_rank(current_id)
gloo.set_size(len(eplist))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
if len(self._http_ip_port) != 0:
gloo.set_http_store(self._http_ip_port[0],
int(self._http_ip_port[1]), "pserver")
else:
gloo.set_hdfs_store(self._hdfs_path + "/pserver",
self._hdfs_name, self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
gloo = fluid.core.Gloo()
all_list = worker_endpoints + eplist
gloo.init(
all_list.index(self._cur_endpoint),
len(all_list),
self._hdfs_path.rstrip("/") + "/all", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix)
gloo.set_rank(all_list.index(self._cur_endpoint))
gloo.set_size(len(all_list))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
if len(self._http_ip_port) != 0:
gloo.set_http_store(self._http_ip_port[0],
int(self._http_ip_port[1]), "all")
else:
gloo.set_hdfs_store(self._hdfs_path + "/all", self._hdfs_name,
self._hdfs_ugi)
gloo.init()
self._all_comm = gloo
self._trainers_num = trainers_num
self._server_endpoints = eplist
......@@ -620,6 +675,11 @@ class GeneralRoleMaker(RoleMakerBase):
self._rank = all_list.index(self._cur_endpoint)
self._size = len(all_list)
self._worker_endpoints = worker_endpoints
if self._http_server is not None:
# set running status to False
self._http_server_d["running"] = False
# wait until child process exits
self._http_server.join()
self._role_is_generated = True
def all_gather(self, input):
......@@ -872,6 +932,16 @@ class GeneralRoleMaker(RoleMakerBase):
return intf_name
return "lo"
def __start_kv_server(self, http_server_d, size_d):
from paddle.fluid.incubate.fleet.utils.http_server import KVServer
http_server = KVServer(int(self._http_ip_port[1]), size_d)
http_server.start()
wait_seconds = 5
while http_server_d.get("running",
False) and not http_server.shoud_stop():
time.sleep(wait_seconds)
http_server.stop()
class UserDefinedRoleMaker(RoleMakerBase):
"""
......
......@@ -567,6 +567,24 @@ class PSLib(Fleet):
model_proto_file, table_var_names, load_combine)
self._role_maker._barrier_worker()
def confirm(self):
"""
confirm all the updated params in current pass
"""
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.confirm()
self._role_maker._barrier_worker()
def revert(self):
"""
revert all the updated params in current pass
"""
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.revert()
self._role_maker._barrier_worker()
def load_model(self, model_dir=None, **kwargs):
"""
load pslib model, there are at least 4 modes, these modes are the same
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Http Server."""
import logging
import BaseHTTPServer
import SimpleHTTPServer
import time
import threading
import socket
def get_logger(name, level, fmt):
logger = logging.getLogger(name)
logger.setLevel(level)
handler = logging.FileHandler('http.log', mode='w')
formatter = logging.Formatter(fmt=fmt)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
_http_server_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
"""
kv handler class for kv http server,
it defines the way to get/set kv in server.
"""
def do_GET(self):
"""
get method for kv handler, get value according to key.
"""
log_str = "GET " + self.address_string() + self.path
paths = self.path.split('/')
if len(paths) < 3:
print('len of request path must be 3: ' + self.path)
self.send_status_code(400)
return
_, scope, key = paths
with self.server.kv_lock:
value = self.server.kv.get(scope, {}).get(key)
if value is None:
log_str += ' , key not found: ' + key
self.send_status_code(404)
else:
log_str += ' , key found: ' + key
self.send_response(200)
self.send_header("Content-Length", str(len(value)))
self.end_headers()
self.wfile.write(value)
_http_server_logger.info(log_str)
def do_PUT(self):
"""
put method for kv handler, set value according to key.
"""
log_str = "PUT " + self.address_string() + self.path
paths = self.path.split('/')
if len(paths) < 3:
print('len of request path must be 3: ' + self.path)
self.send_status_code(400)
return
_, scope, key = paths
content_length = int(self.headers['Content-Length'])
try:
value = self.rfile.read(content_length)
except:
print("receive error invalid request")
self.send_status_code(404)
return
with self.server.kv_lock:
if self.server.kv.get(scope) is None:
self.server.kv[scope] = {}
self.server.kv[scope][key] = value
self.send_status_code(200)
_http_server_logger.info(log_str)
def do_DELETE(self):
"""
delete method for kv handler, set value according to key.
"""
log_str = "DELETE " + self.address_string() + self.path
paths = self.path.split('/')
if len(paths) < 3:
print('len of request path must be 3: ' + self.path)
self.send_status_code(400)
return
_, scope, key = paths
with self.server.delete_kv_lock:
if self.server.delete_kv.get(scope) is None:
self.server.delete_kv[scope] = []
self.server.delete_kv[scope].append(key)
self.send_status_code(200)
_http_server_logger.info(log_str)
def log_message(self, format, *args):
"""
ignore all logging messages in kv handler.
"""
pass
def send_status_code(self, code):
"""
send status code back to client.
"""
self.send_response(code)
self.send_header("Content-Length", 0)
self.end_headers()
class KVHTTPServer(BaseHTTPServer.HTTPServer, object):
"""
it is a http server storing kv pairs.
"""
def __init__(self, port, handler):
"""Init."""
super(KVHTTPServer, self).__init__(('', port), handler)
self.delete_kv_lock = threading.Lock()
self.delete_kv = {}
self.kv_lock = threading.Lock()
self.kv = {}
def get_deleted_size(self, key):
"""
get deleted size in key.
"""
ret = 0
with self.delete_kv_lock:
ret = self.delete_kv.get(key, 0)
return ret
class KVServer:
"""
it is a server storing kv pairs, has a http server inside.
"""
def __init__(self, port, size={}):
"""Init."""
self.http_server = KVHTTPServer(port, KVHandler)
self.listen_thread = None
self.size = {}
def start(self):
"""
start server until user calls stop to let it quit.
"""
self.listen_thread = threading.Thread(
target=lambda: self.http_server.serve_forever())
self.listen_thread.start()
def stop(self):
"""
stop server and clear its resources.
"""
self.http_server.shutdown()
self.listen_thread.join()
self.http_server.server_close()
def shoud_stop(self):
"""
return whether the server should stop.
Returns:
ret(bool): whether the server should stop
"""
for key in self.size:
s = self.http_server.get_deleted_size(key)
if s != self.size.get(key, 0):
return False
return True
......@@ -96,6 +96,8 @@ class TestFleet1(unittest.TestCase):
fleet.save_one_table(0, "./model_002", prefix="hahaha")
fleet.load_model("./model_0003")
fleet.load_one_table(0, "./model_004")
fleet.confirm()
fleet.revert()
except:
print("do not support pslib test, skip")
return
......
......@@ -97,7 +97,7 @@ class TestCloudRoleMaker2(unittest.TestCase):
role4._worker_gather(1)
role4._get_rank()
role4._get_size()
role4._all_comm.init(0, 0, "", "", "", "", "")
role4._all_comm.init()
role5 = GeneralRoleMaker(path="./test_gloo_5")
role5.get_local_endpoint()
role5.get_local_endpoint()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cloud role maker."""
from __future__ import print_function
import os
import unittest
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
class TestCloudRoleMaker(unittest.TestCase):
"""
Test cases for PaddleCloudRoleMaker.
"""
def setUp(self):
"""Set up, set envs."""
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ[
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001"
def test_pslib_1(self):
"""Test cases for pslib."""
import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
try:
import netifaces
except:
print("warning: no netifaces, skip test_pslib_1")
return
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36002"
os.environ["PADDLE_TRAINER_ID"] = "0"
role_maker = GeneralRoleMaker(
init_timeout_seconds=100,
run_timeout_seconds=100,
http_ip_port="127.0.0.1:36003")
role_maker.generate_role()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
fleet.init(role_maker)
train_program = fluid.Program()
startup_program = fluid.Program()
scope = fluid.Scope()
with fluid.program_guard(train_program, startup_program):
show = fluid.layers.data(name="show", shape=[-1, 1], \
dtype="float32", lod_level=1, append_batch_size=False)
fc = fluid.layers.fc(input=show, size=1, act=None)
label = fluid.layers.data(name="click", shape=[-1, 1], \
dtype="int64", lod_level=1, append_batch_size=False)
label_cast = fluid.layers.cast(label, dtype='float32')
cost = fluid.layers.log_loss(fc, label_cast)
try:
adam = fluid.optimizer.Adam(learning_rate=0.000005)
adam = fleet.distributed_optimizer(adam)
adam.minimize([cost], [scope])
fleet.run_server()
http_server_d = {}
http_server_d["running"] = False
size_d = {}
role_maker._GeneralRoleMaker__start_kv_server(http_server_d, size_d)
except:
print("do not support pslib test, skip")
return
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cloud role maker."""
from __future__ import print_function
import os
import unittest
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
class TestCloudRoleMaker(unittest.TestCase):
"""
Test cases for PaddleCloudRoleMaker.
"""
def setUp(self):
"""Set up, set envs."""
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ[
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001"
def test_pslib_1(self):
"""Test cases for pslib."""
import sys
import threading
import paddle.fluid as fluid
try:
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib
from paddle.fluid.incubate.fleet.base.role_maker import \
GeneralRoleMaker
from paddle.fluid.incubate.fleet.utils.http_server import KVHandler
from paddle.fluid.incubate.fleet.utils.http_server import KVServer
from paddle.fluid.incubate.fleet.utils.http_server import \
KVHTTPServer
except:
print("warning: no fleet, skip test_pslib_4")
return
try:
import netifaces
except:
print("warning: no netifaces, skip test_pslib_4")
return
class FakeStream():
"""
it is a fake stream only for test.
"""
def write(self, a):
"""
write a to stream, do nothing
Args:
a(str): the string to write
"""
pass
def read(self, b):
"""
read data of len b from stream, do nothing
Args:
b(str): the len to read
Returns:
c(str): the result
"""
if b == 0:
raise ValueError("this is only for test")
return "fake"
import os
try:
class TmpKVHander(KVHandler):
"""
it is a fake handler only for this test case.
"""
def __init__(self, server):
"""Init."""
self.path = "a/b/c"
self.server = server
self.wfile = FakeStream()
self.rfile = FakeStream()
self.headers = {}
self.headers['Content-Length'] = 0
def address_string(self):
"""
fake address string, it will do nothing.
"""
return "123"
def send_response(self, code):
"""
fake send response, it will do nothing.
Args:
code(int): error code
"""
pass
def send_header(self, a, b):
"""
fake send header, it will do nothing.
Args:
a(str): some header
b(str): some header
"""
pass
def end_headers(self):
"""
fake end header, it will do nothing.
"""
pass
except:
print("warning: no KVHandler, skip test_pslib_4")
return
import sys
try:
class TmpServer(KVHTTPServer):
"""
it is a fake server only for this test case.
"""
def __init__(self):
"""Init."""
self.delete_kv_lock = threading.Lock()
self.delete_kv = {}
self.kv_lock = threading.Lock()
self.kv = {}
except:
print("warning: no KVHTTPServer, skip test_pslib_4")
return
try:
class TmpS(KVServer):
"""
it is a fake server only for this test case.
"""
def __init__(self):
"""Init."""
self.http_server = TmpServer()
self.listen_thread = None
self.size = {}
self.size["a"] = 999
except:
print("warning: no KVServer, skip test_pslib_4")
return
s = TmpServer()
h = TmpKVHander(s)
h.do_GET()
h.path = "a/b"
h.do_GET()
h.do_PUT()
h.do_DELETE()
h.path = "a/b/c"
s.kv["b"] = {}
s.kv["b"]["c"] = "456"
h.do_GET()
h.path = "a/d/e"
h.do_PUT()
h.headers['Content-Length'] = 1
h.do_PUT()
h.do_DELETE()
h.log_message("666")
s.get_deleted_size("haha")
s1 = TmpS()
s1.shoud_stop()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册