dataset.h 1.6 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

class Dataset {
public:
    Dataset() {}
    virtual ~Dataset() {}
    
    virtual int initialize(
        const YAML::Node& config, std::shared_ptr<TrainerContext> context);

    //触发可预取的数据判断
X
xiexionghang 已提交
25
    virtual void pre_detect_data(uint64_t epoch_id);
X
xiexionghang 已提交
26 27 28
    virtual void pre_detect_data(const std::string& data_name, uint64_t epoch_id);

    //获取数据状态
X
xiexionghang 已提交
29
    virtual DatasetStatus epoch_data_status(uint64_t epoch_id);
X
xiexionghang 已提交
30 31
    virtual DatasetStatus epoch_data_status(const std::string& data_name, uint64_t epoch_id);

X
xiexionghang 已提交
32 33 34 35
    //获取数据路径
    virtual std::vector<std::string> epoch_data_path(uint64_t epoch_id);
    virtual std::vector<std::string> epoch_data_path(const std::string& data_name, uint64_t epoch_id);

X
xiexionghang 已提交
36 37 38 39
    //返回各DataContainer内的原始数据(maybe 压缩格式)
    virtual ::paddle::framework::Channel<DataItem> fetch_data(
            const std::string& data_name, uint64_t epoch_id);

X
xiexionghang 已提交
40 41 42
    //获取DataItem解析器
    virtual const DataParser* data_parser(const std::string& data_name);
    
X
xiexionghang 已提交
43 44 45 46 47 48 49
private: 
    std::unordered_map<std::string, std::shared_ptr<DatasetContainer>> _data_containers;
};

} // namespace feed
} // namespace custom_trainer
} // namespace paddle