提交 147fe8cb 编写于 作者: X xiexionghang

add dataset

上级 8b7e1ed1
/* DataReader
* 对指定数据的读取
*/
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class TrainerContext;
struct FeatureItem {
uint64_t feature_sign;
uint16_t slot_id;
};
struct SampleInstance {
std::string id;
std::vector<float> lables;
std::vector<FeatureItem> features;
std::vector<float> embedx;
};
class DataItem {
public:
DataItem() {}
virtual ~DataItem() {}
std::string id; //样本id标识,可用于shuffle
std::string data;//样本数据, maybe压缩格式
};
class DataParser {
public:
DataParser() {}
virtual ~DataParser() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
virtual int parse(const std::string& str, DataItem& data) const {
return parse(str.c_str(), str.size(), data);
}
virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;
};
REGISTER_REGISTERER(DataParser);
class DataReader {
public:
DataReader() {}
virtual ~DataReader() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
//判断样本数据是否已就绪,就绪表明可以开始download
virtual bool is_data_ready(const std::string& data_dir) = 0;
//读取数据样本流中
virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
virtual const DataParser* get_parser() {
return _parser.get();
}
private:
std::shared_ptr<DataParser> _parser;
std::string _data_converter_shell;
};
REGISTER_REGISTERER(DataReader);
//TODO
//HDFS/DISK Reader
}//namespace feed
}//namespace custom_trainer
}//namespace paddle
......@@ -9,27 +9,20 @@
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
//单条样本的原始数据
class DataItem {
public:
DataItem() {}
virtual ~DataItem() {}
std::string id; //样本id标识,可用于shuffle
std::string data;//样本完整数据
};
class DatasetContainer {
public:
DatasetContainer() {}
virtual ~DatasetContainer() {}
virtual int initialize(const YAML::Node& config) {
_dataset_config = config;
//预取n轮样本数据
_prefetch_num = config["prefetch_num"].as<int>();
_data_root_path = config["root_path"].as<std::string>();
_data_path_generater = config["_data_path_generater"].as<std::string>();
......@@ -54,7 +47,7 @@ protected:
uint32_t _current_dataset_idx; //当前样本数据idx
int _current_epoch_id = -1;
int _ready_epoch_id = -1; //已下载完成的epoch_id
std::vector<std::shared_ptr<::paddle::framework::Dataset>> _dataset_list;
std::vector<std::shared_ptr<::paddle::framework::Dataset>> _dataset_list;//预取的数据列表
};
}//namespace feed
......
......@@ -123,6 +123,7 @@ int SimpleExecutor::run() {
}
return 0;
}
REGISTER_CLASS(Executor, SimpleExecutor);
} // namespace feed
} // namespace custom_trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册