diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc index 932485adc2ebdedad2a16bbd5a33d3ba7d245244..629af98973b93ff72218e34a973835a4879a2141 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc @@ -86,35 +86,50 @@ public: } _done_file_name = config["done_file"].as(); _buffer_size = config["buffer_size"].as(1024); + _filename_prefix = config["filename_prefix"].as(""); _buffer.reset(new char[_buffer_size]); return 0; } //判断样本数据是否已就绪,就绪表明可以开始download virtual bool is_data_ready(const std::string& data_dir) { - auto done_file_path = ::paddle::framework::fs_path_join(data_dir, _done_file_name); - if (::paddle::framework::fs_exists(done_file_path)) { + auto done_file_path = framework::fs_path_join(data_dir, _done_file_name); + if (framework::fs_exists(done_file_path)) { return true; } return false; } + virtual std::vector data_file_list(const std::string& data_dir) { + if (_filename_prefix.empty()) { + return framework::fs_list(data_dir); + } + std::vector data_files; + for (auto& filepath : framework::fs_list(data_dir)) { + auto filename = framework::fs_path_split(filepath).second; + if (filename.size() >= _filename_prefix.size() && filename.substr(0, _filename_prefix.size()) == _filename_prefix) { + data_files.push_back(std::move(filepath)); + } + } + return data_files; + } + //读取数据样本流中 - virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel data_channel) { - ::paddle::framework::ChannelWriter writer(data_channel.get()); + virtual int read_all(const std::string& data_dir, framework::Channel data_channel) { + framework::ChannelWriter writer(data_channel.get()); DataItem data_item; if (_buffer_size <= 0 || _buffer == nullptr) { VLOG(2) << "no buffer"; return -1; } - for (const auto& filename : ::paddle::framework::fs_list(data_dir)) { - if (::paddle::framework::fs_path_split(filename).second == _done_file_name) { + for (const auto& filepath : data_file_list(data_dir)) { + if (framework::fs_path_split(filepath).second == _done_file_name) { continue; } int err_no = 0; - std::shared_ptr fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd); + std::shared_ptr fin = framework::fs_open_read(filepath, &err_no, _pipeline_cmd); if (err_no != 0) { - VLOG(2) << "fail to open file: " << filename << ", with cmd: " << _pipeline_cmd; + VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd; return -1; } while (fgets(_buffer.get(), _buffer_size, fin.get())) { @@ -124,7 +139,7 @@ public: writer << std::move(data_item); } if (ferror(fin.get()) != 0) { - VLOG(2) << "fail to read file: " << filename; + VLOG(2) << "fail to read file: " << filepath; return -1; } } @@ -144,6 +159,7 @@ private: std::string _done_file_name; // without data_dir int _buffer_size = 0; std::unique_ptr _buffer; + std::string _filename_prefix; }; REGISTER_CLASS(DataReader, LineDataReader); diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h index f2548c959359bdf3157750a513d9c58ff7867e6c..a7ab6ea0191bf1222da6cdc05aa6b80d8a829a39 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h @@ -59,6 +59,7 @@ public: virtual bool is_data_ready(const std::string& data_dir) = 0; //读取数据样本流中 virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel data_channel) = 0; + virtual std::vector data_file_list(const std::string& data_dir) = 0; virtual const DataParser* get_parser() { return _parser.get(); } diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc index a1114e042b04d48b981211e0393dbc539e513973..47807b2092f59b78abff270e1f163017949eed71 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc @@ -105,18 +105,16 @@ TEST_F(DataReaderTest, LineDataReader) { std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); - YAML::Node config = YAML::Load("parser:\n" - " class: LineDataParser\n" - "pipeline_cmd: cat\n" - "done_file: done_file"); - ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); - - config = YAML::Load("parser:\n" + auto config = YAML::Load("parser:\n" " class: LineDataParser\n" "pipeline_cmd: cat\n" "done_file: done_file\n" "buffer_size: 128"); ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); + auto data_file_list = data_reader->data_file_list(test_data_dir); + ASSERT_EQ(2, data_file_list.size()); + ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]); + ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "b.txt"), data_file_list[1]); ASSERT_FALSE(data_reader->is_data_ready(test_data_dir)); std::ofstream fout(framework::fs_path_join(test_data_dir, "done_file")); @@ -155,6 +153,40 @@ TEST_F(DataReaderTest, LineDataReader) { ASSERT_FALSE(reader); } +TEST_F(DataReaderTest, LineDataReader_filename_prefix) { + std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + ASSERT_NE(nullptr, data_reader); + auto config = YAML::Load("parser:\n" + " class: LineDataParser\n" + "pipeline_cmd: cat\n" + "done_file: done_file\n" + "filename_prefix: a"); + ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); + auto data_file_list = data_reader->data_file_list(test_data_dir); + ASSERT_EQ(1, data_file_list.size()); + ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]); + + auto channel = framework::MakeChannel(128); + ASSERT_NE(nullptr, channel); + ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel)); + + framework::ChannelReader reader(channel.get()); + DataItem data_item; + + reader >> data_item; + ASSERT_TRUE(reader); + ASSERT_STREQ("abc", data_item.id.c_str()); + ASSERT_STREQ("123456", data_item.data.c_str()); + + reader >> data_item; + ASSERT_TRUE(reader); + ASSERT_STREQ("def", data_item.id.c_str()); + ASSERT_STREQ("234567", data_item.data.c_str()); + + reader >> data_item; + ASSERT_FALSE(reader); +} + } // namespace feed } // namespace custom_trainer } // namespace paddle