提交 1a032f40 编写于 作者: R rensilin

openmp

Change-Id: I51728c87923cc613acd8ef5afddbb4501f40cd1c
上级 3515e5fc
WORKROOT('../../../') WORKROOT('../../../')
COMPILER('gcc482') COMPILER('gcc482')
CPPFLAGS('-D_GNU_SOURCE -DNDEBUG') CPPFLAGS('-D_GNU_SOURCE -DNDEBUG')
GLOBAL_CFLAGS_STR = '-g -O3 -pipe ' GLOBAL_CFLAGS_STR = '-g -O3 -pipe -fopenmp '
CFLAGS(GLOBAL_CFLAGS_STR) CFLAGS(GLOBAL_CFLAGS_STR)
GLOBAL_CXXFLAGS_STR = GLOBAL_CFLAGS_STR + ' -std=c++11 ' GLOBAL_CXXFLAGS_STR = GLOBAL_CFLAGS_STR + ' -std=c++11 '
CXXFLAGS(GLOBAL_CXXFLAGS_STR) CXXFLAGS(GLOBAL_CXXFLAGS_STR)
......
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h" #include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include <cstdio> #include <cstdio>
#include <atomic>
#include <glog/logging.h> #include <glog/logging.h>
#include <omp.h> #include <omp.h>
...@@ -136,20 +137,17 @@ public: ...@@ -136,20 +137,17 @@ public:
auto file_list = data_file_list(data_dir); auto file_list = data_file_list(data_dir);
int file_list_size = file_list.size(); int file_list_size = file_list.size();
std::atomic<bool> is_failed(false);
VLOG(5) << "omg max_threads: " << omp_get_max_threads();
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < file_list_size; ++i) { for (int i = 0; i < file_list_size; ++i) {
VLOG(5) << "omg num_threads: " << omp_get_num_threads() << ", start read: " << i << std::endl;
}
for (int i = 0; i < file_list_size; ++i) {
//VLOG(5) << "omg num_threads: " << omp_get_num_threads() << ", start read: " << i;
const auto& filepath = file_list[i]; const auto& filepath = file_list[i];
{ if (!is_failed) {
std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd); std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
if (fin == nullptr) { if (fin == nullptr) {
VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd; VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
return -1; is_failed = true;
continue;
} }
char *buffer = nullptr; char *buffer = nullptr;
size_t buffer_size = 0; size_t buffer_size = 0;
...@@ -172,21 +170,23 @@ public: ...@@ -172,21 +170,23 @@ public:
} }
if (ferror(fin.get()) != 0) { if (ferror(fin.get()) != 0) {
VLOG(2) << "fail to read file: " << filepath; VLOG(2) << "fail to read file: " << filepath;
return -1; is_failed = true;
continue;
} }
} }
if (_file_system->err_no() != 0) { if (_file_system->err_no() != 0) {
_file_system->reset_err_no(); _file_system->reset_err_no();
return -1; is_failed = true;
continue;
} }
} }
writer->Flush(); writer->Flush();
if (!(*writer)) { if (!(*writer)) {
VLOG(2) << "fail when write to channel"; VLOG(2) << "fail when write to channel";
return -1; is_failed = true;
} }
data_channel->Close(); data_channel->Close();
return 0; return is_failed ? -1 : 0;
} }
virtual const DataParser* get_parser() { virtual const DataParser* get_parser() {
...@@ -194,7 +194,7 @@ public: ...@@ -194,7 +194,7 @@ public:
} }
private: private:
std::string _done_file_name; // without data_dirq std::string _done_file_name; // without data_dir
std::string _filename_prefix; std::string _filename_prefix;
std::unique_ptr<FileSystem> _file_system; std::unique_ptr<FileSystem> _file_system;
}; };
......
...@@ -104,6 +104,7 @@ public: ...@@ -104,6 +104,7 @@ public:
std::shared_ptr<TrainerContext> context_ptr; std::shared_ptr<TrainerContext> context_ptr;
std::unique_ptr<FileSystem> fs; std::unique_ptr<FileSystem> fs;
int thread_num = 1; int thread_num = 1;
const int n_run = 5;
}; };
std::vector<DataItem> DataReaderOmpTest::std_items; std::vector<DataItem> DataReaderOmpTest::std_items;
...@@ -128,7 +129,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { ...@@ -128,7 +129,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
ASSERT_EQ(string::format_string("%s/%s.txt", test_data_dir, std_items[i].id.c_str()), data_file_list[i]); ASSERT_EQ(string::format_string("%s/%s.txt", test_data_dir, std_items[i].id.c_str()), data_file_list[i]);
} }
constexpr int n_run = 10;
int same_count = 0; int same_count = 0;
for (int i = 0; i < n_run; ++i) { for (int i = 0; i < n_run; ++i) {
auto channel = framework::MakeChannel<DataItem>(128); auto channel = framework::MakeChannel<DataItem>(128);
...@@ -172,7 +172,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) { ...@@ -172,7 +172,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
fout.close(); fout.close();
ASSERT_TRUE(data_reader->is_data_ready(test_data_dir)); ASSERT_TRUE(data_reader->is_data_ready(test_data_dir));
constexpr int n_run = 10;
int same_count = 0; int same_count = 0;
int sort_same_count = 0; int sort_same_count = 0;
for (int i = 0; i < n_run; ++i) { for (int i = 0; i < n_run; ++i) {
...@@ -200,7 +199,8 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) { ...@@ -200,7 +199,8 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
} }
// n_run次有不同的(证明是多线程) // n_run次有不同的(证明是多线程)
// ASSERT_GT(n_run, same_count); ASSERT_EQ(4, omp_get_max_threads());
ASSERT_GT(n_run, same_count);
// 但排序后都是相同的 // 但排序后都是相同的
ASSERT_EQ(n_run, sort_same_count); ASSERT_EQ(n_run, sort_same_count);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册