data_reader.h 2.6 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10
/* DataReader
 * 对指定数据的读取
 */
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
X
xiexionghang 已提交
11
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
X
xiexionghang 已提交
12
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
X
xiexionghang 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

namespace paddle {
namespace custom_trainer {
namespace feed {

class TrainerContext;

struct FeatureItem {
    uint64_t feature_sign;
    uint16_t slot_id;
};

struct SampleInstance {
    std::string id;
    std::vector<float> lables;
    std::vector<FeatureItem> features;
    std::vector<float> embedx;
};

class DataItem {
public:
    DataItem() {}
    virtual ~DataItem() {}
    std::string id;  //样本id标识,可用于shuffle
    std::string data;//样本数据, maybe压缩格式
};

X
xiexionghang 已提交
40 41 42 43 44
typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe;
inline SampleInstancePipe make_sample_instance_channel() {
    return std::make_shared<Pipeline<DataItem, SampleInstance>>();
}

X
xiexionghang 已提交
45 46 47 48 49 50
class DataParser {
public:
    DataParser() {}
    virtual ~DataParser() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    virtual int parse(const std::string& str, DataItem& data) const {
R
rensilin 已提交
51
        return parse(str.c_str(), data);
X
xiexionghang 已提交
52 53
    }
    virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
R
rensilin 已提交
54
    virtual int parse(const char* str, DataItem& data) const = 0;
X
xiexionghang 已提交
55 56
    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;  
};
X
xiexionghang 已提交
57
REGIST_REGISTERER(DataParser);
X
xiexionghang 已提交
58 59 60 61 62

class DataReader {
public:
    DataReader() {}
    virtual ~DataReader() {}
R
rensilin 已提交
63
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context);
X
xiexionghang 已提交
64 65
    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) = 0;
X
xiexionghang 已提交
66
    //读取dir下文件列表
R
rensilin 已提交
67
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) = 0;
X
xiexionghang 已提交
68
    //读取目录下数据到样本流中
X
xiexionghang 已提交
69
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
70
    //读取指定文件列表的数据到样本流中
X
xiexionghang 已提交
71
    virtual int read_all(const std::vector<std::string>& data_list, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
72 73 74
    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
75
protected:
X
xiexionghang 已提交
76 77
    std::shared_ptr<DataParser> _parser;//数据格式转换
    std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
X
xiexionghang 已提交
78
};
X
xiexionghang 已提交
79
REGIST_REGISTERER(DataReader);
X
xiexionghang 已提交
80 81 82 83

}//namespace feed
}//namespace custom_trainer
}//namespace paddle