dataset_container.cc 8.6 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9
/* DatasetContainer
 * 保存一个数据源的样本,并驱动样本的异步加载
 */
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/string/string_helper.h"
X
xiexionghang 已提交
10 11
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
R
rensilin 已提交
12
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
X
xiexionghang 已提交
13
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
X
xiexionghang 已提交
14
#include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h"
X
xiexionghang 已提交
15 16 17 18

namespace paddle {
namespace custom_trainer {
namespace feed {
X
xiexionghang 已提交
19 20 21 22 23 24 25 26

int DatasetContainer::initialize(
        const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
    _dataset_config = config;
    _trainer_context = context.get();
    //预取n轮样本数据
    _prefetch_num = config["prefetch_num"].as<int>();
    _dataset_list.resize(_prefetch_num);
X
xiexionghang 已提交
27 28 29
    for (int i = 0; i < _prefetch_num; ++i) {
        _dataset_list[i].reset(new DatasetInfo);
    }
X
xiexionghang 已提交
30

R
rensilin 已提交
31
    _data_root_paths = config["root_path"].as<std::vector<std::string>>();
X
xiexionghang 已提交
32 33
    _data_split_interval = config["data_spit_interval"].as<int>();
    _data_path_formater = config["data_path_formater"].as<std::string>();
X
xiexionghang 已提交
34 35 36
    std::string shuffler = config["shuffler"]["name"].as<std::string>();
    _shuffler.reset(CREATE_INSTANCE(Shuffler, shuffler));
    _shuffler->initialize(config, context);
X
xiexionghang 已提交
37
    std::string data_reader_class = config["data_reader"].as<std::string>();
X
xiexionghang 已提交
38
    DataReader* data_reader = CREATE_INSTANCE(DataReader, data_reader_class);
X
xiexionghang 已提交
39 40 41 42 43 44 45 46 47
    _data_reader.reset(data_reader);
    return _data_reader->initialize(config, context);
}   

std::shared_ptr<DatasetInfo> DatasetContainer::dataset(uint64_t timestamp) {
    auto* epoch_accessor = _trainer_context->epoch_accessor.get();
    auto data_idx = timestamp / epoch_accessor->epoch_time_interval();
    return _dataset_list[data_idx % _prefetch_num];
}
X
xiexionghang 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
std::vector<std::string> DatasetContainer::epoch_data_path(uint64_t epoch_id) {
    std::vector<std::string> results;
    auto* epoch_accessor = _trainer_context->epoch_accessor.get();
    time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id);
    size_t data_num = data_num_for_train(timestamp, epoch_accessor->epoch_time_interval(), _data_split_interval);
    uint64_t data_timestamp = timestamp % _data_split_interval == 0 ? timestamp : (timestamp / _data_split_interval + 1) * _data_split_interval;
    for (int i = 0; i < _data_root_paths.size(); ++i) {
        for (int j = 0; j < data_num; ++j) {
            std::string path_suffix = format_timestamp(data_timestamp + j * _data_split_interval, _data_path_formater);
            std::string data_dir = _trainer_context->file_system->path_join(_data_root_paths[i], path_suffix);
            results.emplace_back(data_dir);
        }
    }
    return results;
}
X
xiexionghang 已提交
63 64 65 66 67 68 69 70 71

void DatasetContainer::pre_detect_data(uint64_t epoch_id) {
    int status = 0;
    auto* epoch_accessor = _trainer_context->epoch_accessor.get();
    time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id);
    if (timestamp % epoch_accessor->epoch_time_interval() != 0) {
        LOG(FATAL) << "timestamp:" << timestamp << " don't match interval:" << epoch_accessor->epoch_time_interval();
        return;
    }
X
xiexionghang 已提交
72 73 74 75
    if (_downloader_thread == nullptr) {
        _downloader_thread.reset(new std::thread([this, timestamp](){
            async_download_data(timestamp);
        }));
X
xiexionghang 已提交
76
    }
X
xiexionghang 已提交
77
    for (int detect_idx = 0 ; detect_idx < _prefetch_num; ++detect_idx, ++epoch_id) {
X
xiexionghang 已提交
78 79 80 81 82 83 84 85 86
        if (DatasetStatus::Empty != data_status(timestamp)) {
            continue;
        }
        size_t data_num = data_num_for_train(timestamp, epoch_accessor->epoch_time_interval(), _data_split_interval);
        uint64_t data_timestamp = timestamp % _data_split_interval == 0 ? timestamp : (timestamp / _data_split_interval + 1) * _data_split_interval;
        std::vector<std::string> data_path_list;
        for (int i = 0; i < _data_root_paths.size() && status == 0; ++i) {
            for (int j = 0; j < data_num && status == 0; ++j) {
                std::string path_suffix = format_timestamp(data_timestamp + j * _data_split_interval, _data_path_formater);
R
rensilin 已提交
87
                std::string data_dir = _trainer_context->file_system->path_join(_data_root_paths[i], path_suffix);
X
xiexionghang 已提交
88 89 90 91 92 93 94 95
                status = read_data_list(data_dir, data_path_list);
            }
        }
        if (status == 0) {
            auto dataset_info = dataset(timestamp);
            dataset_info->timestamp = timestamp;
            dataset_info->file_path_list = std::move(data_path_list);
            dataset_info->status = DatasetStatus::Detected;
X
xiexionghang 已提交
96
            VLOG(2) << epoch_accessor->text(epoch_id) << ", data is detected";
X
xiexionghang 已提交
97 98
        }
        timestamp += epoch_accessor->epoch_time_interval();
X
xiexionghang 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
    }
    return;
}

int DatasetContainer::read_data_list(const std::string& data_dir, std::vector<std::string>& data_list) {
    auto* environment = _trainer_context->environment.get();
    
    // 检查数据Ready
    int data_status = -1;
    if (environment->is_master_node(EnvironmentRole::WORKER)) {
        if (_data_reader->is_data_ready(data_dir)) {
            data_status = 0;
        }
    }
    paddle::framework::BinaryArchive ar;
    ar << data_status; 
    environment->bcast(ar, 0, EnvironmentRole::WORKER);
    ar >> data_status;
    if (data_status != 0) {
        return -1;
    } 
    
    // 读取文件列表
    ar.Clear();
    std::vector<std::string> data_path_list;
    if (environment->is_master_node(EnvironmentRole::WORKER)) {
         data_path_list = _data_reader->data_file_list(data_dir);
        ar << data_path_list;
    }
    environment->bcast(ar, 0, EnvironmentRole::WORKER);
    ar >> data_path_list;
    auto worker_id = environment->rank_id(EnvironmentRole::WORKER);
    auto worker_num = environment->node_num(EnvironmentRole::WORKER); 
    for (int i = worker_id; i < data_path_list.size(); i+=worker_num) {
        data_list.push_back(data_path_list[i]);
    }
    environment->barrier(EnvironmentRole::WORKER);
    return 0;
}

DatasetStatus DatasetContainer::epoch_data_status(uint64_t epoch_id) {
    auto* epoch_accessor = _trainer_context->epoch_accessor.get();
    time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id);
    return data_status(timestamp);
}

DatasetStatus DatasetContainer::data_status(uint64_t timestamp) {
    auto dataset_info = dataset(timestamp);
    if (dataset_info->timestamp != timestamp) {
        return DatasetStatus::Empty;
    }
    return dataset_info->status;
}
X
xiexionghang 已提交
152
     
X
xiexionghang 已提交
153
paddle::framework::Channel<DataItem> DatasetContainer::fetch(uint64_t epoch_id) {
X
xiexionghang 已提交
154
    paddle::framework::Channel<DataItem> result;
X
xiexionghang 已提交
155 156 157
    auto* epoch_accessor = _trainer_context->epoch_accessor.get();
    time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id);
    if (data_status(timestamp) != DatasetStatus::Ready) {
X
xiexionghang 已提交
158 159
        return result;
    }
X
xiexionghang 已提交
160 161
    auto dataset_info = dataset(timestamp);
    return dataset_info->data_channel;
X
xiexionghang 已提交
162 163
}  

X
xiexionghang 已提交
164 165 166 167 168 169
void DatasetContainer::async_download_data(uint64_t start_timestamp) {
    auto* epoch_accessor = _trainer_context->epoch_accessor.get();
    if (start_timestamp % epoch_accessor->epoch_time_interval() != 0) {
        LOG(FATAL) << "timestamp:" << start_timestamp << " don't match interval:" << epoch_accessor->epoch_time_interval();
        return;
    }
X
xiexionghang 已提交
170
    while (!_stop_download) {
X
xiexionghang 已提交
171
        auto dataset_info = dataset(start_timestamp);
X
xiexionghang 已提交
172
        while (data_status(start_timestamp) == DatasetStatus::Empty) {
X
xiexionghang 已提交
173 174
            sleep(30);
        }
X
xiexionghang 已提交
175 176 177 178
        dataset_info->status = DatasetStatus::Downloding;

        VLOG(2) << "Start download data, data_timestap:" << start_timestamp
            << ", for epoch:" << epoch_accessor->text(start_timestamp);
X
xiexionghang 已提交
179 180 181 182
        const auto& file_list = dataset_info->file_path_list;
        dataset_info->data_channel->Clear();
        while (_data_reader->read_all(file_list, dataset_info->data_channel) != 0) {
            dataset_info->data_channel->Clear();
X
xiexionghang 已提交
183 184
            VLOG(0) << "Failed download data, data_timestap:" << start_timestamp
                << ", for epoch:" << epoch_accessor->text(start_timestamp) << ", Retry it";
X
xiexionghang 已提交
185 186
            sleep(30); 
        }
X
xiexionghang 已提交
187 188
        VLOG(2) << "End download data num:" << dataset_info->data_channel->Size()
            << ", data_timestap:" << start_timestamp
X
xiexionghang 已提交
189 190 191
            << ", for epoch:" << epoch_accessor->text(start_timestamp) << ", Start shuffle";
        _shuffler->shuffle(dataset_info->data_channel);
        VLOG(2) << "Shuffle done";
X
xiexionghang 已提交
192
        dataset_info->status = DatasetStatus::Ready;
X
xiexionghang 已提交
193
        start_timestamp += epoch_accessor->epoch_time_interval();
X
xiexionghang 已提交
194 195 196
    }
}

X
xiexionghang 已提交
197 198 199
} // namespace feed
} // namespace custom_trainer
} // namespace paddle