提交 1a032f40 编写于 作者: R rensilin

openmp

Change-Id: I51728c87923cc613acd8ef5afddbb4501f40cd1c
上级 3515e5fc
WORKROOT('../../../')
COMPILER('gcc482')
CPPFLAGS('-D_GNU_SOURCE -DNDEBUG')
GLOBAL_CFLAGS_STR = '-g -O3 -pipe '
GLOBAL_CFLAGS_STR = '-g -O3 -pipe -fopenmp '
CFLAGS(GLOBAL_CFLAGS_STR)
GLOBAL_CXXFLAGS_STR = GLOBAL_CFLAGS_STR + ' -std=c++11 '
CXXFLAGS(GLOBAL_CXXFLAGS_STR)
......
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include <cstdio>
#include <atomic>
#include <glog/logging.h>
#include <omp.h>
......@@ -136,20 +137,17 @@ public:
auto file_list = data_file_list(data_dir);
int file_list_size = file_list.size();
std::atomic<bool> is_failed(false);
VLOG(5) << "omg max_threads: " << omp_get_max_threads();
#pragma omp parallel for
for (int i = 0; i < file_list_size; ++i) {
VLOG(5) << "omg num_threads: " << omp_get_num_threads() << ", start read: " << i << std::endl;
}
for (int i = 0; i < file_list_size; ++i) {
//VLOG(5) << "omg num_threads: " << omp_get_num_threads() << ", start read: " << i;
const auto& filepath = file_list[i];
{
if (!is_failed) {
std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
if (fin == nullptr) {
VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
return -1;
is_failed = true;
continue;
}
char *buffer = nullptr;
size_t buffer_size = 0;
......@@ -172,21 +170,23 @@ public:
}
if (ferror(fin.get()) != 0) {
VLOG(2) << "fail to read file: " << filepath;
return -1;
is_failed = true;
continue;
}
}
if (_file_system->err_no() != 0) {
_file_system->reset_err_no();
return -1;
is_failed = true;
continue;
}
}
writer->Flush();
if (!(*writer)) {
VLOG(2) << "fail when write to channel";
return -1;
is_failed = true;
}
data_channel->Close();
return 0;
return is_failed ? -1 : 0;
}
virtual const DataParser* get_parser() {
......@@ -194,7 +194,7 @@ public:
}
private:
std::string _done_file_name; // without data_dirq
std::string _done_file_name; // without data_dir
std::string _filename_prefix;
std::unique_ptr<FileSystem> _file_system;
};
......
......@@ -104,6 +104,7 @@ public:
std::shared_ptr<TrainerContext> context_ptr;
std::unique_ptr<FileSystem> fs;
int thread_num = 1;
const int n_run = 5;
};
std::vector<DataItem> DataReaderOmpTest::std_items;
......@@ -128,7 +129,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
ASSERT_EQ(string::format_string("%s/%s.txt", test_data_dir, std_items[i].id.c_str()), data_file_list[i]);
}
constexpr int n_run = 10;
int same_count = 0;
for (int i = 0; i < n_run; ++i) {
auto channel = framework::MakeChannel<DataItem>(128);
......@@ -172,7 +172,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
fout.close();
ASSERT_TRUE(data_reader->is_data_ready(test_data_dir));
constexpr int n_run = 10;
int same_count = 0;
int sort_same_count = 0;
for (int i = 0; i < n_run; ++i) {
......@@ -200,7 +199,8 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
}
// n_run次有不同的(证明是多线程)
// ASSERT_GT(n_run, same_count);
ASSERT_EQ(4, omp_get_max_threads());
ASSERT_GT(n_run, same_count);
// 但排序后都是相同的
ASSERT_EQ(n_run, sort_same_count);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册