From db4a4db072d05ecc382b826dba6d64c5e468adf8 Mon Sep 17 00:00:00 2001 From: rensilin Date: Fri, 16 Aug 2019 12:15:08 +0800 Subject: [PATCH] fs_bug Change-Id: I7e92af98dc56e18b79640f070b13a26f9c94ab52 --- .../feed/dataset/data_reader.cc | 97 +++++++++++-------- .../feed/unit_test/test_datareader_omp.cc | 52 +++++----- 2 files changed, 84 insertions(+), 65 deletions(-) diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc index 31ed1870..48b9e1fe 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc @@ -131,53 +131,66 @@ public: return read_all(file_list, data_channel); } virtual int read_all(const std::vector& file_list, ::paddle::framework::Channel data_channel) { - auto deleter = [](framework::ChannelWriter *writer) { - if (writer) { - writer->Flush(); - VLOG(3) << "writer auto flush"; - } - delete writer; - }; - std::unique_ptr, decltype(deleter)> writer(new framework::ChannelWriter(data_channel.get()), deleter); DataItem data_item; - - int file_list_size = file_list.size(); + const int file_list_size = file_list.size(); std::atomic is_failed(false); + const int max_threads = omp_get_max_threads(); + std::vector> writers; // writer is not thread safe + writers.reserve(max_threads); + for (int i = 0; i < max_threads; ++i) { + writers.emplace_back(data_channel.get()); + } + VLOG(5) << "file_list: " << string::join_strings(file_list, ' '); #pragma omp parallel for for (int i = 0; i < file_list_size; ++i) { + if (is_failed) { + continue; + } + const int thread_num = omp_get_thread_num(); + framework::ChannelWriter *writer = nullptr; + if (thread_num < max_threads) { + writer = &writers[thread_num]; + } const auto& filepath = file_list[i]; - if (!is_failed) { - std::shared_ptr fin = _file_system->open_read(filepath, _pipeline_cmd); - if (fin == nullptr) { - VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd; - is_failed = true; + std::shared_ptr fin = _file_system->open_read(filepath, _pipeline_cmd); + if (fin == nullptr) { + VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd; + is_failed = true; + continue; + } + char *buffer = nullptr; + size_t buffer_size = 0; + ssize_t line_len = 0; + while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) { + // 去掉行位回车 + if (line_len > 0 && buffer[line_len - 1] == '\n') { + buffer[--line_len] = '\0'; + } + // 忽略空行 + if (line_len <= 0) { continue; } - char *buffer = nullptr; - size_t buffer_size = 0; - ssize_t line_len = 0; - while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) { - if (line_len > 0 && buffer[line_len - 1] == '\n') { - buffer[--line_len] = '\0'; - } - if (line_len <= 0) { - continue; - } - if (_parser->parse(buffer, line_len, data_item) == 0) { + if (_parser->parse(buffer, line_len, data_item) == 0) { + VLOG(5) << "parse data: " << data_item.id << " " << data_item.data << ", filename: " << filepath << ", thread_num: " << thread_num << ", max_threads: " << max_threads; + if (writer == nullptr) { + if (!data_channel->Put(std::move(data_item))) { + VLOG(2) << "fail to put data, thread_num: " << thread_num; + } + } else { (*writer) << std::move(data_item); } } - if (buffer != nullptr) { - free(buffer); - buffer = nullptr; - buffer_size = 0; - } - if (ferror(fin.get()) != 0) { - VLOG(2) << "fail to read file: " << filepath; - is_failed = true; - continue; - } + } + if (buffer != nullptr) { + free(buffer); + buffer = nullptr; + buffer_size = 0; + } + if (ferror(fin.get()) != 0) { + VLOG(2) << "fail to read file: " << filepath; + is_failed = true; + continue; } if (_file_system->err_no() != 0) { _file_system->reset_err_no(); @@ -185,10 +198,14 @@ public: continue; } } - writer->Flush(); - if (!(*writer)) { - VLOG(2) << "fail when write to channel"; - is_failed = true; + // omp end + + for (int i = 0; i < max_threads; ++i) { + writers[i].Flush(); + if (!writers[i]) { + VLOG(2) << "writer " << i << " is failed"; + is_failed = true; + } } data_channel->Close(); return is_failed ? -1 : 0; diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc index 8ac7874e..353c6741 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc @@ -38,6 +38,9 @@ class DataReaderOmpTest : public testing::Test { public: static void SetUpTestCase() { std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + if (fs->exists(test_data_dir)) { + fs->remove(test_data_dir); + } fs->mkdir(test_data_dir); shell_set_verbose(true); std_items.clear(); @@ -92,11 +95,12 @@ public: } static void read_all(framework::Channel& channel, std::vector& items) { - framework::ChannelReader reader(channel.get()); - DataItem data_item; - while (reader >> data_item) { - items.push_back(std::move(data_item)); - } + channel->ReadAll(items); + // framework::ChannelReader reader(channel.get()); + // DataItem data_item; + // while (reader >> data_item) { + // items.push_back(std::move(data_item)); + // } } static bool is_same_with_std_items(const std::vector& items) { @@ -107,6 +111,14 @@ public: return is_same(items, sorted_std_items); } + static std::string to_string(const std::vector& items) { + std::string items_str = ""; + for (const auto& item : items) { + items_str.append(item.id); + } + return items_str; + } + static std::vector std_items; static std::vector sorted_std_items; std::shared_ptr context_ptr; @@ -137,7 +149,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { ASSERT_EQ(string::format_string("%s/%s.txt", test_data_dir, std_items[i].id.c_str()), data_file_list[i]); } - int same_count = 0; for (int i = 0; i < n_run; ++i) { auto channel = framework::MakeChannel(128); ASSERT_NE(nullptr, channel); @@ -146,13 +157,8 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { std::vector items; read_all(channel, items); - if (is_same_with_std_items(items)) { - ++same_count; - } + ASSERT_TRUE(is_same_with_std_items(items)); } - - // n_run 次都相同 - ASSERT_EQ(n_run, same_count); } TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) { @@ -188,36 +194,32 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) { omp_set_num_threads(4); + channel->SetBlockSize(1); ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel)); std::vector items; read_all(channel, items); + ASSERT_EQ(std_items_size, items.size()); + if (is_same_with_std_items(items)) { ++same_count; } - + VLOG(5) << "before sort items: " << to_string(items); std::sort(items.begin(), items.end(), [] (const DataItem& a, const DataItem& b) { return a.id < b.id; }); - - if (is_same_with_sorted_std_items(items)) { - ++sort_same_count; - } else { - std::string items_str = ""; - for (const auto& item: items) { - items_str.append(item.id); - } - VLOG(2) << "items: " << items_str; + bool is_same_with_std = is_same_with_sorted_std_items(items); + if (!is_same_with_std) { + VLOG(5) << "after sort items: " << to_string(items); } - + // 排序后都是相同的 + ASSERT_TRUE(is_same_with_std); } // n_run次有不同的(证明是多线程) ASSERT_EQ(4, omp_get_max_threads()); ASSERT_GT(n_run, same_count); - // 但排序后都是相同的 - ASSERT_EQ(n_run, sort_same_count); } } // namespace feed -- GitLab