提交 17e0cb7c 编写于 作者: R rensilin

data_reader

Change-Id: Id829354b599d9d98824785be2a883480c94b9ffe
上级 1fcde8e9
...@@ -149,7 +149,7 @@ std::vector<std::string> localfs_list(const std::string& path) { ...@@ -149,7 +149,7 @@ std::vector<std::string> localfs_list(const std::string& path) {
std::shared_ptr<FILE> pipe; std::shared_ptr<FILE> pipe;
int err_no = 0; int err_no = 0;
pipe = shell_popen( pipe = shell_popen(
string::format_string("find %s -type f -maxdepth 1", path.c_str()), "r", string::format_string("find %s -maxdepth 1 -type f", path.c_str()), "r",
&err_no); &err_no);
string::LineFileReader reader; string::LineFileReader reader;
std::vector<std::string> list; std::vector<std::string> list;
...@@ -452,5 +452,24 @@ void fs_mkdir(const std::string& path) { ...@@ -452,5 +452,24 @@ void fs_mkdir(const std::string& path) {
LOG(FATAL) << "Not supported"; LOG(FATAL) << "Not supported";
} }
} }
std::string fs_path_join(const std::string& dir, const std::string &path) {
if (dir.empty()) {
return path;
}
if (dir.back() == '/') {
return dir + path;
}
return dir + '/' + path;
}
std::pair<std::string, std::string> fs_path_split(const std::string &path) {
size_t pos = path.find_last_of('/');
if (pos == std::string::npos) {
return {".", path};
}
return {path.substr(0, pos), path.substr(pos + 1)};
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -97,5 +97,9 @@ extern std::string fs_tail(const std::string& path); ...@@ -97,5 +97,9 @@ extern std::string fs_tail(const std::string& path);
extern bool fs_exists(const std::string& path); extern bool fs_exists(const std::string& path);
extern void fs_mkdir(const std::string& path); extern void fs_mkdir(const std::string& path);
extern std::string fs_path_join(const std::string& dir, const std::string &path);
extern std::pair<std::string, std::string> fs_path_split(const std::string &path);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include <cstdio>
#include <glog/logging.h>
#include "paddle/fluid/framework/io/fs.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class LineDataParser : public DataParser{
public:
LineDataParser() {}
virtual ~LineDataParser() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
return 0;
}
virtual int parse(const char* str, size_t len, DataItem& data) const {
size_t pos = 0;
while (str[pos] != ' ') {
if (pos >= len) {
VLOG(2) << "fail to parse line, strlen: " << len;
return -1;
}
++pos;
}
data.id.assign(str, pos);
data.data.assign(str + pos + 1, len - pos - 1);
return 0;
}
virtual int parse(const char* str, DataItem& data) const {
size_t pos = 0;
while (str[pos] != ' ') {
if (str[pos] == '\0') {
VLOG(2) << "fail to parse line, get '\\0' at pos: " << pos;
return -1;
}
++pos;
}
data.id.assign(str, pos);
data.data.assign(str + pos + 1);
return 0;
}
virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
return 0;
}
};
REGISTER_CLASS(DataParser, LineDataParser);
int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
_parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
if (_parser == nullptr) {
VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
return -1;
}
if (_parser->initialize(config["parser"], context) != 0) {
VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
return -1;
}
_pipeline_cmd = config["pipeline_cmd"].as<std::string>();
return 0;
}
class LineDataReader : public DataReader {
public:
LineDataReader() {}
virtual ~LineDataReader() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
if (DataReader::initialize(config, context) != 0) {
return -1;
}
_done_file_name = config["done_file"].as<std::string>();
_buffer_size = config["buffer_size"].as<int>(1024);
_buffer.reset(new char[_buffer_size]);
return 0;
}
//判断样本数据是否已就绪,就绪表明可以开始download
virtual bool is_data_ready(const std::string& data_dir) {
auto done_file_path = ::paddle::framework::fs_path_join(data_dir, _done_file_name);
if (::paddle::framework::fs_exists(done_file_path)) {
return true;
}
return false;
}
//读取数据样本流中
virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) {
::paddle::framework::ChannelWriter<DataItem> writer(data_channel.get());
DataItem data_item;
if (_buffer_size <= 0 || _buffer == nullptr) {
VLOG(2) << "no buffer";
return -1;
}
for (const auto& filename : ::paddle::framework::fs_list(data_dir)) {
if (::paddle::framework::fs_path_split(filename).second == _done_file_name) {
continue;
}
int err_no;
std::shared_ptr<FILE> fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd);
if (err_no != 0) {
return -1;
}
while (fgets(_buffer.get(), _buffer_size, fin.get())) {
if (_parser->parse(_buffer.get(), data_item) != 0) {
return -1;
}
writer << std::move(data_item);
}
if (ferror(fin.get()) != 0) {
VLOG(2) << "fail to read file: " << filename;
return -1;
}
}
writer.Flush();
if (!writer) {
VLOG(2) << "fail when write to channel";
return -1;
}
return 0;
}
virtual const DataParser* get_parser() {
return _parser.get();
}
private:
std::string _done_file_name; // without data_dir
int _buffer_size = 0;
std::unique_ptr<char[]> _buffer;
};
REGISTER_CLASS(DataReader, LineDataReader);
}//namespace feed
}//namespace custom_trainer
}//namespace paddle
...@@ -42,9 +42,10 @@ public: ...@@ -42,9 +42,10 @@ public:
virtual ~DataParser() {} virtual ~DataParser() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0; virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
virtual int parse(const std::string& str, DataItem& data) const { virtual int parse(const std::string& str, DataItem& data) const {
return parse(str.c_str(), str.size(), data); return parse(str.c_str(), data);
} }
virtual int parse(const char* str, size_t len, DataItem& data) const = 0; virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
virtual int parse(const char* str, DataItem& data) const = 0;
virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0; virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;
}; };
REGISTER_REGISTERER(DataParser); REGISTER_REGISTERER(DataParser);
...@@ -53,7 +54,7 @@ class DataReader { ...@@ -53,7 +54,7 @@ class DataReader {
public: public:
DataReader() {} DataReader() {}
virtual ~DataReader() {} virtual ~DataReader() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0; virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context);
//判断样本数据是否已就绪,就绪表明可以开始download //判断样本数据是否已就绪,就绪表明可以开始download
virtual bool is_data_ready(const std::string& data_dir) = 0; virtual bool is_data_ready(const std::string& data_dir) = 0;
//读取数据样本流中 //读取数据样本流中
...@@ -61,17 +62,12 @@ public: ...@@ -61,17 +62,12 @@ public:
virtual const DataParser* get_parser() { virtual const DataParser* get_parser() {
return _parser.get(); return _parser.get();
} }
private: protected:
std::shared_ptr<DataParser> _parser;//数据格式转换 std::shared_ptr<DataParser> _parser;//数据格式转换
std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入 std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
}; };
REGISTER_REGISTERER(DataReader); REGISTER_REGISTERER(DataReader);
//TODO
//可读取HDFS/DISK上数据的Reader,数据按行分隔
//HDFS/DISK - FileLineReader
}//namespace feed }//namespace feed
}//namespace custom_trainer }//namespace custom_trainer
}//namespace paddle }//namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册