提交 379235f4 编写于 作者: R rensilin

file_system_ut

Change-Id: I96c0bc535a0f49a92e8b987df5cc06c5eca4758e
上级 501c9f25
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
namespace paddle { namespace paddle {
namespace custom_trainer { namespace custom_trainer {
namespace feed { namespace feed {
class LineDataParser : public DataParser{ class LineDataParser : public DataParser {
public: public:
LineDataParser() {} LineDataParser() {}
...@@ -29,7 +29,7 @@ public: ...@@ -29,7 +29,7 @@ public:
VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len; VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
return -1; return -1;
} }
VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len; VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
data.id.assign(str, pos); data.id.assign(str, pos);
data.data.assign(str + pos + 1, len - pos - 1); data.data.assign(str + pos + 1, len - pos - 1);
if (!data.data.empty() && data.data.back() == '\n') { if (!data.data.empty() && data.data.back() == '\n') {
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos; VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos;
return -1; return -1;
} }
VLOG(5) << "getline: " << str << " , pos: " << pos; VLOG(5) << "getline: " << str << " , pos: " << pos;
data.id.assign(str, pos); data.id.assign(str, pos);
data.data.assign(str + pos + 1); data.data.assign(str + pos + 1);
if (!data.data.empty() && data.data.back() == '\n') { if (!data.data.empty() && data.data.back() == '\n') {
...@@ -88,13 +88,30 @@ public: ...@@ -88,13 +88,30 @@ public:
_buffer_size = config["buffer_size"].as<int>(1024); _buffer_size = config["buffer_size"].as<int>(1024);
_filename_prefix = config["filename_prefix"].as<std::string>(""); _filename_prefix = config["filename_prefix"].as<std::string>("");
_buffer.reset(new char[_buffer_size]); _buffer.reset(new char[_buffer_size]);
if (config["file_system"] && config["file_system"]["class"]) {
_file_system.reset(
CREATE_CLASS(FileSystem, config["file_system"]["class"].as<std::string>()));
if (_file_system == nullptr ||
_file_system->initialize(config["file_system"], context) != 0) {
VLOG(2) << "fail to create class: "
<< config["file_system"]["class"].as<std::string>();
return -1;
}
} else {
_file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
VLOG(2) << "fail to init file system";
return -1;
}
}
return 0; return 0;
} }
//判断样本数据是否已就绪,就绪表明可以开始download //判断样本数据是否已就绪,就绪表明可以开始download
virtual bool is_data_ready(const std::string& data_dir) { virtual bool is_data_ready(const std::string& data_dir) {
auto done_file_path = framework::fs_path_join(data_dir, _done_file_name); auto done_file_path = _file_system->path_join(data_dir, _done_file_name);
if (framework::fs_exists(done_file_path)) { if (_file_system->exists(done_file_path)) {
return true; return true;
} }
return false; return false;
...@@ -102,12 +119,13 @@ public: ...@@ -102,12 +119,13 @@ public:
virtual std::vector<std::string> data_file_list(const std::string& data_dir) { virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
if (_filename_prefix.empty()) { if (_filename_prefix.empty()) {
return framework::fs_list(data_dir); return _file_system->list(data_dir);
} }
std::vector<std::string> data_files; std::vector<std::string> data_files;
for (auto& filepath : framework::fs_list(data_dir)) { for (auto& filepath : _file_system->list(data_dir)) {
auto filename = framework::fs_path_split(filepath).second; auto filename = _file_system->path_split(filepath).second;
if (filename.size() >= _filename_prefix.size() && filename.substr(0, _filename_prefix.size()) == _filename_prefix) { if (filename.size() >= _filename_prefix.size() &&
filename.substr(0, _filename_prefix.size()) == _filename_prefix) {
data_files.push_back(std::move(filepath)); data_files.push_back(std::move(filepath));
} }
} }
...@@ -116,35 +134,50 @@ public: ...@@ -116,35 +134,50 @@ public:
//读取数据样本流中 //读取数据样本流中
virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) { virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
framework::ChannelWriter<DataItem> writer(data_channel.get()); auto deleter = [](framework::ChannelWriter<DataItem> *writer) {
if (writer) {
writer->Flush();
VLOG(3) << "writer auto flush";
}
delete writer;
};
std::unique_ptr<framework::ChannelWriter<DataItem>, decltype(deleter)> writer(new framework::ChannelWriter<DataItem>(data_channel.get()), deleter);
DataItem data_item; DataItem data_item;
if (_buffer_size <= 0 || _buffer == nullptr) { if (_buffer_size <= 0 || _buffer == nullptr) {
VLOG(2) << "no buffer"; VLOG(2) << "no buffer";
return -1; return -1;
} }
for (const auto& filepath : data_file_list(data_dir)) { for (const auto& filepath : data_file_list(data_dir)) {
if (framework::fs_path_split(filepath).second == _done_file_name) { if (_file_system->path_split(filepath).second == _done_file_name) {
continue; continue;
} }
int err_no = 0; {
std::shared_ptr<FILE> fin = framework::fs_open_read(filepath, &err_no, _pipeline_cmd); std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
if (err_no != 0) { if (fin == nullptr) {
VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd; VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
return -1; return -1;
} }
while (fgets(_buffer.get(), _buffer_size, fin.get())) { while (fgets(_buffer.get(), _buffer_size, fin.get())) {
if (_parser->parse(_buffer.get(), data_item) != 0) { if (_buffer[0] == '\n') {
continue;
}
if (_parser->parse(_buffer.get(), data_item) != 0) {
return -1;
}
(*writer) << std::move(data_item);
}
if (ferror(fin.get()) != 0) {
VLOG(2) << "fail to read file: " << filepath;
return -1; return -1;
} }
writer << std::move(data_item);
} }
if (ferror(fin.get()) != 0) { if (!_file_system) {
VLOG(2) << "fail to read file: " << filepath; _file_system->reset_err_no();
return -1; return -1;
} }
} }
writer.Flush(); writer->Flush();
if (!writer) { if (!(*writer)) {
VLOG(2) << "fail when write to channel"; VLOG(2) << "fail when write to channel";
return -1; return -1;
} }
...@@ -155,14 +188,16 @@ public: ...@@ -155,14 +188,16 @@ public:
virtual const DataParser* get_parser() { virtual const DataParser* get_parser() {
return _parser.get(); return _parser.get();
} }
private: private:
std::string _done_file_name; // without data_dir std::string _done_file_name; // without data_dir
int _buffer_size = 0; int _buffer_size = 0;
std::unique_ptr<char[]> _buffer; std::unique_ptr<char[]> _buffer;
std::string _filename_prefix; std::string _filename_prefix;
std::unique_ptr<FileSystem> _file_system;
}; };
REGISTER_CLASS(DataReader, LineDataReader); REGISTER_CLASS(DataReader, LineDataReader);
}//namespace feed } // namespace feed
}//namespace custom_trainer } // namespace custom_trainer
}//namespace paddle } // namespace paddle
...@@ -29,8 +29,8 @@ class AutoFileSystem : public FileSystem { ...@@ -29,8 +29,8 @@ class AutoFileSystem : public FileSystem {
public: public:
int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) override { int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) override {
_file_system.clear(); _file_system.clear();
if (config) { if (config && config["file_systems"] && config["file_systems"].Type() == YAML::NodeType::Map) {
for (auto& prefix_fs: config) { for (auto& prefix_fs: config["file_systems"]) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as<std::string>(""))); std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as<std::string>("")));
if (fs == nullptr) { if (fs == nullptr) {
VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as<std::string>(""); VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as<std::string>("");
......
...@@ -16,9 +16,11 @@ limitations under the License. */ ...@@ -16,9 +16,11 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <tuple>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h" #include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/piece.h"
#include "glog/logging.h" #include "glog/logging.h"
namespace paddle { namespace paddle {
...@@ -31,8 +33,10 @@ public: ...@@ -31,8 +33,10 @@ public:
_buffer_size = config["buffer_size"].as<size_t>(0); _buffer_size = config["buffer_size"].as<size_t>(0);
_hdfs_command = config["hdfs_command"].as<std::string>("hadoop fs"); _hdfs_command = config["hdfs_command"].as<std::string>("hadoop fs");
_ugi.clear(); _ugi.clear();
for (const auto& prefix_ugi : config["ugi"]) { if (config["ugis"] && config["ugis"].Type() == YAML::NodeType::Map) {
_ugi.emplace(prefix_ugi.first.as<std::string>(), prefix_ugi.second.as<std::string>()); for (const auto& prefix_ugi : config["ugis"]) {
_ugi.emplace(prefix_ugi.first.as<std::string>(), prefix_ugi.second.as<std::string>());
}
} }
if (_ugi.find("default") == _ugi.end()) { if (_ugi.find("default") == _ugi.end()) {
VLOG(2) << "fail to load default ugi"; VLOG(2) << "fail to load default ugi";
...@@ -48,8 +52,7 @@ public: ...@@ -48,8 +52,7 @@ public:
cmd = string::format_string( cmd = string::format_string(
"%s -text \"%s\"", hdfs_command(path).c_str(), path.c_str()); "%s -text \"%s\"", hdfs_command(path).c_str(), path.c_str());
} else { } else {
cmd = string::format_string( cmd = string::format_string("%s -cat \"%s\"", hdfs_command(path).c_str(), path.c_str());
"%s -cat \"%s\"", hdfs_command(path).c_str(), path.c_str());
} }
bool is_pipe = true; bool is_pipe = true;
...@@ -59,7 +62,8 @@ public: ...@@ -59,7 +62,8 @@ public:
std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter) std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter)
override { override {
std::string cmd = string::format_string("%s -put - \"%s\"", hdfs_command(path).c_str(), path.c_str()); std::string cmd =
string::format_string("%s -put - \"%s\"", hdfs_command(path).c_str(), path.c_str());
bool is_pipe = true; bool is_pipe = true;
if (string::end_with(path, ".gz\"")) { if (string::end_with(path, ".gz\"")) {
...@@ -89,12 +93,8 @@ public: ...@@ -89,12 +93,8 @@ public:
if (path == "") { if (path == "") {
return {}; return {};
} }
auto paths = _split_path(path);
std::string prefix = "hdfs:";
if (string::begin_with(path, "afs:")) {
prefix = "afs:";
}
int err_no = 0; int err_no = 0;
std::vector<std::string> list; std::vector<std::string> list;
do { do {
...@@ -115,7 +115,7 @@ public: ...@@ -115,7 +115,7 @@ public:
if (line.size() != 8) { if (line.size() != 8) {
continue; continue;
} }
list.push_back(prefix + line[7]); list.push_back(_get_prefix(paths) + line[7]);
} }
} while (err_no == -1); } while (err_no == -1);
return list; return list;
...@@ -146,30 +146,60 @@ public: ...@@ -146,30 +146,60 @@ public:
return; return;
} }
shell_execute( shell_execute(string::format_string(
string::format_string("%s -mkdir %s; true", hdfs_command(path).c_str(), path.c_str())); "%s -mkdir %s; true", hdfs_command(path).c_str(), path.c_str()));
} }
std::string hdfs_command(const std::string& path) { std::string hdfs_command(const std::string& path) {
auto start_pos = path.find_first_of(':'); auto paths = _split_path(path);
auto end_pos = path.find_first_of('/'); auto it = _ugi.find(std::get<1>(paths).ToString());
if (start_pos != std::string::npos && end_pos != std::string::npos && start_pos < end_pos) { if (it != _ugi.end()) {
auto fs_path = path.substr(start_pos + 1, end_pos - start_pos - 1); return hdfs_command_with_ugi(it->second);
auto ugi_it = _ugi.find(fs_path);
if (ugi_it != _ugi.end()) {
return hdfs_command_with_ugi(ugi_it->second);
}
} }
VLOG(5) << "path: " << path << ", select default ugi"; VLOG(5) << "path: " << path << ", select default ugi";
return hdfs_command_with_ugi(_ugi["default"]); return hdfs_command_with_ugi(_ugi["default"]);
} }
std::string hdfs_command_with_ugi(std::string ugi) { std::string hdfs_command_with_ugi(std::string ugi) {
return string::format_string("%s -Dhadoop.job.ugi=\"%s\"", _hdfs_command.c_str(), ugi.c_str()); return string::format_string(
"%s -Dhadoop.job.ugi=\"%s\"", _hdfs_command.c_str(), ugi.c_str());
} }
private: private:
std::string _get_prefix(const std::tuple<string::Piece, string::Piece, string::Piece>& paths) {
if (std::get<1>(paths).len() == 0) {
return std::get<0>(paths).ToString();
}
return std::get<0>(paths).ToString() + "//" + std::get<1>(paths).ToString();
}
std::tuple<string::Piece, string::Piece, string::Piece> _split_path(string::Piece path) {
// parse "xxx://abc.def:8756/user" as "xxx:", "abc.def:8756", "/user"
// parse "xxx:/user" as "xxx:", "", "/user"
// parse "xxx://abc.def:8756" as "xxx:", "abc.def:8756", ""
// parse "other" as "", "", "other"
std::tuple<string::Piece, string::Piece, string::Piece> result{string::SubStr(path, 0, 0), string::SubStr(path, 0, 0), path};
auto fs_pos = string::Find(path, ':', 0) + 1;
if (path.len() > fs_pos) {
std::get<0>(result) = string::SubStr(path, 0, fs_pos);
path = string::SkipPrefix(path, fs_pos);
if (string::HasPrefix(path, "//")) {
path = string::SkipPrefix(path, 2);
auto end_pos = string::Find(path, '/', 0);
if (end_pos != string::Piece::npos) {
std::get<1>(result) = string::SubStr(path, 0, end_pos);
std::get<2>(result) = string::SkipPrefix(path, end_pos);
} else {
std::get<1>(result) = path;
}
} else {
std::get<2>(result) = path;
}
}
return result;
}
size_t _buffer_size = 0; size_t _buffer_size = 0;
std::string _hdfs_command; std::string _hdfs_command;
std::unordered_map<std::string, std::string> _ugi; std::unordered_map<std::string, std::string> _ugi;
......
...@@ -19,7 +19,8 @@ limitations under the License. */ ...@@ -19,7 +19,8 @@ limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h" #include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
...@@ -31,50 +32,50 @@ namespace { ...@@ -31,50 +32,50 @@ namespace {
const char test_data_dir[] = "test_data"; const char test_data_dir[] = "test_data";
} }
class DataReaderTest : public testing::Test class DataReaderTest : public testing::Test {
{
public: public:
static void SetUpTestCase() static void SetUpTestCase() {
{ std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
framework::shell_set_verbose(true); fs->mkdir(test_data_dir);
framework::localfs_mkdir(test_data_dir); shell_set_verbose(true);
{ {
std::ofstream fout(framework::fs_path_join(test_data_dir, "a.txt")); std::ofstream fout(fs->path_join(test_data_dir, "a.txt"));
fout << "abc 123456" << std::endl; fout << "abc 123456" << std::endl;
fout << "def 234567" << std::endl; fout << "def 234567" << std::endl;
fout.close(); fout.close();
} }
{ {
std::ofstream fout(framework::fs_path_join(test_data_dir, "b.txt")); std::ofstream fout(fs->path_join(test_data_dir, "b.txt"));
fout << "ghi 345678" << std::endl; fout << "ghi 345678" << std::endl;
fout << "jkl 456789" << std::endl; fout << "jkl 456789" << std::endl;
fout.close(); fout.close();
} }
} }
static void TearDownTestCase() static void TearDownTestCase() {
{ std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
framework::localfs_remove(test_data_dir); fs->remove(test_data_dir);
} }
virtual void SetUp() virtual void SetUp() {
{ fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
context_ptr.reset(new TrainerContext()); context_ptr.reset(new TrainerContext());
} }
virtual void TearDown() virtual void TearDown() {
{ fs = nullptr;
context_ptr = nullptr; context_ptr = nullptr;
} }
std::shared_ptr<TrainerContext> context_ptr; std::shared_ptr<TrainerContext> context_ptr;
std::unique_ptr<FileSystem> fs;
}; };
TEST_F(DataReaderTest, LineDataParser) { TEST_F(DataReaderTest, LineDataParser) {
std::unique_ptr<DataParser> data_parser(CREATE_CLASS(DataParser, "LineDataParser")); std::unique_ptr<DataParser> data_parser(CREATE_CLASS(DataParser, "LineDataParser"));
ASSERT_NE(nullptr, data_parser); ASSERT_NE(nullptr, data_parser);
auto config = YAML::Load(""); auto config = YAML::Load("");
...@@ -105,11 +106,12 @@ TEST_F(DataReaderTest, LineDataReader) { ...@@ -105,11 +106,12 @@ TEST_F(DataReaderTest, LineDataReader) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader")); std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader); ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load("parser:\n" auto config = YAML::Load(
" class: LineDataParser\n" "parser:\n"
"pipeline_cmd: cat\n" " class: LineDataParser\n"
"done_file: done_file\n" "pipeline_cmd: cat\n"
"buffer_size: 128"); "done_file: done_file\n"
"buffer_size: 128");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir); auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(2, data_file_list.size()); ASSERT_EQ(2, data_file_list.size());
...@@ -117,7 +119,7 @@ TEST_F(DataReaderTest, LineDataReader) { ...@@ -117,7 +119,7 @@ TEST_F(DataReaderTest, LineDataReader) {
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "b.txt"), data_file_list[1]); ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "b.txt"), data_file_list[1]);
ASSERT_FALSE(data_reader->is_data_ready(test_data_dir)); ASSERT_FALSE(data_reader->is_data_ready(test_data_dir));
std::ofstream fout(framework::fs_path_join(test_data_dir, "done_file")); std::ofstream fout(fs->path_join(test_data_dir, "done_file"));
fout << "done"; fout << "done";
fout.close(); fout.close();
ASSERT_TRUE(data_reader->is_data_ready(test_data_dir)); ASSERT_TRUE(data_reader->is_data_ready(test_data_dir));
...@@ -128,7 +130,7 @@ TEST_F(DataReaderTest, LineDataReader) { ...@@ -128,7 +130,7 @@ TEST_F(DataReaderTest, LineDataReader) {
framework::ChannelReader<DataItem> reader(channel.get()); framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item; DataItem data_item;
reader >> data_item; reader >> data_item;
ASSERT_TRUE(reader); ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str()); ASSERT_STREQ("abc", data_item.id.c_str());
...@@ -156,23 +158,24 @@ TEST_F(DataReaderTest, LineDataReader) { ...@@ -156,23 +158,24 @@ TEST_F(DataReaderTest, LineDataReader) {
TEST_F(DataReaderTest, LineDataReader_filename_prefix) { TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader")); std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader); ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load("parser:\n" auto config = YAML::Load(
" class: LineDataParser\n" "parser:\n"
"pipeline_cmd: cat\n" " class: LineDataParser\n"
"done_file: done_file\n" "pipeline_cmd: cat\n"
"filename_prefix: a"); "done_file: done_file\n"
"filename_prefix: a");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir); auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(1, data_file_list.size()); ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]); ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128); auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel); ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel)); ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get()); framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item; DataItem data_item;
reader >> data_item; reader >> data_item;
ASSERT_TRUE(reader); ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str()); ASSERT_STREQ("abc", data_item.id.c_str());
...@@ -187,6 +190,84 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) { ...@@ -187,6 +190,84 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
ASSERT_FALSE(reader); ASSERT_FALSE(reader);
} }
TEST_F(DataReaderTest, LineDataReader_FileSystem) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
"parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"filename_prefix: a\n"
"file_system:\n"
" class: AutoFileSystem\n"
" file_systems:\n"
" 'afs:': &HDFS \n"
" class: HadoopFileSystem\n"
" hdfs_command: 'hadoop fs'\n"
" ugis:\n"
" 'default': 'feed_video,D3a0z8'\n"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'\n"
" \n"
" 'hdfs:': *HDFS\n");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
{
auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("def", data_item.id.c_str());
ASSERT_STREQ("234567", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
{
char test_hadoop_dir[] = "afs://xingtian.afs.baidu.com:9902/user/feed_video/user/rensilin/paddle_trainer_test_dir";
ASSERT_TRUE(data_reader->is_data_ready(test_hadoop_dir));
auto data_file_list = data_reader->data_file_list(test_hadoop_dir);
ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_hadoop_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_hadoop_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("hello", data_item.id.c_str());
ASSERT_STREQ("world", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("hello", data_item.id.c_str());
ASSERT_STREQ("hadoop", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
}
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
} // namespace paddle } // namespace paddle
...@@ -19,7 +19,8 @@ limitations under the License. */ ...@@ -19,7 +19,8 @@ limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
...@@ -37,7 +38,9 @@ class SimpleExecutorTest : public testing::Test ...@@ -37,7 +38,9 @@ class SimpleExecutorTest : public testing::Test
public: public:
static void SetUpTestCase() static void SetUpTestCase()
{ {
::paddle::framework::localfs_mkdir(test_data_dir); std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir);
shell_set_verbose(true);
{ {
std::unique_ptr<paddle::framework::ProgramDesc> startup_program( std::unique_ptr<paddle::framework::ProgramDesc> startup_program(
...@@ -67,7 +70,8 @@ public: ...@@ -67,7 +70,8 @@ public:
static void TearDownTestCase() static void TearDownTestCase()
{ {
::paddle::framework::localfs_remove(test_data_dir); std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir);
} }
virtual void SetUp() virtual void SetUp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册