data_reader.h 3.7 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8
/* DataReader
 * 对指定数据的读取
 */
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
X
xiexionghang 已提交
9
#include <time.h>
X
xiexionghang 已提交
10 11
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
X
xiexionghang 已提交
12
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
W
wangyihong01 已提交
13 14
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
X
xiexionghang 已提交
15
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
X
xiexionghang 已提交
16 17 18 19 20 21 22 23

namespace paddle {
namespace custom_trainer {
namespace feed {

class TrainerContext;

struct FeatureItem {
X
xiexionghang 已提交
24 25
    std::vector<float> weights;
    std::vector<float> gradients;
W
wangyihong01 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
public:
    FeatureItem() {
    }
    FeatureItem(uint64_t sign_, uint16_t slot_) {
        sign() = sign_;
        slot() = slot_;
    }
    uint64_t& sign() {
        return *(uint64_t*)sign_buffer();
    }
    const uint64_t& sign() const {
        return *(const uint64_t*)sign_buffer();
    }
    uint16_t& slot() {
        return _slot;
    }
    const uint16_t& slot() const {
        return _slot;
    }

private:
    char _sign[sizeof(uint64_t)];
    uint16_t _slot;

    char* sign_buffer() const {
        return (char*)_sign;
    }
X
xiexionghang 已提交
53 54 55 56
};

struct SampleInstance {
    std::string id;
X
xiexionghang 已提交
57
    std::vector<float> predicts;
X
xiexionghang 已提交
58
    std::vector<float> labels;
X
xiexionghang 已提交
59 60 61 62 63 64 65 66 67 68 69 70
    std::vector<FeatureItem> features;
    std::vector<float> embedx;
};

class DataItem {
public:
    DataItem() {}
    virtual ~DataItem() {}
    std::string id;  //样本id标识,可用于shuffle
    std::string data;//样本数据, maybe压缩格式
};

X
xiexionghang 已提交
71 72 73 74 75
typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe;
inline SampleInstancePipe make_sample_instance_channel() {
    return std::make_shared<Pipeline<DataItem, SampleInstance>>();
}

X
xiexionghang 已提交
76 77 78 79 80 81
class DataParser {
public:
    DataParser() {}
    virtual ~DataParser() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    virtual int parse(const std::string& str, DataItem& data) const {
X
xiexionghang 已提交
82
        return parse(str.c_str(), str.size(), data);
X
xiexionghang 已提交
83 84 85 86
    }
    virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;  
};
X
xiexionghang 已提交
87
REGIST_REGISTERER(DataParser);
X
xiexionghang 已提交
88 89 90 91 92

class DataReader {
public:
    DataReader() {}
    virtual ~DataReader() {}
R
rensilin 已提交
93
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context);
X
xiexionghang 已提交
94 95
    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) = 0;
X
xiexionghang 已提交
96
    //读取dir下文件列表
R
rensilin 已提交
97
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) = 0;
X
xiexionghang 已提交
98
    //读取目录下数据到样本流中
X
xiexionghang 已提交
99
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
100
    //读取指定文件列表的数据到样本流中
X
xiexionghang 已提交
101
    virtual int read_all(const std::vector<std::string>& data_list, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
102 103 104
    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
105
protected:
X
xiexionghang 已提交
106 107
    std::shared_ptr<DataParser> _parser;//数据格式转换
    std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
X
xiexionghang 已提交
108
};
X
xiexionghang 已提交
109
REGIST_REGISTERER(DataReader);
X
xiexionghang 已提交
110

X
xiexionghang 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
class LineDataParser : public DataParser {
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const;

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
X
xiexionghang 已提交
127 128 129 130

}//namespace feed
}//namespace custom_trainer
}//namespace paddle