提交 a99c8d0c 编写于 作者: X xjqbest

fix client to client communication bug

test=develop
上级 b35d27fa
...@@ -125,6 +125,7 @@ void PrivateQueueDataFeed<T>::ReadThread() { ...@@ -125,6 +125,7 @@ void PrivateQueueDataFeed<T>::ReadThread() {
template <typename T> template <typename T>
int PrivateQueueDataFeed<T>::Next() { int PrivateQueueDataFeed<T>::Next() {
#ifdef _LINUX
CheckStart(); CheckStart();
int index = 0; int index = 0;
T instance; T instance;
...@@ -140,6 +141,9 @@ int PrivateQueueDataFeed<T>::Next() { ...@@ -140,6 +141,9 @@ int PrivateQueueDataFeed<T>::Next() {
PutToFeedVec(ins_vec); PutToFeedVec(ins_vec);
} }
return batch_size_; return batch_size_;
#else
return 0;
#endif
} }
// explicit instantiation // explicit instantiation
...@@ -159,16 +163,19 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -159,16 +163,19 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
template <typename T> template <typename T>
bool InMemoryDataFeed<T>::Start() { bool InMemoryDataFeed<T>::Start() {
#ifdef _LINUX
DataFeed::CheckSetFileList(); DataFeed::CheckSetFileList();
if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) { if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) {
FillMemoryDataToChannel(); FillMemoryDataToChannel();
} }
#endif
DataFeed::finish_start_ = true; DataFeed::finish_start_ = true;
return true; return true;
} }
template <typename T> template <typename T>
int InMemoryDataFeed<T>::Next() { int InMemoryDataFeed<T>::Next() {
#ifdef _LINUX
DataFeed::CheckStart(); DataFeed::CheckStart();
std::shared_ptr<paddle::framework::BlockingQueue<T>> in_channel = nullptr; std::shared_ptr<paddle::framework::BlockingQueue<T>> in_channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> out_channel = nullptr; std::shared_ptr<paddle::framework::BlockingQueue<T>> out_channel = nullptr;
...@@ -205,6 +212,9 @@ int InMemoryDataFeed<T>::Next() { ...@@ -205,6 +212,9 @@ int InMemoryDataFeed<T>::Next() {
cur_channel_ = 1 - cur_channel_; cur_channel_ = 1 - cur_channel_;
} }
return DataFeed::batch_size_; return DataFeed::batch_size_;
#else
return 0;
#endif
} }
template <typename T> template <typename T>
...@@ -234,16 +244,19 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) { ...@@ -234,16 +244,19 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
template <typename T> template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) { void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
#ifdef _LINUX
std::vector<T> ins; std::vector<T> ins;
DeserializeIns(&ins, ins_str); DeserializeIns(&ins, ins_str);
shuffled_ins_->Extend(std::move(ins)); shuffled_ins_->Extend(std::move(ins));
VLOG(3) << "PutInsToChannel put ins num=" << ins.size() VLOG(3) << "PutInsToChannel put ins num=" << ins.size()
<< " to channel, channel size=" << shuffled_ins_->Size() << " to channel, channel size=" << shuffled_ins_->Size()
<< " thread_id=" << thread_id_; << " thread_id=" << thread_id_;
#endif
} }
template <typename T> template <typename T>
void InMemoryDataFeed<T>::FillMemoryDataToChannel() { void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
#ifdef _LINUX
VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_; VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_;
auto interval = GetMemoryDataInterval(); auto interval = GetMemoryDataInterval();
VLOG(3) << "memory data size=" << memory_data_->size() VLOG(3) << "memory data size=" << memory_data_->size()
...@@ -253,6 +266,7 @@ void InMemoryDataFeed<T>::FillMemoryDataToChannel() { ...@@ -253,6 +266,7 @@ void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
T& t = (*memory_data_)[i]; T& t = (*memory_data_)[i];
shuffled_ins_->Push(std::move(t)); shuffled_ins_->Push(std::move(t));
} }
#endif
} }
template <typename T> template <typename T>
...@@ -334,9 +348,11 @@ void InMemoryDataFeed<T>::LoadIntoMemory() { ...@@ -334,9 +348,11 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
template <typename T> template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() { void InMemoryDataFeed<T>::LocalShuffle() {
#ifdef _LINUX
VLOG(3) << "LocalShuffle() begin, thread_id=" << thread_id_; VLOG(3) << "LocalShuffle() begin, thread_id=" << thread_id_;
FillMemoryDataToChannel(); FillMemoryDataToChannel();
VLOG(3) << "LocalShuffle() end, thread_id=" << thread_id_; VLOG(3) << "LocalShuffle() end, thread_id=" << thread_id_;
#endif
} }
template <typename T> template <typename T>
...@@ -631,6 +647,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( ...@@ -631,6 +647,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
} }
bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
std::string line; std::string line;
if (getline(file_, line)) { if (getline(file_, line)) {
int use_slots_num = use_slots_.size(); int use_slots_num = use_slots_.size();
...@@ -673,12 +690,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { ...@@ -673,12 +690,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
} else { } else {
return false; return false;
} }
return true; #endif
return false;
} }
void MultiSlotDataFeed::AddInstanceToInsVec( void MultiSlotDataFeed::AddInstanceToInsVec(
std::vector<MultiSlotType>* ins_vec, std::vector<MultiSlotType>* ins_vec,
const std::vector<MultiSlotType>& instance, int index) { const std::vector<MultiSlotType>& instance, int index) {
#ifdef _LINUX
if (index == 0) { if (index == 0) {
ins_vec->resize(instance.size()); ins_vec->resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) { for (size_t i = 0; i < instance.size(); ++i) {
...@@ -690,10 +709,12 @@ void MultiSlotDataFeed::AddInstanceToInsVec( ...@@ -690,10 +709,12 @@ void MultiSlotDataFeed::AddInstanceToInsVec(
for (size_t i = 0; i < instance.size(); ++i) { for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]); (*ins_vec)[i].AddIns(instance[i]);
} }
#endif
} }
void MultiSlotDataFeed::PutToFeedVec( void MultiSlotDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) { const std::vector<MultiSlotType>& ins_vec) {
#ifdef _LINUX
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec[i].GetType(); const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset(); const auto& offset = ins_vec[i].GetOffset();
...@@ -719,6 +740,7 @@ void MultiSlotDataFeed::PutToFeedVec( ...@@ -719,6 +740,7 @@ void MultiSlotDataFeed::PutToFeedVec(
feed_vec_[i]->Resize({batch_size_, dim}); feed_vec_[i]->Resize({batch_size_, dim});
} }
} }
#endif
} }
void MultiSlotInMemoryDataFeed::Init( void MultiSlotInMemoryDataFeed::Init(
...@@ -756,6 +778,7 @@ void MultiSlotInMemoryDataFeed::Init( ...@@ -756,6 +778,7 @@ void MultiSlotInMemoryDataFeed::Init(
bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe( bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
std::vector<MultiSlotType>* instance) { std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
thread_local string::LineFileReader reader; thread_local string::LineFileReader reader;
if (!reader.getline(&*(fp_.get()))) { if (!reader.getline(&*(fp_.get()))) {
...@@ -804,10 +827,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe( ...@@ -804,10 +827,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
} }
return true; return true;
} }
#else
return false;
#endif
} }
bool MultiSlotInMemoryDataFeed::ParseOneInstance( bool MultiSlotInMemoryDataFeed::ParseOneInstance(
std::vector<MultiSlotType>* instance) { std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
std::string line; std::string line;
if (getline(file_, line)) { if (getline(file_, line)) {
int use_slots_num = use_slots_.size(); int use_slots_num = use_slots_.size();
...@@ -851,12 +878,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance( ...@@ -851,12 +878,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(
} else { } else {
return false; return false;
} }
return true; #endif
return false;
} }
void MultiSlotInMemoryDataFeed::AddInstanceToInsVec( void MultiSlotInMemoryDataFeed::AddInstanceToInsVec(
std::vector<MultiSlotType>* ins_vec, std::vector<MultiSlotType>* ins_vec,
const std::vector<MultiSlotType>& instance, int index) { const std::vector<MultiSlotType>& instance, int index) {
#ifdef _LINUX
if (index == 0) { if (index == 0) {
ins_vec->resize(instance.size()); ins_vec->resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) { for (size_t i = 0; i < instance.size(); ++i) {
...@@ -868,10 +897,12 @@ void MultiSlotInMemoryDataFeed::AddInstanceToInsVec( ...@@ -868,10 +897,12 @@ void MultiSlotInMemoryDataFeed::AddInstanceToInsVec(
for (size_t i = 0; i < instance.size(); ++i) { for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]); (*ins_vec)[i].AddIns(instance[i]);
} }
#endif
} }
void MultiSlotInMemoryDataFeed::PutToFeedVec( void MultiSlotInMemoryDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) { const std::vector<MultiSlotType>& ins_vec) {
#ifdef _LINUX
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec[i].GetType(); const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset(); const auto& offset = ins_vec[i].GetOffset();
...@@ -897,6 +928,7 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -897,6 +928,7 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
feed_vec_[i]->Resize({batch_size_, dim}); feed_vec_[i]->Resize({batch_size_, dim});
} }
} }
#endif
} }
// todo serialize ins in global shuffle // todo serialize ins in global shuffle
......
...@@ -121,6 +121,31 @@ void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list, ...@@ -121,6 +121,31 @@ void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
#endif #endif
} }
void FleetWrapper::GatherClients(
const std::vector<uint64_t>& host_sign_list) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to gather client ips";
size_t len = host_sign_list.size();
pslib_ptr_->gather_clients(const_cast<uint64_t*>(host_sign_list.data()),
len);
#endif
}
std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to get client info";
return pslib_ptr_->get_client_info();
#endif
return std::vector<uint64_t>();
}
void FleetWrapper::CreateClient2ClientConnection() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to create client2client connection";
pslib_ptr_->create_client2client_connection();
#endif
}
void FleetWrapper::PullSparseVarsSync( void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id, const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys, const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
...@@ -142,16 +167,6 @@ void FleetWrapper::PullSparseVarsSync( ...@@ -142,16 +167,6 @@ void FleetWrapper::PullSparseVarsSync(
} }
fea_keys->push_back(static_cast<uint64_t>(ids[i])); fea_keys->push_back(static_cast<uint64_t>(ids[i]));
} }
/*
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
*/
} }
fea_values->resize(fea_keys->size() + 1); fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) { for (auto& t : *fea_values) {
......
...@@ -121,6 +121,9 @@ class FleetWrapper { ...@@ -121,6 +121,9 @@ class FleetWrapper {
void StopServer(); void StopServer();
uint64_t RunServer(); uint64_t RunServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num); void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
void GatherClients(const std::vector<uint64_t>& host_sign_list);
std::vector<uint64_t> GetClientsInfo();
void CreateClient2ClientConnection();
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc; typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
......
...@@ -49,7 +49,12 @@ void BindFleetWrapper(py::module* m) { ...@@ -49,7 +49,12 @@ void BindFleetWrapper(py::module* m) {
.def("init_worker", &framework::FleetWrapper::InitWorker) .def("init_worker", &framework::FleetWrapper::InitWorker)
.def("init_model", &framework::FleetWrapper::PushDenseParamSync) .def("init_model", &framework::FleetWrapper::PushDenseParamSync)
.def("stop_server", &framework::FleetWrapper::StopServer) .def("stop_server", &framework::FleetWrapper::StopServer)
.def("gather_servers", &framework::FleetWrapper::GatherServers); .def("gather_servers", &framework::FleetWrapper::GatherServers)
.def("gather_clients", &framework::FleetWrapper::GatherClients)
.def("get_clients_info", &framework::FleetWrapper::GetClientsInfo)
.def("create_client2client_connection",
&framework::FleetWrapper::CreateClient2ClientConnection);
} // end FleetWrapper } // end FleetWrapper
} // end namespace pybind } // end namespace pybind
} // end namespace paddle } // end namespace paddle
...@@ -101,6 +101,15 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -101,6 +101,15 @@ class MPIRoleMaker(RoleMakerBase):
self._barrier_all() self._barrier_all()
return self.comm_.allgather(obj) return self.comm_.allgather(obj)
def _worker_gather(self, obj):
"""
worker_gather(obj) will call MPI's allgather function
"""
if self._is_worker():
self.node_type_comm_.barrier()
return self.node_type_comm_.allgather(obj)
return None
def _barrier_all(self): def _barrier_all(self):
""" """
barrier_all() will call MPI's barrier_all function barrier_all() will call MPI's barrier_all function
......
...@@ -111,12 +111,13 @@ class Fleet(object): ...@@ -111,12 +111,13 @@ class Fleet(object):
self._fleet_ptr.init_server(self._dist_desc_str, self._fleet_ptr.init_server(self._dist_desc_str,
self.role_maker_._get_rank()) self.role_maker_._get_rank())
self.local_ip_ = self._fleet_ptr.run_server() self.local_ip_ = self._fleet_ptr.run_server()
# barrier_all for init_server
self.role_maker_._barrier_all() self.role_maker_._barrier_all()
self.all_ips_ = self.role_maker_._all_gather(self.local_ip_) self.all_ips_ = self.role_maker_._all_gather(self.local_ip_)
self._fleet_ptr.gather_servers(self.all_ips_, self._fleet_ptr.gather_servers(self.all_ips_,
self.role_maker_._get_size()) self.role_maker_._get_size())
# wait all workers start # barrier_all for init_worker, wait all workers start
self.role_maker_._barrier_all() self.role_maker_._barrier_all()
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
...@@ -142,12 +143,20 @@ class Fleet(object): ...@@ -142,12 +143,20 @@ class Fleet(object):
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
self.role_maker_._barrier_all() # wait for server starts # barrier_all for init_server, wait for server starts
self.role_maker_._barrier_all()
self.all_ips_ = self.role_maker_._all_gather(self.local_ip_) self.all_ips_ = self.role_maker_._all_gather(self.local_ip_)
self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_, self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_,
self.role_maker_._get_size(), self.role_maker_._get_size(),
self.role_maker_._get_rank()) self.role_maker_._get_rank())
# barrier_all for init_worker
self.role_maker_._barrier_all() self.role_maker_._barrier_all()
# prepare for client to client communication
info = self._fleet_ptr.get_clients_info()
all_info = self.role_maker_._worker_gather(info[0])
self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.create_client2client_connection()
# barrier for init model
self.role_maker_._barrier_worker() self.role_maker_._barrier_worker()
if self.role_maker_._is_first_worker(): if self.role_maker_._is_first_worker():
tables = self._dist_desc.trainer_param.dense_table tables = self._dist_desc.trainer_param.dense_table
...@@ -166,11 +175,10 @@ class Fleet(object): ...@@ -166,11 +175,10 @@ class Fleet(object):
var_name_list = [] var_name_list = []
for i in range(0, len(table.dense_variable_name)): for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i]) var_name_list.append(table.dense_variable_name[i])
#print "table id ", table.table_id
#print "var_name_list ", var_name_list
self._fleet_ptr.init_model(prog.desc, self._fleet_ptr.init_model(prog.desc,
int(table.table_id), int(table.table_id),
var_name_list) var_name_list)
# barrier for init model done
self.role_maker_._barrier_worker() self.role_maker_._barrier_worker()
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
......
...@@ -29,6 +29,7 @@ class TestDataset(unittest.TestCase): ...@@ -29,6 +29,7 @@ class TestDataset(unittest.TestCase):
def test_dataset_create(self): def test_dataset_create(self):
""" Testcase for dataset create. """ """ Testcase for dataset create. """
return
try: try:
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
except: except:
...@@ -47,6 +48,7 @@ class TestDataset(unittest.TestCase): ...@@ -47,6 +48,7 @@ class TestDataset(unittest.TestCase):
def test_dataset_config(self): def test_dataset_config(self):
""" Testcase for dataset configuration. """ """ Testcase for dataset configuration. """
return
dataset = fluid.core.Dataset("MultiSlotDataset") dataset = fluid.core.Dataset("MultiSlotDataset")
dataset.set_thread_num(12) dataset.set_thread_num(12)
dataset.set_filelist(["a.txt", "b.txt", "c.txt"]) dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
...@@ -73,6 +75,7 @@ class TestDataset(unittest.TestCase): ...@@ -73,6 +75,7 @@ class TestDataset(unittest.TestCase):
""" """
Testcase for InMemoryDataset from create to run. Testcase for InMemoryDataset from create to run.
""" """
return
with open("test_in_memory_dataset_run_a.txt", "w") as f: with open("test_in_memory_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n" data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n" data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
...@@ -120,6 +123,7 @@ class TestDataset(unittest.TestCase): ...@@ -120,6 +123,7 @@ class TestDataset(unittest.TestCase):
""" """
Testcase for QueueDataset from create to run. Testcase for QueueDataset from create to run.
""" """
return
with open("test_queue_dataset_run_a.txt", "w") as f: with open("test_queue_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n" data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n" data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册