提交 db4a4db0 编写于 作者: R rensilin

fs_bug

Change-Id: I7e92af98dc56e18b79640f070b13a26f9c94ab52
上级 6a15698f
......@@ -131,23 +131,28 @@ public:
return read_all(file_list, data_channel);
}
virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) {
auto deleter = [](framework::ChannelWriter<DataItem> *writer) {
if (writer) {
writer->Flush();
VLOG(3) << "writer auto flush";
}
delete writer;
};
std::unique_ptr<framework::ChannelWriter<DataItem>, decltype(deleter)> writer(new framework::ChannelWriter<DataItem>(data_channel.get()), deleter);
DataItem data_item;
int file_list_size = file_list.size();
const int file_list_size = file_list.size();
std::atomic<bool> is_failed(false);
const int max_threads = omp_get_max_threads();
std::vector<framework::ChannelWriter<DataItem>> writers; // writer is not thread safe
writers.reserve(max_threads);
for (int i = 0; i < max_threads; ++i) {
writers.emplace_back(data_channel.get());
}
VLOG(5) << "file_list: " << string::join_strings(file_list, ' ');
#pragma omp parallel for
for (int i = 0; i < file_list_size; ++i) {
if (is_failed) {
continue;
}
const int thread_num = omp_get_thread_num();
framework::ChannelWriter<DataItem> *writer = nullptr;
if (thread_num < max_threads) {
writer = &writers[thread_num];
}
const auto& filepath = file_list[i];
if (!is_failed) {
std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
if (fin == nullptr) {
VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
......@@ -158,16 +163,25 @@ public:
size_t buffer_size = 0;
ssize_t line_len = 0;
while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
// 去掉行位回车
if (line_len > 0 && buffer[line_len - 1] == '\n') {
buffer[--line_len] = '\0';
}
// 忽略空行
if (line_len <= 0) {
continue;
}
if (_parser->parse(buffer, line_len, data_item) == 0) {
VLOG(5) << "parse data: " << data_item.id << " " << data_item.data << ", filename: " << filepath << ", thread_num: " << thread_num << ", max_threads: " << max_threads;
if (writer == nullptr) {
if (!data_channel->Put(std::move(data_item))) {
VLOG(2) << "fail to put data, thread_num: " << thread_num;
}
} else {
(*writer) << std::move(data_item);
}
}
}
if (buffer != nullptr) {
free(buffer);
buffer = nullptr;
......@@ -178,18 +192,21 @@ public:
is_failed = true;
continue;
}
}
if (_file_system->err_no() != 0) {
_file_system->reset_err_no();
is_failed = true;
continue;
}
}
writer->Flush();
if (!(*writer)) {
VLOG(2) << "fail when write to channel";
// omp end
for (int i = 0; i < max_threads; ++i) {
writers[i].Flush();
if (!writers[i]) {
VLOG(2) << "writer " << i << " is failed";
is_failed = true;
}
}
data_channel->Close();
return is_failed ? -1 : 0;
}
......
......@@ -38,6 +38,9 @@ class DataReaderOmpTest : public testing::Test {
public:
static void SetUpTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
if (fs->exists(test_data_dir)) {
fs->remove(test_data_dir);
}
fs->mkdir(test_data_dir);
shell_set_verbose(true);
std_items.clear();
......@@ -92,11 +95,12 @@ public:
}
static void read_all(framework::Channel<DataItem>& channel, std::vector<DataItem>& items) {
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
while (reader >> data_item) {
items.push_back(std::move(data_item));
}
channel->ReadAll(items);
// framework::ChannelReader<DataItem> reader(channel.get());
// DataItem data_item;
// while (reader >> data_item) {
// items.push_back(std::move(data_item));
// }
}
static bool is_same_with_std_items(const std::vector<DataItem>& items) {
......@@ -107,6 +111,14 @@ public:
return is_same(items, sorted_std_items);
}
static std::string to_string(const std::vector<DataItem>& items) {
std::string items_str = "";
for (const auto& item : items) {
items_str.append(item.id);
}
return items_str;
}
static std::vector<DataItem> std_items;
static std::vector<DataItem> sorted_std_items;
std::shared_ptr<TrainerContext> context_ptr;
......@@ -137,7 +149,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
ASSERT_EQ(string::format_string("%s/%s.txt", test_data_dir, std_items[i].id.c_str()), data_file_list[i]);
}
int same_count = 0;
for (int i = 0; i < n_run; ++i) {
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
......@@ -146,13 +157,8 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
std::vector<DataItem> items;
read_all(channel, items);
if (is_same_with_std_items(items)) {
++same_count;
}
ASSERT_TRUE(is_same_with_std_items(items));
}
// n_run 次都相同
ASSERT_EQ(n_run, same_count);
}
TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
......@@ -188,36 +194,32 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
omp_set_num_threads(4);
channel->SetBlockSize(1);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
std::vector<DataItem> items;
read_all(channel, items);
ASSERT_EQ(std_items_size, items.size());
if (is_same_with_std_items(items)) {
++same_count;
}
VLOG(5) << "before sort items: " << to_string(items);
std::sort(items.begin(), items.end(), [] (const DataItem& a, const DataItem& b) {
return a.id < b.id;
});
if (is_same_with_sorted_std_items(items)) {
++sort_same_count;
} else {
std::string items_str = "";
for (const auto& item: items) {
items_str.append(item.id);
bool is_same_with_std = is_same_with_sorted_std_items(items);
if (!is_same_with_std) {
VLOG(5) << "after sort items: " << to_string(items);
}
VLOG(2) << "items: " << items_str;
}
// 排序后都是相同的
ASSERT_TRUE(is_same_with_std);
}
// n_run次有不同的(证明是多线程)
ASSERT_EQ(4, omp_get_max_threads());
ASSERT_GT(n_run, same_count);
// 但排序后都是相同的
ASSERT_EQ(n_run, sort_same_count);
}
} // namespace feed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册