data_reader.h 2.7 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10
/* DataReader
 * 对指定数据的读取
 */
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
X
xiexionghang 已提交
11
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
X
xiexionghang 已提交
12
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
X
xiexionghang 已提交
13 14 15 16 17 18 19 20 21 22

namespace paddle {
namespace custom_trainer {
namespace feed {

class TrainerContext;

struct FeatureItem {
    uint64_t feature_sign;
    uint16_t slot_id;
X
xiexionghang 已提交
23 24
    std::vector<float> weights;
    std::vector<float> gradients;
X
xiexionghang 已提交
25 26 27 28
};

struct SampleInstance {
    std::string id;
X
xiexionghang 已提交
29
    std::vector<float> labels;
X
xiexionghang 已提交
30 31 32 33 34 35 36 37 38 39 40 41
    std::vector<FeatureItem> features;
    std::vector<float> embedx;
};

class DataItem {
public:
    DataItem() {}
    virtual ~DataItem() {}
    std::string id;  //样本id标识,可用于shuffle
    std::string data;//样本数据, maybe压缩格式
};

X
xiexionghang 已提交
42 43 44 45 46
typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe;
inline SampleInstancePipe make_sample_instance_channel() {
    return std::make_shared<Pipeline<DataItem, SampleInstance>>();
}

X
xiexionghang 已提交
47 48 49 50 51 52
class DataParser {
public:
    DataParser() {}
    virtual ~DataParser() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    virtual int parse(const std::string& str, DataItem& data) const {
R
rensilin 已提交
53
        return parse(str.c_str(), data);
X
xiexionghang 已提交
54 55
    }
    virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
R
rensilin 已提交
56
    virtual int parse(const char* str, DataItem& data) const = 0;
X
xiexionghang 已提交
57 58
    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;  
};
X
xiexionghang 已提交
59
REGIST_REGISTERER(DataParser);
X
xiexionghang 已提交
60 61 62 63 64

class DataReader {
public:
    DataReader() {}
    virtual ~DataReader() {}
R
rensilin 已提交
65
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context);
X
xiexionghang 已提交
66 67
    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) = 0;
X
xiexionghang 已提交
68
    //读取dir下文件列表
R
rensilin 已提交
69
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) = 0;
X
xiexionghang 已提交
70
    //读取目录下数据到样本流中
X
xiexionghang 已提交
71
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
72
    //读取指定文件列表的数据到样本流中
X
xiexionghang 已提交
73
    virtual int read_all(const std::vector<std::string>& data_list, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
74 75 76
    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
77
protected:
X
xiexionghang 已提交
78 79
    std::shared_ptr<DataParser> _parser;//数据格式转换
    std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
X
xiexionghang 已提交
80
};
X
xiexionghang 已提交
81
REGIST_REGISTERER(DataReader);
X
xiexionghang 已提交
82 83 84 85

}//namespace feed
}//namespace custom_trainer
}//namespace paddle