From fae7e884c0fea246df7358a9973a3a4979e04d26 Mon Sep 17 00:00:00 2001 From: rensilin Date: Mon, 5 Aug 2019 12:36:30 +0800 Subject: [PATCH] finish datareader ut Change-Id: I1b0544bb2b844a47d7434963c21b05c134e1da80 --- .../custom_trainer/feed/dataset/data_reader.cc | 13 ++++++++++--- .../feed/unit_test/test_datareader.cc | 13 +++++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc index 5ffe20b6..932485ad 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc @@ -26,12 +26,15 @@ public: ++pos; } if (pos >= len) { - VLOG(2) << "fail to parse line" << std::string(str, len) << ", strlen: " << len; + VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len; return -1; } VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len; data.id.assign(str, pos); data.data.assign(str + pos + 1, len - pos - 1); + if (!data.data.empty() && data.data.back() == '\n') { + data.data.pop_back(); + } return 0; } @@ -41,12 +44,15 @@ public: ++pos; } if (str[pos] == '\0') { - VLOG(2) << "fail to parse line" << str << ", get '\\0' at pos: " << pos; + VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos; return -1; } VLOG(5) << "getline: " << str << " , pos: " << pos; data.id.assign(str, pos); data.data.assign(str + pos + 1); + if (!data.data.empty() && data.data.back() == '\n') { + data.data.pop_back(); + } return 0; } @@ -105,7 +111,7 @@ public: if (::paddle::framework::fs_path_split(filename).second == _done_file_name) { continue; } - int err_no; + int err_no = 0; std::shared_ptr fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd); if (err_no != 0) { VLOG(2) << "fail to open file: " << filename << ", with cmd: " << _pipeline_cmd; @@ -127,6 +133,7 @@ public: VLOG(2) << "fail when write to channel"; return -1; } + data_channel->Close(); return 0; } diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc index 4834f851..a1114e04 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc @@ -36,17 +36,18 @@ class DataReaderTest : public testing::Test public: static void SetUpTestCase() { - ::paddle::framework::localfs_mkdir(test_data_dir); + framework::shell_set_verbose(true); + framework::localfs_mkdir(test_data_dir); { - std::ofstream fout(::paddle::framework::fs_path_join(test_data_dir, "a.txt")); + std::ofstream fout(framework::fs_path_join(test_data_dir, "a.txt")); fout << "abc 123456" << std::endl; fout << "def 234567" << std::endl; fout.close(); } { - std::ofstream fout(::paddle::framework::fs_path_join(test_data_dir, "b.txt")); + std::ofstream fout(framework::fs_path_join(test_data_dir, "b.txt")); fout << "ghi 345678" << std::endl; fout << "jkl 456789" << std::endl; fout.close(); @@ -55,7 +56,7 @@ public: static void TearDownTestCase() { - ::paddle::framework::localfs_remove(test_data_dir); + framework::localfs_remove(test_data_dir); } virtual void SetUp() @@ -118,12 +119,12 @@ TEST_F(DataReaderTest, LineDataReader) { ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); ASSERT_FALSE(data_reader->is_data_ready(test_data_dir)); - std::ofstream fout(::paddle::framework::fs_path_join(test_data_dir, "done_file")); + std::ofstream fout(framework::fs_path_join(test_data_dir, "done_file")); fout << "done"; fout.close(); ASSERT_TRUE(data_reader->is_data_ready(test_data_dir)); - auto channel = ::paddle::framework::MakeChannel(); + auto channel = framework::MakeChannel(128); ASSERT_NE(nullptr, channel); ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel)); -- GitLab