提交 d5f4d04b 编写于 作者: R rensilin

data_reader_ut

Change-Id: Ibf99267a4d5b7196832d438a4f964f2625b616c8
上级 17e0cb7c
......@@ -94,4 +94,4 @@ Application('feed_trainer', Sources('paddle/fluid/train/custom_trainer/feed/main
#feed unit test
UT_MAIN = UT_FILE('main.cc')
UTApplication('test_executor', Sources(UT_MAIN, UT_FILE('test_executor.cc'), custom_trainer_src), CppFlags(CPPFLAGS_STR), CFlags(CFLAGS_STR), CxxFlags(CXXFLAGS_STR), Libs(src_libs=['paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so']))
UTApplication('unit_test', Sources(UT_MAIN, UT_FILE('test_executor.cc'), UT_FILE('test_datareader.cc'), custom_trainer_src), CppFlags(CPPFLAGS_STR), CFlags(CFLAGS_STR), CxxFlags(CXXFLAGS_STR), Libs(src_libs=['paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so']))
......@@ -332,7 +332,7 @@ class ChannelReader {
}
if (cursor_ >= buffer_.size()) {
cursor_ = 0;
if (channel_->read(buffer_) == 0) {
if (channel_->Read(buffer_) == 0) {
failed_ = true;
return *this;
}
......
......@@ -22,13 +22,14 @@ public:
virtual int parse(const char* str, size_t len, DataItem& data) const {
size_t pos = 0;
while (str[pos] != ' ') {
while (pos < len && str[pos] != ' ') {
++pos;
}
if (pos >= len) {
VLOG(2) << "fail to parse line, strlen: " << len;
VLOG(2) << "fail to parse line" << std::string(str, len) << ", strlen: " << len;
return -1;
}
++pos;
}
VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
data.id.assign(str, pos);
data.data.assign(str + pos + 1, len - pos - 1);
return 0;
......@@ -36,13 +37,14 @@ public:
virtual int parse(const char* str, DataItem& data) const {
size_t pos = 0;
while (str[pos] != ' ') {
while (str[pos] != '\0' && str[pos] != ' ') {
++pos;
}
if (str[pos] == '\0') {
VLOG(2) << "fail to parse line, get '\\0' at pos: " << pos;
VLOG(2) << "fail to parse line" << str << ", get '\\0' at pos: " << pos;
return -1;
}
++pos;
}
VLOG(5) << "getline: " << str << " , pos: " << pos;
data.id.assign(str, pos);
data.data.assign(str + pos + 1);
return 0;
......@@ -106,6 +108,7 @@ public:
int err_no;
std::shared_ptr<FILE> fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd);
if (err_no != 0) {
VLOG(2) << "fail to open file: " << filename << ", with cmd: " << _pipeline_cmd;
return -1;
}
while (fgets(_buffer.get(), _buffer_size, fin.get())) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <fstream>
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
namespace {
const char test_data_dir[] = "test_data";
}
class DataReaderTest : public testing::Test
{
public:
static void SetUpTestCase()
{
::paddle::framework::localfs_mkdir(test_data_dir);
{
std::ofstream fout(::paddle::framework::fs_path_join(test_data_dir, "a.txt"));
fout << "abc 123456" << std::endl;
fout << "def 234567" << std::endl;
fout.close();
}
{
std::ofstream fout(::paddle::framework::fs_path_join(test_data_dir, "b.txt"));
fout << "ghi 345678" << std::endl;
fout << "jkl 456789" << std::endl;
fout.close();
}
}
static void TearDownTestCase()
{
::paddle::framework::localfs_remove(test_data_dir);
}
virtual void SetUp()
{
context_ptr.reset(new TrainerContext());
}
virtual void TearDown()
{
context_ptr = nullptr;
}
std::shared_ptr<TrainerContext> context_ptr;
};
TEST_F(DataReaderTest, LineDataParser) {
std::unique_ptr<DataParser> data_parser(CREATE_CLASS(DataParser, "LineDataParser"));
ASSERT_NE(nullptr, data_parser);
auto config = YAML::Load("");
ASSERT_EQ(0, data_parser->initialize(config, context_ptr));
DataItem data_item;
ASSERT_NE(0, data_parser->parse(std::string("1abcd123456"), data_item));
ASSERT_EQ(0, data_parser->parse(std::string("2abc 123456"), data_item));
ASSERT_STREQ("2abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
ASSERT_NE(0, data_parser->parse("3abcd123456", data_item));
ASSERT_EQ(0, data_parser->parse("4abc 123456", data_item));
ASSERT_STREQ("4abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
ASSERT_NE(0, data_parser->parse("5abc 123456", 4, data_item));
ASSERT_EQ(0, data_parser->parse("6abc 123456", 5, data_item));
ASSERT_STREQ("6abc", data_item.id.c_str());
ASSERT_STREQ("", data_item.data.c_str());
ASSERT_EQ(0, data_parser->parse("7abc 123456", 8, data_item));
ASSERT_STREQ("7abc", data_item.id.c_str());
ASSERT_STREQ("123", data_item.data.c_str());
}
TEST_F(DataReaderTest, LineDataReader) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
YAML::Node config = YAML::Load("parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
config = YAML::Load("parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"buffer_size: 128");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
ASSERT_FALSE(data_reader->is_data_ready(test_data_dir));
std::ofstream fout(::paddle::framework::fs_path_join(test_data_dir, "done_file"));
fout << "done";
fout.close();
ASSERT_TRUE(data_reader->is_data_ready(test_data_dir));
auto channel = ::paddle::framework::MakeChannel<DataItem>();
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("def", data_item.id.c_str());
ASSERT_STREQ("234567", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("ghi", data_item.id.c_str());
ASSERT_STREQ("345678", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("jkl", data_item.id.c_str());
ASSERT_STREQ("456789", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -20,14 +20,17 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
namespace {
const char test_data_dir[] = "test_data";
const char main_program_path[] = "test_data/main_program";
const char startup_program_path[] = "test_data/startup_program";
}
class SimpleExecutorTest : public testing::Test
{
......@@ -82,18 +85,20 @@ public:
TEST_F(SimpleExecutorTest, initialize) {
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
ASSERT_NE(nullptr, executor);
YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, executor->initialize(config, context_ptr));
config = YAML::Load(std::string() + "{startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
config = YAML::Load(string::format_string("{startup_program: %s, main_program: %s}", startup_program_path, main_program_path));
ASSERT_EQ(0, executor->initialize(config, context_ptr));
config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path));
ASSERT_EQ(0, executor->initialize(config, context_ptr));
}
TEST_F(SimpleExecutorTest, run) {
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
ASSERT_NE(nullptr, executor);
auto config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path));
ASSERT_EQ(0, executor->initialize(config, context_ptr));
auto x_var = executor->mutable_var<::paddle::framework::LoDTensor>("x");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册