data_reader.h 3.3 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10
/* DataReader
 * 对指定数据的读取
 */
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
X
xiexionghang 已提交
11
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
W
wangyihong01 已提交
12 13
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
X
xiexionghang 已提交
14
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
X
xiexionghang 已提交
15 16 17 18 19 20 21 22

namespace paddle {
namespace custom_trainer {
namespace feed {

class TrainerContext;

struct FeatureItem {
W
wangyihong01 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
public:
    FeatureItem() {
    }
    FeatureItem(uint64_t sign_, uint16_t slot_) {
        sign() = sign_;
        slot() = slot_;
    }
    uint64_t& sign() {
        return *(uint64_t*)sign_buffer();
    }
    const uint64_t& sign() const {
        return *(const uint64_t*)sign_buffer();
    }
    uint16_t& slot() {
        return _slot;
    }
    const uint16_t& slot() const {
        return _slot;
    }

private:
    char _sign[sizeof(uint64_t)];
    uint16_t _slot;

    char* sign_buffer() const {
        return (char*)_sign;
    }
X
xiexionghang 已提交
50 51 52 53
};

struct SampleInstance {
    std::string id;
Y
yaopenghui 已提交
54
    std::vector<float> predicts;
W
wangyihong01 已提交
55
    std::vector<float> labels;
X
xiexionghang 已提交
56 57 58 59 60 61 62 63 64 65 66 67
    std::vector<FeatureItem> features;
    std::vector<float> embedx;
};

class DataItem {
public:
    DataItem() {}
    virtual ~DataItem() {}
    std::string id;  //样本id标识,可用于shuffle
    std::string data;//样本数据, maybe压缩格式
};

X
xiexionghang 已提交
68 69 70 71 72
typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe;
inline SampleInstancePipe make_sample_instance_channel() {
    return std::make_shared<Pipeline<DataItem, SampleInstance>>();
}

X
xiexionghang 已提交
73 74 75 76 77 78
class DataParser {
public:
    DataParser() {}
    virtual ~DataParser() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    virtual int parse(const std::string& str, DataItem& data) const {
R
rensilin 已提交
79
        return parse(str.c_str(), data);
X
xiexionghang 已提交
80 81
    }
    virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
R
rensilin 已提交
82
    virtual int parse(const char* str, DataItem& data) const = 0;
X
xiexionghang 已提交
83 84 85 86 87 88 89 90
    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;  
};
REGISTER_REGISTERER(DataParser);

class DataReader {
public:
    DataReader() {}
    virtual ~DataReader() {}
R
rensilin 已提交
91
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context);
X
xiexionghang 已提交
92 93
    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) = 0;
X
xiexionghang 已提交
94
    //读取dir下文件列表
R
rensilin 已提交
95
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) = 0;
X
xiexionghang 已提交
96
    //读取目录下数据到样本流中
X
xiexionghang 已提交
97
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
98
    //读取指定文件列表的数据到样本流中
X
xiexionghang 已提交
99
    virtual int read_all(const std::vector<std::string>& data_list, ::paddle::framework::Channel<DataItem> data_channel) = 0;
X
xiexionghang 已提交
100 101 102
    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
103
protected:
X
xiexionghang 已提交
104 105
    std::shared_ptr<DataParser> _parser;//数据格式转换
    std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
X
xiexionghang 已提交
106 107 108 109 110 111
};
REGISTER_REGISTERER(DataReader);

}//namespace feed
}//namespace custom_trainer
}//namespace paddle