提交 6a57e807 编写于 作者: X xjqbest

remove trainer_id in datafeed and dataset

test=develop
上级 7a759d76
...@@ -237,11 +237,6 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) { ...@@ -237,11 +237,6 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num; thread_num_ = thread_num;
} }
template <typename T>
void InMemoryDataFeed<T>::SetTrainerId(int trainer_id) {
trainer_id_ = trainer_id;
}
template <typename T> template <typename T>
void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) { void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num; trainer_num_ = trainer_num;
...@@ -372,12 +367,10 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -372,12 +367,10 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::vector<T*>> send_vec(trainer_num_); std::vector<std::vector<T*>> send_vec(trainer_num_);
std::vector<int> send_index(trainer_num_); std::vector<int> send_index(trainer_num_);
std::vector<T> local_send_vec;
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_; uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_;
for (auto& vec : send_vec) { for (auto& vec : send_vec) {
vec.reserve(reserve_len); vec.reserve(reserve_len);
} }
local_send_vec.reserve(reserve_len);
for (int i = 0; i < trainer_num_; ++i) { for (int i = 0; i < trainer_num_; ++i) {
send_index[i] = i; send_index[i] = i;
} }
...@@ -390,23 +383,12 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -390,23 +383,12 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
// std::string ins_id = memory_data_[i].ins_id; // std::string ins_id = memory_data_[i].ins_id;
int64_t random_num = rand_r(&rand_seed); int64_t random_num = rand_r(&rand_seed);
int64_t node_id = random_num % trainer_num_; int64_t node_id = random_num % trainer_num_;
if (node_id == trainer_id_) { send_vec[node_id].push_back(&((*memory_data_)[i]));
local_send_vec.push_back((*memory_data_)[i]);
} else {
send_vec[node_id].push_back(&((*memory_data_)[i]));
}
if (i % fleet_send_batch_size_ == 0 && i != 0) { if (i % fleet_send_batch_size_ == 0 && i != 0) {
// shuffle the sequence of sending to avoid network timeout error // shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end()); std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) { for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index]; int j = send_index[index];
if (j == trainer_id_) {
VLOG(3) << "send to local, ins num=" << local_send_vec.size()
<< ", node_id=" << j << ", thread_id=" << thread_id_;
shuffled_ins_->Extend(std::move(local_send_vec));
local_send_vec.clear();
continue;
}
std::string send_str; std::string send_str;
SerializeIns(send_vec[j], &send_str); SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() VLOG(3) << "send str_length=" << send_str.length()
...@@ -423,10 +405,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() { ...@@ -423,10 +405,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
std::random_shuffle(send_index.begin(), send_index.end()); std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) { for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index]; int j = send_index[index];
if (j == trainer_id_ && local_send_vec.size() != 0) { if (send_vec[j].size() != 0) {
shuffled_ins_->Extend(std::move(local_send_vec));
std::vector<T>().swap(local_send_vec);
} else if (send_vec[j].size() != 0) {
std::string send_str; std::string send_str;
SerializeIns(send_vec[j], &send_str); SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j
......
...@@ -91,8 +91,6 @@ class DataFeed { ...@@ -91,8 +91,6 @@ class DataFeed {
// This function will do nothing at default // This function will do nothing at default
virtual void SetThreadId(int thread_id) {} virtual void SetThreadId(int thread_id) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetTrainerId(int trainer_id) {}
// This function will do nothing at default
virtual void SetThreadNum(int thread_num) {} virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) {} virtual void SetTrainerNum(int trainer_num) {}
...@@ -215,7 +213,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -215,7 +213,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void SetMemoryDataMutex(std::mutex* mutex); virtual void SetMemoryDataMutex(std::mutex* mutex);
virtual void SetThreadId(int thread_id); virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetTrainerId(int trainer_id);
virtual void SetTrainerNum(int trainer_num); virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size); virtual void SetFleetSendBatchSize(int64_t size);
virtual void PutInsToChannel(const std::string& ins_str); virtual void PutInsToChannel(const std::string& ins_str);
...@@ -237,7 +234,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -237,7 +234,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
int thread_id_; int thread_id_;
int thread_num_; int thread_num_;
int trainer_id_;
int trainer_num_; int trainer_num_;
uint32_t rand_seed; uint32_t rand_seed;
std::vector<T>* memory_data_; std::vector<T>* memory_data_;
......
...@@ -52,17 +52,6 @@ void DatasetImpl<T>::SetThreadNum(int thread_num) { ...@@ -52,17 +52,6 @@ void DatasetImpl<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num; thread_num_ = thread_num;
} }
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerId
template <typename T>
void DatasetImpl<T>::SetTrainerId(int trainer_id) {
trainer_id_ = trainer_id;
for (auto reader : readers_) {
reader->SetTrainerId(trainer_id);
}
}
// if you run distributed, and want to do global shuffle, // if you run distributed, and want to do global shuffle,
// set this before global shuffle. // set this before global shuffle.
// be sure you call CreateReaders before SetTrainerNum // be sure you call CreateReaders before SetTrainerNum
......
...@@ -45,8 +45,6 @@ class Dataset { ...@@ -45,8 +45,6 @@ class Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist) = 0; virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
// set readers' num // set readers' num
virtual void SetThreadNum(int thread_num) = 0; virtual void SetThreadNum(int thread_num) = 0;
// set worker rank
virtual void SetTrainerId(int trainer_id) = 0;
// set workers' num // set workers' num
virtual void SetTrainerNum(int trainer_num) = 0; virtual void SetTrainerNum(int trainer_num) = 0;
// set fleet send batch size // set fleet send batch size
...@@ -61,8 +59,6 @@ class Dataset { ...@@ -61,8 +59,6 @@ class Dataset {
virtual const std::vector<std::string>& GetFileList() = 0; virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num // get thread num
virtual int GetThreadNum() = 0; virtual int GetThreadNum() = 0;
// get worker rank
virtual int GetTrainerId() = 0;
// get worker num // get worker num
virtual int GetTrainerNum() = 0; virtual int GetTrainerNum() = 0;
// get fleet send batch size // get fleet send batch size
...@@ -105,7 +101,6 @@ class DatasetImpl : public Dataset { ...@@ -105,7 +101,6 @@ class DatasetImpl : public Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist); virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetTrainerId(int trainer_id);
virtual void SetTrainerNum(int trainer_num); virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size); virtual void SetFleetSendBatchSize(int64_t size);
virtual void SetHdfsConfig(const std::string& fs_name, virtual void SetHdfsConfig(const std::string& fs_name,
...@@ -114,7 +109,6 @@ class DatasetImpl : public Dataset { ...@@ -114,7 +109,6 @@ class DatasetImpl : public Dataset {
virtual const std::vector<std::string>& GetFileList() { return filelist_; } virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; } virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerId() { return trainer_id_; }
virtual int GetTrainerNum() { return trainer_num_; } virtual int GetTrainerNum() { return trainer_num_; }
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; } virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() { virtual std::pair<std::string, std::string> GetHdfsConfig() {
...@@ -142,7 +136,6 @@ class DatasetImpl : public Dataset { ...@@ -142,7 +136,6 @@ class DatasetImpl : public Dataset {
std::mutex mutex_for_update_memory_data_; std::mutex mutex_for_update_memory_data_;
int thread_num_; int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_; paddle::framework::DataFeedDesc data_feed_desc_;
int trainer_id_;
int trainer_num_; int trainer_num_;
std::vector<std::string> filelist_; std::vector<std::string> filelist_;
size_t file_idx_; size_t file_idx_;
......
...@@ -49,7 +49,6 @@ void BindDataset(py::module* m) { ...@@ -49,7 +49,6 @@ void BindDataset(py::module* m) {
})) }))
.def("set_filelist", &framework::Dataset::SetFileList) .def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum) .def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_id", &framework::Dataset::SetTrainerId)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum) .def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_fleet_send_batch_size", .def("set_fleet_send_batch_size",
&framework::Dataset::SetFleetSendBatchSize) &framework::Dataset::SetFleetSendBatchSize)
...@@ -57,7 +56,6 @@ void BindDataset(py::module* m) { ...@@ -57,7 +56,6 @@ void BindDataset(py::module* m) {
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("get_filelist", &framework::Dataset::GetFileList) .def("get_filelist", &framework::Dataset::GetFileList)
.def("get_thread_num", &framework::Dataset::GetThreadNum) .def("get_thread_num", &framework::Dataset::GetThreadNum)
.def("get_trainer_id", &framework::Dataset::GetTrainerId)
.def("get_trainer_num", &framework::Dataset::GetTrainerNum) .def("get_trainer_num", &framework::Dataset::GetTrainerNum)
.def("get_fleet_send_batch_size", .def("get_fleet_send_batch_size",
&framework::Dataset::GetFleetSendBatchSize) &framework::Dataset::GetFleetSendBatchSize)
......
...@@ -240,15 +240,12 @@ class InMemoryDataset(DatasetBase): ...@@ -240,15 +240,12 @@ class InMemoryDataset(DatasetBase):
Args: Args:
fleet: fleet singleton. Default None. fleet: fleet singleton. Default None.
""" """
trainer_id = 0
trainer_num = 1 trainer_num = 1
fleet_send_batch_size = 80000 fleet_send_batch_size = 80000
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker() fleet.fleet_instance.role_maker_._barrier_worker()
trainer_id = fleet.worker_index()
trainer_num = fleet.worker_num() trainer_num = fleet.worker_num()
self.dataset.register_client2client_msg_handler() self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_id(trainer_id)
self.dataset.set_trainer_num(trainer_num) self.dataset.set_trainer_num(trainer_num)
self.dataset.set_fleet_send_batch_size(fleet_send_batch_size) self.dataset.set_fleet_send_batch_size(fleet_send_batch_size)
if fleet is not None: if fleet is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册