提交 db4a4db0 编写于 作者: R rensilin

fs_bug

Change-Id: I7e92af98dc56e18b79640f070b13a26f9c94ab52
上级 6a15698f
...@@ -131,53 +131,66 @@ public: ...@@ -131,53 +131,66 @@ public:
return read_all(file_list, data_channel); return read_all(file_list, data_channel);
} }
virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) { virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) {
auto deleter = [](framework::ChannelWriter<DataItem> *writer) {
if (writer) {
writer->Flush();
VLOG(3) << "writer auto flush";
}
delete writer;
};
std::unique_ptr<framework::ChannelWriter<DataItem>, decltype(deleter)> writer(new framework::ChannelWriter<DataItem>(data_channel.get()), deleter);
DataItem data_item; DataItem data_item;
const int file_list_size = file_list.size();
int file_list_size = file_list.size();
std::atomic<bool> is_failed(false); std::atomic<bool> is_failed(false);
const int max_threads = omp_get_max_threads();
std::vector<framework::ChannelWriter<DataItem>> writers; // writer is not thread safe
writers.reserve(max_threads);
for (int i = 0; i < max_threads; ++i) {
writers.emplace_back(data_channel.get());
}
VLOG(5) << "file_list: " << string::join_strings(file_list, ' ');
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < file_list_size; ++i) { for (int i = 0; i < file_list_size; ++i) {
if (is_failed) {
continue;
}
const int thread_num = omp_get_thread_num();
framework::ChannelWriter<DataItem> *writer = nullptr;
if (thread_num < max_threads) {
writer = &writers[thread_num];
}
const auto& filepath = file_list[i]; const auto& filepath = file_list[i];
if (!is_failed) { std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd); if (fin == nullptr) {
if (fin == nullptr) { VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd; is_failed = true;
is_failed = true; continue;
}
char *buffer = nullptr;
size_t buffer_size = 0;
ssize_t line_len = 0;
while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
// 去掉行位回车
if (line_len > 0 && buffer[line_len - 1] == '\n') {
buffer[--line_len] = '\0';
}
// 忽略空行
if (line_len <= 0) {
continue; continue;
} }
char *buffer = nullptr; if (_parser->parse(buffer, line_len, data_item) == 0) {
size_t buffer_size = 0; VLOG(5) << "parse data: " << data_item.id << " " << data_item.data << ", filename: " << filepath << ", thread_num: " << thread_num << ", max_threads: " << max_threads;
ssize_t line_len = 0; if (writer == nullptr) {
while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) { if (!data_channel->Put(std::move(data_item))) {
if (line_len > 0 && buffer[line_len - 1] == '\n') { VLOG(2) << "fail to put data, thread_num: " << thread_num;
buffer[--line_len] = '\0'; }
} } else {
if (line_len <= 0) {
continue;
}
if (_parser->parse(buffer, line_len, data_item) == 0) {
(*writer) << std::move(data_item); (*writer) << std::move(data_item);
} }
} }
if (buffer != nullptr) { }
free(buffer); if (buffer != nullptr) {
buffer = nullptr; free(buffer);
buffer_size = 0; buffer = nullptr;
} buffer_size = 0;
if (ferror(fin.get()) != 0) { }
VLOG(2) << "fail to read file: " << filepath; if (ferror(fin.get()) != 0) {
is_failed = true; VLOG(2) << "fail to read file: " << filepath;
continue; is_failed = true;
} continue;
} }
if (_file_system->err_no() != 0) { if (_file_system->err_no() != 0) {
_file_system->reset_err_no(); _file_system->reset_err_no();
...@@ -185,10 +198,14 @@ public: ...@@ -185,10 +198,14 @@ public:
continue; continue;
} }
} }
writer->Flush(); // omp end
if (!(*writer)) {
VLOG(2) << "fail when write to channel"; for (int i = 0; i < max_threads; ++i) {
is_failed = true; writers[i].Flush();
if (!writers[i]) {
VLOG(2) << "writer " << i << " is failed";
is_failed = true;
}
} }
data_channel->Close(); data_channel->Close();
return is_failed ? -1 : 0; return is_failed ? -1 : 0;
......
...@@ -38,6 +38,9 @@ class DataReaderOmpTest : public testing::Test { ...@@ -38,6 +38,9 @@ class DataReaderOmpTest : public testing::Test {
public: public:
static void SetUpTestCase() { static void SetUpTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
if (fs->exists(test_data_dir)) {
fs->remove(test_data_dir);
}
fs->mkdir(test_data_dir); fs->mkdir(test_data_dir);
shell_set_verbose(true); shell_set_verbose(true);
std_items.clear(); std_items.clear();
...@@ -92,11 +95,12 @@ public: ...@@ -92,11 +95,12 @@ public:
} }
static void read_all(framework::Channel<DataItem>& channel, std::vector<DataItem>& items) { static void read_all(framework::Channel<DataItem>& channel, std::vector<DataItem>& items) {
framework::ChannelReader<DataItem> reader(channel.get()); channel->ReadAll(items);
DataItem data_item; // framework::ChannelReader<DataItem> reader(channel.get());
while (reader >> data_item) { // DataItem data_item;
items.push_back(std::move(data_item)); // while (reader >> data_item) {
} // items.push_back(std::move(data_item));
// }
} }
static bool is_same_with_std_items(const std::vector<DataItem>& items) { static bool is_same_with_std_items(const std::vector<DataItem>& items) {
...@@ -107,6 +111,14 @@ public: ...@@ -107,6 +111,14 @@ public:
return is_same(items, sorted_std_items); return is_same(items, sorted_std_items);
} }
static std::string to_string(const std::vector<DataItem>& items) {
std::string items_str = "";
for (const auto& item : items) {
items_str.append(item.id);
}
return items_str;
}
static std::vector<DataItem> std_items; static std::vector<DataItem> std_items;
static std::vector<DataItem> sorted_std_items; static std::vector<DataItem> sorted_std_items;
std::shared_ptr<TrainerContext> context_ptr; std::shared_ptr<TrainerContext> context_ptr;
...@@ -137,7 +149,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { ...@@ -137,7 +149,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
ASSERT_EQ(string::format_string("%s/%s.txt", test_data_dir, std_items[i].id.c_str()), data_file_list[i]); ASSERT_EQ(string::format_string("%s/%s.txt", test_data_dir, std_items[i].id.c_str()), data_file_list[i]);
} }
int same_count = 0;
for (int i = 0; i < n_run; ++i) { for (int i = 0; i < n_run; ++i) {
auto channel = framework::MakeChannel<DataItem>(128); auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel); ASSERT_NE(nullptr, channel);
...@@ -146,13 +157,8 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { ...@@ -146,13 +157,8 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
std::vector<DataItem> items; std::vector<DataItem> items;
read_all(channel, items); read_all(channel, items);
if (is_same_with_std_items(items)) { ASSERT_TRUE(is_same_with_std_items(items));
++same_count;
}
} }
// n_run 次都相同
ASSERT_EQ(n_run, same_count);
} }
TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) { TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
...@@ -188,36 +194,32 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) { ...@@ -188,36 +194,32 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
omp_set_num_threads(4); omp_set_num_threads(4);
channel->SetBlockSize(1);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel)); ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
std::vector<DataItem> items; std::vector<DataItem> items;
read_all(channel, items); read_all(channel, items);
ASSERT_EQ(std_items_size, items.size());
if (is_same_with_std_items(items)) { if (is_same_with_std_items(items)) {
++same_count; ++same_count;
} }
VLOG(5) << "before sort items: " << to_string(items);
std::sort(items.begin(), items.end(), [] (const DataItem& a, const DataItem& b) { std::sort(items.begin(), items.end(), [] (const DataItem& a, const DataItem& b) {
return a.id < b.id; return a.id < b.id;
}); });
bool is_same_with_std = is_same_with_sorted_std_items(items);
if (is_same_with_sorted_std_items(items)) { if (!is_same_with_std) {
++sort_same_count; VLOG(5) << "after sort items: " << to_string(items);
} else {
std::string items_str = "";
for (const auto& item: items) {
items_str.append(item.id);
}
VLOG(2) << "items: " << items_str;
} }
// 排序后都是相同的
ASSERT_TRUE(is_same_with_std);
} }
// n_run次有不同的(证明是多线程) // n_run次有不同的(证明是多线程)
ASSERT_EQ(4, omp_get_max_threads()); ASSERT_EQ(4, omp_get_max_threads());
ASSERT_GT(n_run, same_count); ASSERT_GT(n_run, same_count);
// 但排序后都是相同的
ASSERT_EQ(n_run, sort_same_count);
} }
} // namespace feed } // namespace feed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册