data_reader.h 2.0 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10
/* DataReader
 * 对指定数据的读取
 */
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
X
xiexionghang 已提交
11
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
X
xiexionghang 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

namespace paddle {
namespace custom_trainer {
namespace feed {

class TrainerContext;

struct FeatureItem {
    uint64_t feature_sign;
    uint16_t slot_id;
};

struct SampleInstance {
    std::string id;
    std::vector<float> lables;
    std::vector<FeatureItem> features;
    std::vector<float> embedx;
};

class DataItem {
public:
    DataItem() {}
    virtual ~DataItem() {}
    std::string id;  //样本id标识,可用于shuffle
    std::string data;//样本数据, maybe压缩格式
};

class DataParser {
public:
    DataParser() {}
    virtual ~DataParser() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    virtual int parse(const std::string& str, DataItem& data) const {
        return parse(str.c_str(), str.size(), data);
    }
    virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;  
};
REGISTER_REGISTERER(DataParser);

class DataReader {
public:
    DataReader() {}
    virtual ~DataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) = 0;
    //读取数据样本流中
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
    virtual const DataParser* get_parser() {
        return _parser.get();
    }
private:
    std::shared_ptr<DataParser> _parser;
    std::string _data_converter_shell;
};
REGISTER_REGISTERER(DataReader);


//TODO
//HDFS/DISK Reader

}//namespace feed
}//namespace custom_trainer
}//namespace paddle