data_reader.h 4.0 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8
/* DataReader
 * 对指定数据的读取
 */
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
X
xiexionghang 已提交
9
#include <time.h>
X
xiexionghang 已提交
10 11
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
X
xiexionghang 已提交
12
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
W
wangyihong01 已提交
13 14
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
X
xiexionghang 已提交
15
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
X
xiexionghang 已提交
16 17 18 19 20 21 22 23

namespace paddle {
namespace custom_trainer {
namespace feed {

class TrainerContext;

struct FeatureItem {
X
xiexionghang 已提交
24 25
    std::vector<float> weights;
    std::vector<float> gradients;
W
wangyihong01 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
public:
    FeatureItem() {
    }
    FeatureItem(uint64_t sign_, uint16_t slot_) {
        sign() = sign_;
        slot() = slot_;
    }
    uint64_t& sign() {
        return *(uint64_t*)sign_buffer();
    }
    const uint64_t& sign() const {
        return *(const uint64_t*)sign_buffer();
    }
    uint16_t& slot() {
        return _slot;
    }
    const uint16_t& slot() const {
        return _slot;
    }

private:
    char _sign[sizeof(uint64_t)];
    uint16_t _slot;

    char* sign_buffer() const {
        return (char*)_sign;
    }
X
xiexionghang 已提交
53 54 55 56
};

struct SampleInstance {
    std::string id;
X
xiexionghang 已提交
57
    std::vector<float> predicts;
X
xiexionghang 已提交
58
    std::vector<float> labels;
X
xiexionghang 已提交
59 60 61 62 63 64 65 66 67 68 69 70
    std::vector<FeatureItem> features;
    std::vector<float> embedx;
};

class DataItem {
public:
    DataItem() {}
    virtual ~DataItem() {}
    std::string id;  //样本id标识,可用于shuffle
    std::string data;//样本数据, maybe压缩格式
};

R
rensilin 已提交
71 72 73 74 75 76 77 78 79 80
template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar, DataItem& x) {
    return ar >> x.id >> x.data;
}

template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar, const DataItem& x) {
    return ar << x.id << x.data;
}

X
xiexionghang 已提交
81 82 83 84 85
typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe;
inline SampleInstancePipe make_sample_instance_channel() {
    return std::make_shared<Pipeline<DataItem, SampleInstance>>();
}

X
xiexionghang 已提交
86 87 88 89 90 91
class DataParser {
public:
    DataParser() {}
    virtual ~DataParser() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    virtual int parse(const std::string& str, DataItem& data) const {
X
xiexionghang 已提交
92
        return parse(str.c_str(), str.size(), data);
X
xiexionghang 已提交
93 94 95 96
    }
    virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;  
};
X
xiexionghang 已提交
97
REGIST_REGISTERER(DataParser);
X
xiexionghang 已提交
98 99 100 101 102

class DataReader {
public:
    DataReader() {}
    virtual ~DataReader() {}
R
rensilin 已提交
103
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context);
X
xiexionghang 已提交
104 105
    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) = 0;
X
xiexionghang 已提交
106
    //读取dir下文件列表
R
rensilin 已提交
107
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) = 0;
X
xiexionghang 已提交
108
    //读取目录下数据到样本流中
X
xiexionghang 已提交
109
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
110
    //读取指定文件列表的数据到样本流中
X
xiexionghang 已提交
111
    virtual int read_all(const std::vector<std::string>& data_list, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
112 113 114
    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
115
protected:
X
xiexionghang 已提交
116 117
    std::shared_ptr<DataParser> _parser;//数据格式转换
    std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
X
xiexionghang 已提交
118
};
X
xiexionghang 已提交
119
REGIST_REGISTERER(DataReader);
X
xiexionghang 已提交
120

X
xiexionghang 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
class LineDataParser : public DataParser {
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const;

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
X
xiexionghang 已提交
137 138 139 140

}//namespace feed
}//namespace custom_trainer
}//namespace paddle