提交 379235f4 编写于 作者: R rensilin

file_system_ut

Change-Id: I96c0bc535a0f49a92e8b987df5cc06c5eca4758e
上级 501c9f25
......@@ -4,13 +4,13 @@
#include <glog/logging.h>
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class LineDataParser : public DataParser{
class LineDataParser : public DataParser {
public:
LineDataParser() {}
......@@ -29,7 +29,7 @@ public:
VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
return -1;
}
VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
data.id.assign(str, pos);
data.data.assign(str + pos + 1, len - pos - 1);
if (!data.data.empty() && data.data.back() == '\n') {
......@@ -47,7 +47,7 @@ public:
VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos;
return -1;
}
VLOG(5) << "getline: " << str << " , pos: " << pos;
VLOG(5) << "getline: " << str << " , pos: " << pos;
data.id.assign(str, pos);
data.data.assign(str + pos + 1);
if (!data.data.empty() && data.data.back() == '\n') {
......@@ -88,13 +88,30 @@ public:
_buffer_size = config["buffer_size"].as<int>(1024);
_filename_prefix = config["filename_prefix"].as<std::string>("");
_buffer.reset(new char[_buffer_size]);
if (config["file_system"] && config["file_system"]["class"]) {
_file_system.reset(
CREATE_CLASS(FileSystem, config["file_system"]["class"].as<std::string>()));
if (_file_system == nullptr ||
_file_system->initialize(config["file_system"], context) != 0) {
VLOG(2) << "fail to create class: "
<< config["file_system"]["class"].as<std::string>();
return -1;
}
} else {
_file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
VLOG(2) << "fail to init file system";
return -1;
}
}
return 0;
}
//判断样本数据是否已就绪,就绪表明可以开始download
virtual bool is_data_ready(const std::string& data_dir) {
auto done_file_path = framework::fs_path_join(data_dir, _done_file_name);
if (framework::fs_exists(done_file_path)) {
auto done_file_path = _file_system->path_join(data_dir, _done_file_name);
if (_file_system->exists(done_file_path)) {
return true;
}
return false;
......@@ -102,12 +119,13 @@ public:
virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
if (_filename_prefix.empty()) {
return framework::fs_list(data_dir);
return _file_system->list(data_dir);
}
std::vector<std::string> data_files;
for (auto& filepath : framework::fs_list(data_dir)) {
auto filename = framework::fs_path_split(filepath).second;
if (filename.size() >= _filename_prefix.size() && filename.substr(0, _filename_prefix.size()) == _filename_prefix) {
for (auto& filepath : _file_system->list(data_dir)) {
auto filename = _file_system->path_split(filepath).second;
if (filename.size() >= _filename_prefix.size() &&
filename.substr(0, _filename_prefix.size()) == _filename_prefix) {
data_files.push_back(std::move(filepath));
}
}
......@@ -116,35 +134,50 @@ public:
//读取数据样本流中
virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
framework::ChannelWriter<DataItem> writer(data_channel.get());
auto deleter = [](framework::ChannelWriter<DataItem> *writer) {
if (writer) {
writer->Flush();
VLOG(3) << "writer auto flush";
}
delete writer;
};
std::unique_ptr<framework::ChannelWriter<DataItem>, decltype(deleter)> writer(new framework::ChannelWriter<DataItem>(data_channel.get()), deleter);
DataItem data_item;
if (_buffer_size <= 0 || _buffer == nullptr) {
VLOG(2) << "no buffer";
return -1;
}
for (const auto& filepath : data_file_list(data_dir)) {
if (framework::fs_path_split(filepath).second == _done_file_name) {
if (_file_system->path_split(filepath).second == _done_file_name) {
continue;
}
int err_no = 0;
std::shared_ptr<FILE> fin = framework::fs_open_read(filepath, &err_no, _pipeline_cmd);
if (err_no != 0) {
VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
return -1;
}
while (fgets(_buffer.get(), _buffer_size, fin.get())) {
if (_parser->parse(_buffer.get(), data_item) != 0) {
{
std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
if (fin == nullptr) {
VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
return -1;
}
while (fgets(_buffer.get(), _buffer_size, fin.get())) {
if (_buffer[0] == '\n') {
continue;
}
if (_parser->parse(_buffer.get(), data_item) != 0) {
return -1;
}
(*writer) << std::move(data_item);
}
if (ferror(fin.get()) != 0) {
VLOG(2) << "fail to read file: " << filepath;
return -1;
}
writer << std::move(data_item);
}
if (ferror(fin.get()) != 0) {
VLOG(2) << "fail to read file: " << filepath;
if (!_file_system) {
_file_system->reset_err_no();
return -1;
}
}
writer.Flush();
if (!writer) {
writer->Flush();
if (!(*writer)) {
VLOG(2) << "fail when write to channel";
return -1;
}
......@@ -155,14 +188,16 @@ public:
virtual const DataParser* get_parser() {
return _parser.get();
}
private:
std::string _done_file_name; // without data_dir
std::string _done_file_name; // without data_dir
int _buffer_size = 0;
std::unique_ptr<char[]> _buffer;
std::string _filename_prefix;
std::unique_ptr<FileSystem> _file_system;
};
REGISTER_CLASS(DataReader, LineDataReader);
}//namespace feed
}//namespace custom_trainer
}//namespace paddle
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -29,8 +29,8 @@ class AutoFileSystem : public FileSystem {
public:
int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) override {
_file_system.clear();
if (config) {
for (auto& prefix_fs: config) {
if (config && config["file_systems"] && config["file_systems"].Type() == YAML::NodeType::Map) {
for (auto& prefix_fs: config["file_systems"]) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as<std::string>("")));
if (fs == nullptr) {
VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as<std::string>("");
......
......@@ -16,9 +16,11 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <tuple>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/piece.h"
#include "glog/logging.h"
namespace paddle {
......@@ -31,8 +33,10 @@ public:
_buffer_size = config["buffer_size"].as<size_t>(0);
_hdfs_command = config["hdfs_command"].as<std::string>("hadoop fs");
_ugi.clear();
for (const auto& prefix_ugi : config["ugi"]) {
_ugi.emplace(prefix_ugi.first.as<std::string>(), prefix_ugi.second.as<std::string>());
if (config["ugis"] && config["ugis"].Type() == YAML::NodeType::Map) {
for (const auto& prefix_ugi : config["ugis"]) {
_ugi.emplace(prefix_ugi.first.as<std::string>(), prefix_ugi.second.as<std::string>());
}
}
if (_ugi.find("default") == _ugi.end()) {
VLOG(2) << "fail to load default ugi";
......@@ -48,8 +52,7 @@ public:
cmd = string::format_string(
"%s -text \"%s\"", hdfs_command(path).c_str(), path.c_str());
} else {
cmd = string::format_string(
"%s -cat \"%s\"", hdfs_command(path).c_str(), path.c_str());
cmd = string::format_string("%s -cat \"%s\"", hdfs_command(path).c_str(), path.c_str());
}
bool is_pipe = true;
......@@ -59,7 +62,8 @@ public:
std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter)
override {
std::string cmd = string::format_string("%s -put - \"%s\"", hdfs_command(path).c_str(), path.c_str());
std::string cmd =
string::format_string("%s -put - \"%s\"", hdfs_command(path).c_str(), path.c_str());
bool is_pipe = true;
if (string::end_with(path, ".gz\"")) {
......@@ -89,12 +93,8 @@ public:
if (path == "") {
return {};
}
auto paths = _split_path(path);
std::string prefix = "hdfs:";
if (string::begin_with(path, "afs:")) {
prefix = "afs:";
}
int err_no = 0;
std::vector<std::string> list;
do {
......@@ -115,7 +115,7 @@ public:
if (line.size() != 8) {
continue;
}
list.push_back(prefix + line[7]);
list.push_back(_get_prefix(paths) + line[7]);
}
} while (err_no == -1);
return list;
......@@ -146,30 +146,60 @@ public:
return;
}
shell_execute(
string::format_string("%s -mkdir %s; true", hdfs_command(path).c_str(), path.c_str()));
shell_execute(string::format_string(
"%s -mkdir %s; true", hdfs_command(path).c_str(), path.c_str()));
}
std::string hdfs_command(const std::string& path) {
auto start_pos = path.find_first_of(':');
auto end_pos = path.find_first_of('/');
if (start_pos != std::string::npos && end_pos != std::string::npos && start_pos < end_pos) {
auto fs_path = path.substr(start_pos + 1, end_pos - start_pos - 1);
auto ugi_it = _ugi.find(fs_path);
if (ugi_it != _ugi.end()) {
return hdfs_command_with_ugi(ugi_it->second);
}
auto paths = _split_path(path);
auto it = _ugi.find(std::get<1>(paths).ToString());
if (it != _ugi.end()) {
return hdfs_command_with_ugi(it->second);
}
VLOG(5) << "path: " << path << ", select default ugi";
return hdfs_command_with_ugi(_ugi["default"]);
}
std::string hdfs_command_with_ugi(std::string ugi) {
return string::format_string("%s -Dhadoop.job.ugi=\"%s\"", _hdfs_command.c_str(), ugi.c_str());
return string::format_string(
"%s -Dhadoop.job.ugi=\"%s\"", _hdfs_command.c_str(), ugi.c_str());
}
private:
std::string _get_prefix(const std::tuple<string::Piece, string::Piece, string::Piece>& paths) {
if (std::get<1>(paths).len() == 0) {
return std::get<0>(paths).ToString();
}
return std::get<0>(paths).ToString() + "//" + std::get<1>(paths).ToString();
}
std::tuple<string::Piece, string::Piece, string::Piece> _split_path(string::Piece path) {
// parse "xxx://abc.def:8756/user" as "xxx:", "abc.def:8756", "/user"
// parse "xxx:/user" as "xxx:", "", "/user"
// parse "xxx://abc.def:8756" as "xxx:", "abc.def:8756", ""
// parse "other" as "", "", "other"
std::tuple<string::Piece, string::Piece, string::Piece> result{string::SubStr(path, 0, 0), string::SubStr(path, 0, 0), path};
auto fs_pos = string::Find(path, ':', 0) + 1;
if (path.len() > fs_pos) {
std::get<0>(result) = string::SubStr(path, 0, fs_pos);
path = string::SkipPrefix(path, fs_pos);
if (string::HasPrefix(path, "//")) {
path = string::SkipPrefix(path, 2);
auto end_pos = string::Find(path, '/', 0);
if (end_pos != string::Piece::npos) {
std::get<1>(result) = string::SubStr(path, 0, end_pos);
std::get<2>(result) = string::SkipPrefix(path, end_pos);
} else {
std::get<1>(result) = path;
}
} else {
std::get<2>(result) = path;
}
}
return result;
}
size_t _buffer_size = 0;
std::string _hdfs_command;
std::unordered_map<std::string, std::string> _ugi;
......
......@@ -19,7 +19,8 @@ limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
......@@ -31,50 +32,50 @@ namespace {
const char test_data_dir[] = "test_data";
}
class DataReaderTest : public testing::Test
{
class DataReaderTest : public testing::Test {
public:
static void SetUpTestCase()
{
framework::shell_set_verbose(true);
framework::localfs_mkdir(test_data_dir);
static void SetUpTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir);
shell_set_verbose(true);
{
std::ofstream fout(framework::fs_path_join(test_data_dir, "a.txt"));
std::ofstream fout(fs->path_join(test_data_dir, "a.txt"));
fout << "abc 123456" << std::endl;
fout << "def 234567" << std::endl;
fout.close();
}
{
std::ofstream fout(framework::fs_path_join(test_data_dir, "b.txt"));
std::ofstream fout(fs->path_join(test_data_dir, "b.txt"));
fout << "ghi 345678" << std::endl;
fout << "jkl 456789" << std::endl;
fout.close();
}
}
static void TearDownTestCase()
{
framework::localfs_remove(test_data_dir);
static void TearDownTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir);
}
virtual void SetUp()
{
virtual void SetUp() {
fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
context_ptr.reset(new TrainerContext());
}
virtual void TearDown()
{
virtual void TearDown() {
fs = nullptr;
context_ptr = nullptr;
}
std::shared_ptr<TrainerContext> context_ptr;
std::unique_ptr<FileSystem> fs;
};
TEST_F(DataReaderTest, LineDataParser) {
std::unique_ptr<DataParser> data_parser(CREATE_CLASS(DataParser, "LineDataParser"));
ASSERT_NE(nullptr, data_parser);
auto config = YAML::Load("");
......@@ -105,11 +106,12 @@ TEST_F(DataReaderTest, LineDataReader) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load("parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"buffer_size: 128");
auto config = YAML::Load(
"parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"buffer_size: 128");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(2, data_file_list.size());
......@@ -117,7 +119,7 @@ TEST_F(DataReaderTest, LineDataReader) {
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "b.txt"), data_file_list[1]);
ASSERT_FALSE(data_reader->is_data_ready(test_data_dir));
std::ofstream fout(framework::fs_path_join(test_data_dir, "done_file"));
std::ofstream fout(fs->path_join(test_data_dir, "done_file"));
fout << "done";
fout.close();
ASSERT_TRUE(data_reader->is_data_ready(test_data_dir));
......@@ -128,7 +130,7 @@ TEST_F(DataReaderTest, LineDataReader) {
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str());
......@@ -156,23 +158,24 @@ TEST_F(DataReaderTest, LineDataReader) {
TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load("parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"filename_prefix: a");
auto config = YAML::Load(
"parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"filename_prefix: a");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str());
......@@ -187,6 +190,84 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
ASSERT_FALSE(reader);
}
TEST_F(DataReaderTest, LineDataReader_FileSystem) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
"parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"filename_prefix: a\n"
"file_system:\n"
" class: AutoFileSystem\n"
" file_systems:\n"
" 'afs:': &HDFS \n"
" class: HadoopFileSystem\n"
" hdfs_command: 'hadoop fs'\n"
" ugis:\n"
" 'default': 'feed_video,D3a0z8'\n"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'\n"
" \n"
" 'hdfs:': *HDFS\n");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
{
auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("def", data_item.id.c_str());
ASSERT_STREQ("234567", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
{
char test_hadoop_dir[] = "afs://xingtian.afs.baidu.com:9902/user/feed_video/user/rensilin/paddle_trainer_test_dir";
ASSERT_TRUE(data_reader->is_data_ready(test_hadoop_dir));
auto data_file_list = data_reader->data_file_list(test_hadoop_dir);
ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_hadoop_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_hadoop_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("hello", data_item.id.c_str());
ASSERT_STREQ("world", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("hello", data_item.id.c_str());
ASSERT_STREQ("hadoop", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -19,7 +19,8 @@ limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
......@@ -37,7 +38,9 @@ class SimpleExecutorTest : public testing::Test
public:
static void SetUpTestCase()
{
::paddle::framework::localfs_mkdir(test_data_dir);
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir);
shell_set_verbose(true);
{
std::unique_ptr<paddle::framework::ProgramDesc> startup_program(
......@@ -67,7 +70,8 @@ public:
static void TearDownTestCase()
{
::paddle::framework::localfs_remove(test_data_dir);
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir);
}
virtual void SetUp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册