dataset.cc 2.2 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

int Dataset::initialize(
    const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
X
xiexionghang 已提交
9 10
    if (config["data_list"].Type() != YAML::NodeType::Map) {
        VLOG(0) << "miss data_list config in dataset, or type error please check";
X
xiexionghang 已提交
11 12
        return -1;
    }
X
xiexionghang 已提交
13 14
    for (auto& data_config : config["data_list"]) {
        std::string name = data_config.first.as<std::string>();
X
xiexionghang 已提交
15
        auto data_ptr = std::make_shared<DatasetContainer>();
X
xiexionghang 已提交
16
        if (data_ptr->initialize(data_config.second, context) != 0) {
X
xiexionghang 已提交
17 18 19 20 21 22 23 24
            VLOG(0) << "dataset initialize failed, name:" << name;
            return -1;
        }
        _data_containers[name] = data_ptr;
    }
    return 0;
}

X
xiexionghang 已提交
25 26 27 28 29 30
inline void Dataset::pre_detect_data(uint64_t epoch_id) {
    for (auto it = _data_containers.begin(); it != _data_containers.end(); ++it) {
        it->second->pre_detect_data(epoch_id);
    }
    return;
}
X
xiexionghang 已提交
31 32 33 34 35 36
inline void Dataset::pre_detect_data(
    const std::string& data_name, uint64_t epoch_id) {
    _data_containers[data_name]->pre_detect_data(epoch_id);
    return;
}

X
xiexionghang 已提交
37 38 39 40 41 42 43 44 45
inline DatasetStatus Dataset::epoch_data_status(uint64_t epoch_id) {
    int status = static_cast<int>(DatasetStatus::Ready);
    for (auto it = _data_containers.begin(); it != _data_containers.end(); ++it) {
        auto d_status = static_cast<int>(it->second->epoch_data_status(epoch_id));
        status = d_status < status ? d_status : status;
    }
    return static_cast<DatasetStatus>(status);
}

X
xiexionghang 已提交
46 47 48 49 50 51 52 53 54 55
inline DatasetStatus Dataset::epoch_data_status(
    const std::string& data_name, uint64_t epoch_id) {
    return _data_containers[data_name]->epoch_data_status(epoch_id);
}

inline ::paddle::framework::Channel<DataItem> Dataset::fetch_data(
    const std::string& data_name, uint64_t epoch_id) {
    return _data_containers[data_name]->fetch(epoch_id);
}

X
xiexionghang 已提交
56
inline const DataParser* Dataset::data_parser(const std::string& data_name) {
X
xiexionghang 已提交
57
    auto* data_container = _data_containers[data_name].get();
X
xiexionghang 已提交
58
    return data_container->data_parser();
X
xiexionghang 已提交
59 60 61 62 63 64
}
     

} // namespace feed
} // namespace custom_trainer
} // namespace paddle