From d5f4d04b3a9d94b5a82656ea12ad2adf6bbefefe Mon Sep 17 00:00:00 2001 From: rensilin Date: Fri, 2 Aug 2019 20:55:34 +0800 Subject: [PATCH] data_reader_ut Change-Id: Ibf99267a4d5b7196832d438a4f964f2625b616c8 --- BCLOUD | 2 +- paddle/fluid/framework/channel.h | 2 +- .../feed/dataset/data_reader.cc | 23 +-- .../feed/unit_test/test_datareader.cc | 159 ++++++++++++++++++ .../feed/unit_test/test_executor.cc | 11 +- 5 files changed, 182 insertions(+), 15 deletions(-) create mode 100644 paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc diff --git a/BCLOUD b/BCLOUD index 4162dcbe..61cdb29b 100644 --- a/BCLOUD +++ b/BCLOUD @@ -94,4 +94,4 @@ Application('feed_trainer', Sources('paddle/fluid/train/custom_trainer/feed/main #feed unit test UT_MAIN = UT_FILE('main.cc') -UTApplication('test_executor', Sources(UT_MAIN, UT_FILE('test_executor.cc'), custom_trainer_src), CppFlags(CPPFLAGS_STR), CFlags(CFLAGS_STR), CxxFlags(CXXFLAGS_STR), Libs(src_libs=['paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so'])) +UTApplication('unit_test', Sources(UT_MAIN, UT_FILE('test_executor.cc'), UT_FILE('test_datareader.cc'), custom_trainer_src), CppFlags(CPPFLAGS_STR), CFlags(CFLAGS_STR), CxxFlags(CXXFLAGS_STR), Libs(src_libs=['paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so'])) diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index 644f60db..f0658ed1 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -332,7 +332,7 @@ class ChannelReader { } if (cursor_ >= buffer_.size()) { cursor_ = 0; - if (channel_->read(buffer_) == 0) { + if (channel_->Read(buffer_) == 0) { failed_ = true; return *this; } diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc index 8c857ad8..5ffe20b6 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc @@ -22,13 +22,14 @@ public: virtual int parse(const char* str, size_t len, DataItem& data) const { size_t pos = 0; - while (str[pos] != ' ') { - if (pos >= len) { - VLOG(2) << "fail to parse line, strlen: " << len; - return -1; - } + while (pos < len && str[pos] != ' ') { ++pos; } + if (pos >= len) { + VLOG(2) << "fail to parse line" << std::string(str, len) << ", strlen: " << len; + return -1; + } + VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len; data.id.assign(str, pos); data.data.assign(str + pos + 1, len - pos - 1); return 0; @@ -36,13 +37,14 @@ public: virtual int parse(const char* str, DataItem& data) const { size_t pos = 0; - while (str[pos] != ' ') { - if (str[pos] == '\0') { - VLOG(2) << "fail to parse line, get '\\0' at pos: " << pos; - return -1; - } + while (str[pos] != '\0' && str[pos] != ' ') { ++pos; } + if (str[pos] == '\0') { + VLOG(2) << "fail to parse line" << str << ", get '\\0' at pos: " << pos; + return -1; + } + VLOG(5) << "getline: " << str << " , pos: " << pos; data.id.assign(str, pos); data.data.assign(str + pos + 1); return 0; @@ -106,6 +108,7 @@ public: int err_no; std::shared_ptr fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd); if (err_no != 0) { + VLOG(2) << "fail to open file: " << filename << ", with cmd: " << _pipeline_cmd; return -1; } while (fgets(_buffer.get(), _buffer_size, fin.get())) { diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc new file mode 100644 index 00000000..4834f851 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc @@ -0,0 +1,159 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/string/string_helper.h" +#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +namespace { +const char test_data_dir[] = "test_data"; +} + +class DataReaderTest : public testing::Test +{ +public: + static void SetUpTestCase() + { + ::paddle::framework::localfs_mkdir(test_data_dir); + + { + std::ofstream fout(::paddle::framework::fs_path_join(test_data_dir, "a.txt")); + fout << "abc 123456" << std::endl; + fout << "def 234567" << std::endl; + fout.close(); + } + + { + std::ofstream fout(::paddle::framework::fs_path_join(test_data_dir, "b.txt")); + fout << "ghi 345678" << std::endl; + fout << "jkl 456789" << std::endl; + fout.close(); + } + } + + static void TearDownTestCase() + { + ::paddle::framework::localfs_remove(test_data_dir); + } + + virtual void SetUp() + { + context_ptr.reset(new TrainerContext()); + } + + virtual void TearDown() + { + context_ptr = nullptr; + } + + std::shared_ptr context_ptr; +}; + +TEST_F(DataReaderTest, LineDataParser) { + std::unique_ptr data_parser(CREATE_CLASS(DataParser, "LineDataParser")); + + ASSERT_NE(nullptr, data_parser); + auto config = YAML::Load(""); + + ASSERT_EQ(0, data_parser->initialize(config, context_ptr)); + + DataItem data_item; + ASSERT_NE(0, data_parser->parse(std::string("1abcd123456"), data_item)); + ASSERT_EQ(0, data_parser->parse(std::string("2abc 123456"), data_item)); + ASSERT_STREQ("2abc", data_item.id.c_str()); + ASSERT_STREQ("123456", data_item.data.c_str()); + + ASSERT_NE(0, data_parser->parse("3abcd123456", data_item)); + ASSERT_EQ(0, data_parser->parse("4abc 123456", data_item)); + ASSERT_STREQ("4abc", data_item.id.c_str()); + ASSERT_STREQ("123456", data_item.data.c_str()); + + ASSERT_NE(0, data_parser->parse("5abc 123456", 4, data_item)); + ASSERT_EQ(0, data_parser->parse("6abc 123456", 5, data_item)); + ASSERT_STREQ("6abc", data_item.id.c_str()); + ASSERT_STREQ("", data_item.data.c_str()); + + ASSERT_EQ(0, data_parser->parse("7abc 123456", 8, data_item)); + ASSERT_STREQ("7abc", data_item.id.c_str()); + ASSERT_STREQ("123", data_item.data.c_str()); +} + +TEST_F(DataReaderTest, LineDataReader) { + std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + ASSERT_NE(nullptr, data_reader); + + YAML::Node config = YAML::Load("parser:\n" + " class: LineDataParser\n" + "pipeline_cmd: cat\n" + "done_file: done_file"); + ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); + + config = YAML::Load("parser:\n" + " class: LineDataParser\n" + "pipeline_cmd: cat\n" + "done_file: done_file\n" + "buffer_size: 128"); + ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); + + ASSERT_FALSE(data_reader->is_data_ready(test_data_dir)); + std::ofstream fout(::paddle::framework::fs_path_join(test_data_dir, "done_file")); + fout << "done"; + fout.close(); + ASSERT_TRUE(data_reader->is_data_ready(test_data_dir)); + + auto channel = ::paddle::framework::MakeChannel(); + ASSERT_NE(nullptr, channel); + ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel)); + + framework::ChannelReader reader(channel.get()); + DataItem data_item; + + reader >> data_item; + ASSERT_TRUE(reader); + ASSERT_STREQ("abc", data_item.id.c_str()); + ASSERT_STREQ("123456", data_item.data.c_str()); + + reader >> data_item; + ASSERT_TRUE(reader); + ASSERT_STREQ("def", data_item.id.c_str()); + ASSERT_STREQ("234567", data_item.data.c_str()); + + reader >> data_item; + ASSERT_TRUE(reader); + ASSERT_STREQ("ghi", data_item.id.c_str()); + ASSERT_STREQ("345678", data_item.data.c_str()); + + reader >> data_item; + ASSERT_TRUE(reader); + ASSERT_STREQ("jkl", data_item.id.c_str()); + ASSERT_STREQ("456789", data_item.data.c_str()); + + reader >> data_item; + ASSERT_FALSE(reader); +} + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc index ec5cac4b..5410cb9a 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc @@ -20,14 +20,17 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace custom_trainer { namespace feed { +namespace { const char test_data_dir[] = "test_data"; const char main_program_path[] = "test_data/main_program"; const char startup_program_path[] = "test_data/startup_program"; +} class SimpleExecutorTest : public testing::Test { @@ -82,18 +85,20 @@ public: TEST_F(SimpleExecutorTest, initialize) { std::unique_ptr executor(CREATE_CLASS(Executor, "SimpleExecutor")); + ASSERT_NE(nullptr, executor); YAML::Node config = YAML::Load("[1, 2, 3]"); ASSERT_NE(0, executor->initialize(config, context_ptr)); - config = YAML::Load(std::string() + "{startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}"); + config = YAML::Load(string::format_string("{startup_program: %s, main_program: %s}", startup_program_path, main_program_path)); ASSERT_EQ(0, executor->initialize(config, context_ptr)); - config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}"); + config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path)); ASSERT_EQ(0, executor->initialize(config, context_ptr)); } TEST_F(SimpleExecutorTest, run) { std::unique_ptr executor(CREATE_CLASS(Executor, "SimpleExecutor")); + ASSERT_NE(nullptr, executor); - auto config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}"); + auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path)); ASSERT_EQ(0, executor->initialize(config, context_ptr)); auto x_var = executor->mutable_var<::paddle::framework::LoDTensor>("x"); -- GitLab