From 147fe8cbb0efbddaa42d2bc5506dba301c0dea85 Mon Sep 17 00:00:00 2001 From: xiexionghang Date: Thu, 1 Aug 2019 18:34:19 +0800 Subject: [PATCH] add dataset --- .../custom_trainer/feed/dataset/data_reader.h | 75 +++++++++++++++++++ .../feed/dataset/dataset_container.h | 13 +--- .../custom_trainer/feed/executor/executor.cc | 1 + 3 files changed, 79 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h new file mode 100644 index 00000000..d7841d35 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h @@ -0,0 +1,75 @@ +/* DataReader + * 对指定数据的读取 + */ +#pragma once +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/channel.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class TrainerContext; + +struct FeatureItem { + uint64_t feature_sign; + uint16_t slot_id; +}; + +struct SampleInstance { + std::string id; + std::vector lables; + std::vector features; + std::vector embedx; +}; + +class DataItem { +public: + DataItem() {} + virtual ~DataItem() {} + std::string id; //样本id标识,可用于shuffle + std::string data;//样本数据, maybe压缩格式 +}; + +class DataParser { +public: + DataParser() {} + virtual ~DataParser() {} + virtual int initialize(const YAML::Node& config, std::shared_ptr context) = 0; + virtual int parse(const std::string& str, DataItem& data) const { + return parse(str.c_str(), str.size(), data); + } + virtual int parse(const char* str, size_t len, DataItem& data) const = 0; + virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0; +}; +REGISTER_REGISTERER(DataParser); + +class DataReader { +public: + DataReader() {} + virtual ~DataReader() {} + virtual int initialize(const YAML::Node& config, std::shared_ptr context) = 0; + //判断样本数据是否已就绪,就绪表明可以开始download + virtual bool is_data_ready(const std::string& data_dir) = 0; + //读取数据样本流中 + virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel data_channel) = 0; + virtual const DataParser* get_parser() { + return _parser.get(); + } +private: + std::shared_ptr _parser; + std::string _data_converter_shell; +}; +REGISTER_REGISTERER(DataReader); + + +//TODO +//HDFS/DISK Reader + +}//namespace feed +}//namespace custom_trainer +}//namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h index a681db9f..a7404a82 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h +++ b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h @@ -9,27 +9,20 @@ #include #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h" #include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h" namespace paddle { namespace custom_trainer { namespace feed { -//单条样本的原始数据 -class DataItem { -public: - DataItem() {} - virtual ~DataItem() {} - std::string id; //样本id标识,可用于shuffle - std::string data;//样本完整数据 -}; - class DatasetContainer { public: DatasetContainer() {} virtual ~DatasetContainer() {} virtual int initialize(const YAML::Node& config) { _dataset_config = config; + //预取n轮样本数据 _prefetch_num = config["prefetch_num"].as(); _data_root_path = config["root_path"].as(); _data_path_generater = config["_data_path_generater"].as(); @@ -54,7 +47,7 @@ protected: uint32_t _current_dataset_idx; //当前样本数据idx int _current_epoch_id = -1; int _ready_epoch_id = -1; //已下载完成的epoch_id - std::vector> _dataset_list; + std::vector> _dataset_list;//预取的数据列表 }; }//namespace feed diff --git a/paddle/fluid/train/custom_trainer/feed/executor/executor.cc b/paddle/fluid/train/custom_trainer/feed/executor/executor.cc index 782ec620..caf5eaba 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/executor/executor.cc @@ -123,6 +123,7 @@ int SimpleExecutor::run() { } return 0; } +REGISTER_CLASS(Executor, SimpleExecutor); } // namespace feed } // namespace custom_trainer -- GitLab