dataset.cc 2.7 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int Dataset::initialize(
    const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
X
xiexionghang 已提交
9
    if (config["data_list"].Type() != YAML::NodeType::Map) {
R
rensilin 已提交
10
        LOG(FATAL) << "miss data_list config in dataset, or type error please check";
X
xiexionghang 已提交
11 12
        return -1;
    }
X
xiexionghang 已提交
13 14
    for (auto& data_config : config["data_list"]) {
        std::string name = data_config.first.as<std::string>();
X
xiexionghang 已提交
15
        auto data_ptr = std::make_shared<DatasetContainer>();
X
xiexionghang 已提交
16
        if (data_ptr->initialize(data_config.second, context) != 0) {
R
rensilin 已提交
17
            LOG(FATAL) << "dataset initialize failed, name:" << name;
X
xiexionghang 已提交
18 19 20 21 22 23 24
            return -1;
        }
        _data_containers[name] = data_ptr;
    }
    return 0;
}

X
xiexionghang 已提交
25 26 27 28 29 30
inline void Dataset::pre_detect_data(uint64_t epoch_id) {
    for (auto it = _data_containers.begin(); it != _data_containers.end(); ++it) {
        it->second->pre_detect_data(epoch_id);
    }
    return;
}
X
xiexionghang 已提交
31 32 33 34 35 36
inline void Dataset::pre_detect_data(
    const std::string& data_name, uint64_t epoch_id) {
    _data_containers[data_name]->pre_detect_data(epoch_id);
    return;
}

X
xiexionghang 已提交
37 38 39 40 41 42 43 44 45
inline DatasetStatus Dataset::epoch_data_status(uint64_t epoch_id) {
    int status = static_cast<int>(DatasetStatus::Ready);
    for (auto it = _data_containers.begin(); it != _data_containers.end(); ++it) {
        auto d_status = static_cast<int>(it->second->epoch_data_status(epoch_id));
        status = d_status < status ? d_status : status;
    }
    return static_cast<DatasetStatus>(status);
}

X
xiexionghang 已提交
46 47 48 49 50
inline DatasetStatus Dataset::epoch_data_status(
    const std::string& data_name, uint64_t epoch_id) {
    return _data_containers[data_name]->epoch_data_status(epoch_id);
}

X
xiexionghang 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
inline std::vector<std::string> Dataset::epoch_data_path(
    const std::string& data_name, uint64_t epoch_id) {
    return _data_containers[data_name]->epoch_data_path(epoch_id);
}

inline std::vector<std::string> Dataset::epoch_data_path(uint64_t epoch_id) {
    std::vector<std::string> results;
    for (auto it = _data_containers.begin(); it != _data_containers.end(); ++it) {
        auto items = std::move(it->second->epoch_data_path(epoch_id));
        for (auto& item : items) {
            results.emplace_back(item);
        }
    }
    return results;
}

X
xiexionghang 已提交
67 68 69 70 71
inline ::paddle::framework::Channel<DataItem> Dataset::fetch_data(
    const std::string& data_name, uint64_t epoch_id) {
    return _data_containers[data_name]->fetch(epoch_id);
}

X
xiexionghang 已提交
72
inline const DataParser* Dataset::data_parser(const std::string& data_name) {
X
xiexionghang 已提交
73
    auto* data_container = _data_containers[data_name].get();
X
xiexionghang 已提交
74
    return data_container->data_parser();
X
xiexionghang 已提交
75 76 77 78 79 80
}
     

} // namespace feed
} // namespace custom_trainer
} // namespace paddle