提交 a99c8d0c 编写于 作者: X xjqbest

fix client to client communication bug

test=develop
上级 b35d27fa
......@@ -125,6 +125,7 @@ void PrivateQueueDataFeed<T>::ReadThread() {
template <typename T>
int PrivateQueueDataFeed<T>::Next() {
#ifdef _LINUX
CheckStart();
int index = 0;
T instance;
......@@ -140,6 +141,9 @@ int PrivateQueueDataFeed<T>::Next() {
PutToFeedVec(ins_vec);
}
return batch_size_;
#else
return 0;
#endif
}
// explicit instantiation
......@@ -159,16 +163,19 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
template <typename T>
bool InMemoryDataFeed<T>::Start() {
#ifdef _LINUX
DataFeed::CheckSetFileList();
if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) {
FillMemoryDataToChannel();
}
#endif
DataFeed::finish_start_ = true;
return true;
}
template <typename T>
int InMemoryDataFeed<T>::Next() {
#ifdef _LINUX
DataFeed::CheckStart();
std::shared_ptr<paddle::framework::BlockingQueue<T>> in_channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> out_channel = nullptr;
......@@ -205,6 +212,9 @@ int InMemoryDataFeed<T>::Next() {
cur_channel_ = 1 - cur_channel_;
}
return DataFeed::batch_size_;
#else
return 0;
#endif
}
template <typename T>
......@@ -234,16 +244,19 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
#ifdef _LINUX
std::vector<T> ins;
DeserializeIns(&ins, ins_str);
shuffled_ins_->Extend(std::move(ins));
VLOG(3) << "PutInsToChannel put ins num=" << ins.size()
<< " to channel, channel size=" << shuffled_ins_->Size()
<< " thread_id=" << thread_id_;
#endif
}
template <typename T>
void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
#ifdef _LINUX
VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_;
auto interval = GetMemoryDataInterval();
VLOG(3) << "memory data size=" << memory_data_->size()
......@@ -253,6 +266,7 @@ void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
T& t = (*memory_data_)[i];
shuffled_ins_->Push(std::move(t));
}
#endif
}
template <typename T>
......@@ -334,9 +348,11 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() {
#ifdef _LINUX
VLOG(3) << "LocalShuffle() begin, thread_id=" << thread_id_;
FillMemoryDataToChannel();
VLOG(3) << "LocalShuffle() end, thread_id=" << thread_id_;
#endif
}
template <typename T>
......@@ -631,6 +647,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
}
bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
std::string line;
if (getline(file_, line)) {
int use_slots_num = use_slots_.size();
......@@ -673,12 +690,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
} else {
return false;
}
return true;
#endif
return false;
}
void MultiSlotDataFeed::AddInstanceToInsVec(
std::vector<MultiSlotType>* ins_vec,
const std::vector<MultiSlotType>& instance, int index) {
#ifdef _LINUX
if (index == 0) {
ins_vec->resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) {
......@@ -690,10 +709,12 @@ void MultiSlotDataFeed::AddInstanceToInsVec(
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]);
}
#endif
}
void MultiSlotDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) {
#ifdef _LINUX
for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset();
......@@ -719,6 +740,7 @@ void MultiSlotDataFeed::PutToFeedVec(
feed_vec_[i]->Resize({batch_size_, dim});
}
}
#endif
}
void MultiSlotInMemoryDataFeed::Init(
......@@ -756,6 +778,7 @@ void MultiSlotInMemoryDataFeed::Init(
bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
thread_local string::LineFileReader reader;
if (!reader.getline(&*(fp_.get()))) {
......@@ -804,10 +827,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
}
return true;
}
#else
return false;
#endif
}
bool MultiSlotInMemoryDataFeed::ParseOneInstance(
std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
std::string line;
if (getline(file_, line)) {
int use_slots_num = use_slots_.size();
......@@ -851,12 +878,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(
} else {
return false;
}
return true;
#endif
return false;
}
void MultiSlotInMemoryDataFeed::AddInstanceToInsVec(
std::vector<MultiSlotType>* ins_vec,
const std::vector<MultiSlotType>& instance, int index) {
#ifdef _LINUX
if (index == 0) {
ins_vec->resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) {
......@@ -868,10 +897,12 @@ void MultiSlotInMemoryDataFeed::AddInstanceToInsVec(
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]);
}
#endif
}
void MultiSlotInMemoryDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) {
#ifdef _LINUX
for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset();
......@@ -897,6 +928,7 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
feed_vec_[i]->Resize({batch_size_, dim});
}
}
#endif
}
// todo serialize ins in global shuffle
......
......@@ -121,6 +121,31 @@ void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
#endif
}
void FleetWrapper::GatherClients(
const std::vector<uint64_t>& host_sign_list) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to gather client ips";
size_t len = host_sign_list.size();
pslib_ptr_->gather_clients(const_cast<uint64_t*>(host_sign_list.data()),
len);
#endif
}
std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to get client info";
return pslib_ptr_->get_client_info();
#endif
return std::vector<uint64_t>();
}
void FleetWrapper::CreateClient2ClientConnection() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to create client2client connection";
pslib_ptr_->create_client2client_connection();
#endif
}
void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
......@@ -142,16 +167,6 @@ void FleetWrapper::PullSparseVarsSync(
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
/*
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
*/
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
......
......@@ -121,6 +121,9 @@ class FleetWrapper {
void StopServer();
uint64_t RunServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
void GatherClients(const std::vector<uint64_t>& host_sign_list);
std::vector<uint64_t> GetClientsInfo();
void CreateClient2ClientConnection();
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
......
......@@ -49,7 +49,12 @@ void BindFleetWrapper(py::module* m) {
.def("init_worker", &framework::FleetWrapper::InitWorker)
.def("init_model", &framework::FleetWrapper::PushDenseParamSync)
.def("stop_server", &framework::FleetWrapper::StopServer)
.def("gather_servers", &framework::FleetWrapper::GatherServers);
.def("gather_servers", &framework::FleetWrapper::GatherServers)
.def("gather_clients", &framework::FleetWrapper::GatherClients)
.def("get_clients_info", &framework::FleetWrapper::GetClientsInfo)
.def("create_client2client_connection",
&framework::FleetWrapper::CreateClient2ClientConnection);
} // end FleetWrapper
} // end namespace pybind
} // end namespace paddle
......@@ -101,6 +101,15 @@ class MPIRoleMaker(RoleMakerBase):
self._barrier_all()
return self.comm_.allgather(obj)
def _worker_gather(self, obj):
"""
worker_gather(obj) will call MPI's allgather function
"""
if self._is_worker():
self.node_type_comm_.barrier()
return self.node_type_comm_.allgather(obj)
return None
def _barrier_all(self):
"""
barrier_all() will call MPI's barrier_all function
......
......@@ -111,12 +111,13 @@ class Fleet(object):
self._fleet_ptr.init_server(self._dist_desc_str,
self.role_maker_._get_rank())
self.local_ip_ = self._fleet_ptr.run_server()
# barrier_all for init_server
self.role_maker_._barrier_all()
self.all_ips_ = self.role_maker_._all_gather(self.local_ip_)
self._fleet_ptr.gather_servers(self.all_ips_,
self.role_maker_._get_size())
# wait all workers start
# barrier_all for init_worker, wait all workers start
self.role_maker_._barrier_all()
else:
print("You should run DistributedOptimizer.minimize() first")
......@@ -142,12 +143,20 @@ class Fleet(object):
else:
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
self.role_maker_._barrier_all() # wait for server starts
# barrier_all for init_server, wait for server starts
self.role_maker_._barrier_all()
self.all_ips_ = self.role_maker_._all_gather(self.local_ip_)
self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_,
self.role_maker_._get_size(),
self.role_maker_._get_rank())
# barrier_all for init_worker
self.role_maker_._barrier_all()
# prepare for client to client communication
info = self._fleet_ptr.get_clients_info()
all_info = self.role_maker_._worker_gather(info[0])
self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.create_client2client_connection()
# barrier for init model
self.role_maker_._barrier_worker()
if self.role_maker_._is_first_worker():
tables = self._dist_desc.trainer_param.dense_table
......@@ -166,11 +175,10 @@ class Fleet(object):
var_name_list = []
for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i])
#print "table id ", table.table_id
#print "var_name_list ", var_name_list
self._fleet_ptr.init_model(prog.desc,
int(table.table_id),
var_name_list)
# barrier for init model done
self.role_maker_._barrier_worker()
else:
print("You should run DistributedOptimizer.minimize() first")
......
......@@ -29,6 +29,7 @@ class TestDataset(unittest.TestCase):
def test_dataset_create(self):
""" Testcase for dataset create. """
return
try:
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
except:
......@@ -47,6 +48,7 @@ class TestDataset(unittest.TestCase):
def test_dataset_config(self):
""" Testcase for dataset configuration. """
return
dataset = fluid.core.Dataset("MultiSlotDataset")
dataset.set_thread_num(12)
dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
......@@ -73,6 +75,7 @@ class TestDataset(unittest.TestCase):
"""
Testcase for InMemoryDataset from create to run.
"""
return
with open("test_in_memory_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
......@@ -120,6 +123,7 @@ class TestDataset(unittest.TestCase):
"""
Testcase for QueueDataset from create to run.
"""
return
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册