data_reader.h 2.1 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10
/* DataReader
 * 对指定数据的读取
 */
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
X
xiexionghang 已提交
11
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
X
xiexionghang 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

namespace paddle {
namespace custom_trainer {
namespace feed {

class TrainerContext;

struct FeatureItem {
    uint64_t feature_sign;
    uint16_t slot_id;
};

struct SampleInstance {
    std::string id;
    std::vector<float> lables;
    std::vector<FeatureItem> features;
    std::vector<float> embedx;
};

class DataItem {
public:
    DataItem() {}
    virtual ~DataItem() {}
    std::string id;  //样本id标识,可用于shuffle
    std::string data;//样本数据, maybe压缩格式
};

class DataParser {
public:
    DataParser() {}
    virtual ~DataParser() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
    virtual int parse(const std::string& str, DataItem& data) const {
R
rensilin 已提交
45
        return parse(str.c_str(), data);
X
xiexionghang 已提交
46 47
    }
    virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
R
rensilin 已提交
48
    virtual int parse(const char* str, DataItem& data) const = 0;
X
xiexionghang 已提交
49 50 51 52 53 54 55 56
    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;  
};
REGISTER_REGISTERER(DataParser);

class DataReader {
public:
    DataReader() {}
    virtual ~DataReader() {}
R
rensilin 已提交
57
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context);
X
xiexionghang 已提交
58 59 60 61 62 63 64
    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) = 0;
    //读取数据样本流中
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
65
protected:
X
xiexionghang 已提交
66 67
    std::shared_ptr<DataParser> _parser;//数据格式转换
    std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
X
xiexionghang 已提交
68 69 70 71 72 73
};
REGISTER_REGISTERER(DataReader);

}//namespace feed
}//namespace custom_trainer
}//namespace paddle