diff --git a/paddle/fluid/framework/io/fs.cc b/paddle/fluid/framework/io/fs.cc index d5bc5df2565b0f25bc29f2fce37c1bd8626a0dbc..fc6e9f40b9dab44bc94f7730a2e2c21b16a5a872 100644 --- a/paddle/fluid/framework/io/fs.cc +++ b/paddle/fluid/framework/io/fs.cc @@ -149,7 +149,7 @@ std::vector localfs_list(const std::string& path) { std::shared_ptr pipe; int err_no = 0; pipe = shell_popen( - string::format_string("find %s -type f -maxdepth 1", path.c_str()), "r", + string::format_string("find %s -maxdepth 1 -type f", path.c_str()), "r", &err_no); string::LineFileReader reader; std::vector list; @@ -452,5 +452,24 @@ void fs_mkdir(const std::string& path) { LOG(FATAL) << "Not supported"; } } + +std::string fs_path_join(const std::string& dir, const std::string &path) { + if (dir.empty()) { + return path; + } + if (dir.back() == '/') { + return dir + path; + } + return dir + '/' + path; +} + +std::pair fs_path_split(const std::string &path) { + size_t pos = path.find_last_of('/'); + if (pos == std::string::npos) { + return {".", path}; + } + return {path.substr(0, pos), path.substr(pos + 1)}; +} + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/io/fs.h b/paddle/fluid/framework/io/fs.h index 3f0174701c24cc5a3eac38d12792650bdbd9463b..3f9e787348426a29ddd639174814c5c337b455cc 100644 --- a/paddle/fluid/framework/io/fs.h +++ b/paddle/fluid/framework/io/fs.h @@ -97,5 +97,9 @@ extern std::string fs_tail(const std::string& path); extern bool fs_exists(const std::string& path); extern void fs_mkdir(const std::string& path); + +extern std::string fs_path_join(const std::string& dir, const std::string &path); + +extern std::pair fs_path_split(const std::string &path); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc new file mode 100644 index 0000000000000000000000000000000000000000..8c857ad8f7ea2c669d11820b37124ab0e96ac64d --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc @@ -0,0 +1,142 @@ +#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h" + +#include + +#include + +#include "paddle/fluid/framework/io/fs.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class LineDataParser : public DataParser{ +public: + LineDataParser() {} + + virtual ~LineDataParser() {} + + virtual int initialize(const YAML::Node& config, std::shared_ptr context) { + return 0; + } + + virtual int parse(const char* str, size_t len, DataItem& data) const { + size_t pos = 0; + while (str[pos] != ' ') { + if (pos >= len) { + VLOG(2) << "fail to parse line, strlen: " << len; + return -1; + } + ++pos; + } + data.id.assign(str, pos); + data.data.assign(str + pos + 1, len - pos - 1); + return 0; + } + + virtual int parse(const char* str, DataItem& data) const { + size_t pos = 0; + while (str[pos] != ' ') { + if (str[pos] == '\0') { + VLOG(2) << "fail to parse line, get '\\0' at pos: " << pos; + return -1; + } + ++pos; + } + data.id.assign(str, pos); + data.data.assign(str + pos + 1); + return 0; + } + + virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const { + return 0; + } +}; +REGISTER_CLASS(DataParser, LineDataParser); + +int DataReader::initialize(const YAML::Node& config, std::shared_ptr context) { + _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as())); + if (_parser == nullptr) { + VLOG(2) << "fail to get parser: " << config["parser"]["class"].as(); + return -1; + } + if (_parser->initialize(config["parser"], context) != 0) { + VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as(); + return -1; + } + _pipeline_cmd = config["pipeline_cmd"].as(); + return 0; +} + +class LineDataReader : public DataReader { +public: + LineDataReader() {} + virtual ~LineDataReader() {} + virtual int initialize(const YAML::Node& config, std::shared_ptr context) { + if (DataReader::initialize(config, context) != 0) { + return -1; + } + _done_file_name = config["done_file"].as(); + _buffer_size = config["buffer_size"].as(1024); + _buffer.reset(new char[_buffer_size]); + return 0; + } + + //判断样本数据是否已就绪,就绪表明可以开始download + virtual bool is_data_ready(const std::string& data_dir) { + auto done_file_path = ::paddle::framework::fs_path_join(data_dir, _done_file_name); + if (::paddle::framework::fs_exists(done_file_path)) { + return true; + } + return false; + } + + //读取数据样本流中 + virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel data_channel) { + ::paddle::framework::ChannelWriter writer(data_channel.get()); + DataItem data_item; + if (_buffer_size <= 0 || _buffer == nullptr) { + VLOG(2) << "no buffer"; + return -1; + } + for (const auto& filename : ::paddle::framework::fs_list(data_dir)) { + if (::paddle::framework::fs_path_split(filename).second == _done_file_name) { + continue; + } + int err_no; + std::shared_ptr fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd); + if (err_no != 0) { + return -1; + } + while (fgets(_buffer.get(), _buffer_size, fin.get())) { + if (_parser->parse(_buffer.get(), data_item) != 0) { + return -1; + } + writer << std::move(data_item); + } + if (ferror(fin.get()) != 0) { + VLOG(2) << "fail to read file: " << filename; + return -1; + } + } + writer.Flush(); + if (!writer) { + VLOG(2) << "fail when write to channel"; + return -1; + } + return 0; + } + + virtual const DataParser* get_parser() { + return _parser.get(); + } +private: + std::string _done_file_name; // without data_dir + int _buffer_size = 0; + std::unique_ptr _buffer; +}; +REGISTER_CLASS(DataReader, LineDataReader); + +}//namespace feed +}//namespace custom_trainer +}//namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h index a10275dd04148f00a1319160e5a773b952ceb4a7..f2548c959359bdf3157750a513d9c58ff7867e6c 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h @@ -42,9 +42,10 @@ public: virtual ~DataParser() {} virtual int initialize(const YAML::Node& config, std::shared_ptr context) = 0; virtual int parse(const std::string& str, DataItem& data) const { - return parse(str.c_str(), str.size(), data); + return parse(str.c_str(), data); } virtual int parse(const char* str, size_t len, DataItem& data) const = 0; + virtual int parse(const char* str, DataItem& data) const = 0; virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0; }; REGISTER_REGISTERER(DataParser); @@ -53,7 +54,7 @@ class DataReader { public: DataReader() {} virtual ~DataReader() {} - virtual int initialize(const YAML::Node& config, std::shared_ptr context) = 0; + virtual int initialize(const YAML::Node& config, std::shared_ptr context); //判断样本数据是否已就绪,就绪表明可以开始download virtual bool is_data_ready(const std::string& data_dir) = 0; //读取数据样本流中 @@ -61,17 +62,12 @@ public: virtual const DataParser* get_parser() { return _parser.get(); } -private: +protected: std::shared_ptr _parser;//数据格式转换 std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入 }; REGISTER_REGISTERER(DataReader); - -//TODO -//可读取HDFS/DISK上数据的Reader,数据按行分隔 -//HDFS/DISK - FileLineReader - }//namespace feed }//namespace custom_trainer }//namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so b/paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so index dff4760f3c761a51379a5fd2821d241163a5c396..9fef632a0ba7d03d101ada2eadb97105c2b9cdd3 100755 Binary files a/paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so and b/paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so differ