data_reader.h 2.7 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10
/* DataReader
 * 对指定数据的读取
 */
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
X
xiexionghang 已提交
11
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
X
xiexionghang 已提交
12
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
X
xiexionghang 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

namespace paddle {
namespace custom_trainer {
namespace feed {

class TrainerContext;

struct FeatureItem {
    uint64_t feature_sign;
    uint16_t slot_id;
};

struct SampleInstance {
    std::string id;
    std::vector<float> lables;
    std::vector<FeatureItem> features;
    std::vector<float> embedx;
};

class DataItem {
public:
    DataItem() {}
    virtual ~DataItem() {}
    std::string id;  //样本id标识,可用于shuffle
    std::string data;//样本数据, maybe压缩格式
};

X
xiexionghang 已提交
40 41 42 43 44
typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe;
inline SampleInstancePipe make_sample_instance_channel() {
    return std::make_shared<Pipeline<DataItem, SampleInstance>>();
}

X
xiexionghang 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
class DataParser {
public:
    DataParser() {}
    virtual ~DataParser() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    virtual int parse(const std::string& str, DataItem& data) const {
        return parse(str.c_str(), str.size(), data);
    }
    virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;  
};
REGISTER_REGISTERER(DataParser);

class DataReader {
public:
    DataReader() {}
    virtual ~DataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) = 0;
X
xiexionghang 已提交
65 66 67 68 69 70
    //读取dir下文件列表
    virtual std::vector<std::string> data_file_list(const std::string& data_dir);
    //读取目录下数据到样本流中
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem>& data_channel) = 0;
    //读取指定文件列表的数据到样本流中
    virtual int read_all(const std::vector<std::string>& data_list, ::paddle::framework::Channel<DataItem>& data_channel) = 0;
X
xiexionghang 已提交
71 72 73 74
    virtual const DataParser* get_parser() {
        return _parser.get();
    }
private:
X
xiexionghang 已提交
75 76
    std::shared_ptr<DataParser> _parser;//数据格式转换
    std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
X
xiexionghang 已提交
77 78 79 80 81
};
REGISTER_REGISTERER(DataReader);


//TODO
X
xiexionghang 已提交
82 83
//可读取HDFS/DISK上数据的Reader,数据按行分隔
//HDFS/DISK - FileLineReader
X
xiexionghang 已提交
84 85 86 87

}//namespace feed
}//namespace custom_trainer
}//namespace paddle