From 271b7147cc60ebff22ace30c1325307a1e7ceaca Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Wed, 3 Apr 2019 11:02:11 +0800 Subject: [PATCH] fix dataset bug test=develop --- paddle/fluid/framework/data_feed.cc | 47 +++++++++++++++++-- paddle/fluid/framework/data_feed.h | 7 +++ paddle/fluid/framework/data_set.cc | 22 +++++++++ paddle/fluid/framework/data_set.h | 14 ++++++ paddle/fluid/pybind/data_set_py.cc | 6 +++ python/paddle/fluid/executor.py | 2 +- .../fluid/tests/unittests/test_dataset.py | 14 ++---- 7 files changed, 95 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index e4e9861e3..365117bf8 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -237,11 +237,21 @@ void InMemoryDataFeed::SetThreadNum(int thread_num) { thread_num_ = thread_num; } +template +void InMemoryDataFeed::SetTrainerId(int trainer_id) { + trainer_id_ = trainer_id; +} + template void InMemoryDataFeed::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } +template +void InMemoryDataFeed::SetFleetSendBatchSize(int64_t size) { + fleet_send_batch_size_ = size; +} + template void InMemoryDataFeed::PutInsToChannel(const std::string& ins_str) { #ifdef _LINUX @@ -361,8 +371,15 @@ void InMemoryDataFeed::GlobalShuffle() { VLOG(3) << "GlobalShuffle() begin, thread_id=" << thread_id_; auto fleet_ptr = FleetWrapper::GetInstance(); std::vector> send_vec(trainer_num_); + std::vector send_index(trainer_num_); + std::vector local_send_vec; + uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_; for (auto& vec : send_vec) { - vec.reserve(fleet_send_batch_size_); + vec.reserve(reserve_len); + } + local_send_vec.reserve(reserve_len); + for (int i = 0; i < trainer_num_; ++i) { + send_index[i] = i; } std::vector> total_status; auto interval = GetMemoryDataInterval(); @@ -373,9 +390,23 @@ void InMemoryDataFeed::GlobalShuffle() { // std::string ins_id = memory_data_[i].ins_id; int64_t random_num = rand_r(&rand_seed); int64_t node_id = random_num % trainer_num_; - send_vec[node_id].push_back(&((*memory_data_)[i])); + if (node_id == trainer_id_) { + local_send_vec.push_back((*memory_data_)[i]); + } else { + send_vec[node_id].push_back(&((*memory_data_)[i])); + } if (i % fleet_send_batch_size_ == 0 && i != 0) { - for (int j = 0; j < send_vec.size(); ++j) { + // shuffle the sequence of sending to avoid network timeout error + std::random_shuffle(send_index.begin(), send_index.end()); + for (int index = 0; index < send_index.size(); ++index) { + int j = send_index[index]; + if (j == trainer_id_) { + VLOG(3) << "send to local, ins num=" << local_send_vec.size() + << ", node_id=" << j << ", thread_id=" << thread_id_; + shuffled_ins_->Extend(std::move(local_send_vec)); + local_send_vec.clear(); + continue; + } std::string send_str; SerializeIns(send_vec[j], &send_str); VLOG(3) << "send str_length=" << send_str.length() @@ -388,8 +419,14 @@ void InMemoryDataFeed::GlobalShuffle() { } } } - for (int j = 0; j < send_vec.size(); ++j) { - if (send_vec[j].size() != 0) { + // shuffle the sequence of sending to avoid network timeout error + std::random_shuffle(send_index.begin(), send_index.end()); + for (int index = 0; index < send_index.size(); ++index) { + int j = send_index[index]; + if (j == trainer_id_ && local_send_vec.size() != 0) { + shuffled_ins_->Extend(std::move(local_send_vec)); + std::vector().swap(local_send_vec); + } else if (send_vec[j].size() != 0) { std::string send_str; SerializeIns(send_vec[j], &send_str); VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 8ea09b65d..e657e1d63 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -91,9 +91,13 @@ class DataFeed { // This function will do nothing at default virtual void SetThreadId(int thread_id) {} // This function will do nothing at default + virtual void SetTrainerId(int trainer_id) {} + // This function will do nothing at default virtual void SetThreadNum(int thread_num) {} // This function will do nothing at default virtual void SetTrainerNum(int trainer_num) {} + // This function will do nothing at default + virtual void SetFleetSendBatchSize(int64_t size) {} virtual void SetFileListMutex(std::mutex* mutex) { mutex_for_pick_file_ = mutex; } @@ -211,7 +215,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed { virtual void SetMemoryDataMutex(std::mutex* mutex); virtual void SetThreadId(int thread_id); virtual void SetThreadNum(int thread_num); + virtual void SetTrainerId(int trainer_id); virtual void SetTrainerNum(int trainer_num); + virtual void SetFleetSendBatchSize(int64_t size); virtual void PutInsToChannel(const std::string& ins_str); virtual void FillMemoryDataToChannel(); virtual void FillChannelToMemoryData(); @@ -231,6 +237,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed { int thread_id_; int thread_num_; + int trainer_id_; int trainer_num_; uint32_t rand_seed; std::vector* memory_data_; diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 600fc7471..4df7d6af0 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -52,6 +52,17 @@ void DatasetImpl::SetThreadNum(int thread_num) { thread_num_ = thread_num; } +// if you run distributed, and want to do global shuffle, +// set this before global shuffle. +// be sure you call CreateReaders before SetTrainerId +template +void DatasetImpl::SetTrainerId(int trainer_id) { + trainer_id_ = trainer_id; + for (auto reader : readers_) { + reader->SetTrainerId(trainer_id); + } +} + // if you run distributed, and want to do global shuffle, // set this before global shuffle. // be sure you call CreateReaders before SetTrainerNum @@ -64,6 +75,17 @@ void DatasetImpl::SetTrainerNum(int trainer_num) { } } +// if you run distributed, and want to do global shuffle, +// set this before global shuffle. +// be sure you call CreateReaders before SetFleetSendBatchSize +template +void DatasetImpl::SetFleetSendBatchSize(int64_t size) { + fleet_send_batch_size_ = size; + for (auto reader : readers_) { + reader->SetFleetSendBatchSize(size); + } +} + template void DatasetImpl::SetHdfsConfig(const std::string& fs_name, const std::string& fs_ugi) { diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 6fd3fcad2..42073934d 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -45,8 +45,12 @@ class Dataset { virtual void SetFileList(const std::vector& filelist) = 0; // set readers' num virtual void SetThreadNum(int thread_num) = 0; + // set worker rank + virtual void SetTrainerId(int trainer_id) = 0; // set workers' num virtual void SetTrainerNum(int trainer_num) = 0; + // set fleet send batch size + virtual void SetFleetSendBatchSize(int64_t size) = 0; // set fs name and ugi virtual void SetHdfsConfig(const std::string& fs_name, const std::string& fs_ugi) = 0; @@ -57,8 +61,12 @@ class Dataset { virtual const std::vector& GetFileList() = 0; // get thread num virtual int GetThreadNum() = 0; + // get worker rank + virtual int GetTrainerId() = 0; // get worker num virtual int GetTrainerNum() = 0; + // get fleet send batch size + virtual int64_t GetFleetSendBatchSize() = 0; // get hdfs config virtual std::pair GetHdfsConfig() = 0; // get data fedd desc @@ -97,14 +105,18 @@ class DatasetImpl : public Dataset { virtual void SetFileList(const std::vector& filelist); virtual void SetThreadNum(int thread_num); + virtual void SetTrainerId(int trainer_id); virtual void SetTrainerNum(int trainer_num); + virtual void SetFleetSendBatchSize(int64_t size); virtual void SetHdfsConfig(const std::string& fs_name, const std::string& fs_ugi); virtual void SetDataFeedDesc(const std::string& data_feed_desc_str); virtual const std::vector& GetFileList() { return filelist_; } virtual int GetThreadNum() { return thread_num_; } + virtual int GetTrainerId() { return trainer_id_; } virtual int GetTrainerNum() { return trainer_num_; } + virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; } virtual std::pair GetHdfsConfig() { return std::make_pair(fs_name_, fs_ugi_); } @@ -130,6 +142,7 @@ class DatasetImpl : public Dataset { std::mutex mutex_for_update_memory_data_; int thread_num_; paddle::framework::DataFeedDesc data_feed_desc_; + int trainer_id_; int trainer_num_; std::vector filelist_; size_t file_idx_; @@ -137,6 +150,7 @@ class DatasetImpl : public Dataset { std::string fs_name_; std::string fs_ugi_; unsigned int rand_seed; + int64_t fleet_send_batch_size_; }; // use std::vector as data type diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index b773fd03c..0c7bd4752 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -49,12 +49,18 @@ void BindDataset(py::module* m) { })) .def("set_filelist", &framework::Dataset::SetFileList) .def("set_thread_num", &framework::Dataset::SetThreadNum) + .def("set_trainer_id", &framework::Dataset::SetTrainerId) .def("set_trainer_num", &framework::Dataset::SetTrainerNum) + .def("set_fleet_send_batch_size", + &framework::Dataset::SetFleetSendBatchSize) .def("set_hdfs_config", &framework::Dataset::SetHdfsConfig) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) .def("get_filelist", &framework::Dataset::GetFileList) .def("get_thread_num", &framework::Dataset::GetThreadNum) + .def("get_trainer_id", &framework::Dataset::GetTrainerId) .def("get_trainer_num", &framework::Dataset::GetTrainerNum) + .def("get_fleet_send_batch_size", + &framework::Dataset::GetFleetSendBatchSize) .def("get_hdfs_config", &framework::Dataset::GetHdfsConfig) .def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc) .def("register_client2client_msg_handler", diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index e4666deb7..e53633950 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -796,7 +796,7 @@ class Executor(object): if dataset == None: raise RuntimeError("dataset is need and should be initialized") - if self.place == paddle.fluid.CUDAPlace(): + if not isinstance(self.place, core.CPUPlace): raise RuntimeError("train_from_dataset is verified on CPUPlace" "We will open CUDAPlace in the future") diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 8c705a095..39094323f 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -29,7 +29,6 @@ class TestDataset(unittest.TestCase): def test_dataset_create(self): """ Testcase for dataset create. """ - return try: dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") except: @@ -48,7 +47,6 @@ class TestDataset(unittest.TestCase): def test_dataset_config(self): """ Testcase for dataset configuration. """ - return dataset = fluid.core.Dataset("MultiSlotDataset") dataset.set_thread_num(12) dataset.set_filelist(["a.txt", "b.txt", "c.txt"]) @@ -75,7 +73,6 @@ class TestDataset(unittest.TestCase): """ Testcase for InMemoryDataset from create to run. """ - return with open("test_in_memory_dataset_run_a.txt", "w") as f: data = "1 1 2 3 3 4 5 5 5 5 1 1\n" data += "1 2 2 3 4 4 6 6 6 6 1 2\n" @@ -113,8 +110,7 @@ class TestDataset(unittest.TestCase): try: exe.train_from_dataset(fluid.default_main_program(), dataset) except: - #self.assertTrue(False) - pass + self.assertTrue(False) os.remove("./test_in_memory_dataset_run_a.txt") os.remove("./test_in_memory_dataset_run_b.txt") @@ -123,7 +119,6 @@ class TestDataset(unittest.TestCase): """ Testcase for QueueDataset from create to run. """ - return with open("test_queue_dataset_run_a.txt", "w") as f: data = "1 1 2 3 3 4 5 5 5 5 1 1\n" data += "1 2 2 3 4 4 6 6 6 6 1 2\n" @@ -157,14 +152,11 @@ class TestDataset(unittest.TestCase): try: exe.train_from_dataset(fluid.default_main_program(), dataset) except: - #self.assertTrue(False) - pass + self.assertTrue(False) os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_b.txt") if __name__ == '__main__': - #unittest.main() - import sys - sys.exit(0) + unittest.main() -- GitLab