data_reader.cc 8.1 KB
Newer Older
R
rensilin 已提交
1 2 3
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"

#include <cstdio>
R
openmp  
rensilin 已提交
4
#include <atomic>
R
rensilin 已提交
5 6

#include <glog/logging.h>
R
rensilin 已提交
7
#include <omp.h>
R
rensilin 已提交
8

R
rensilin 已提交
9
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
R
rensilin 已提交
10 11 12 13 14

namespace paddle {
namespace custom_trainer {
namespace feed {

R
rensilin 已提交
15
class LineDataParser : public DataParser {
R
rensilin 已提交
16 17 18 19 20 21 22 23 24 25 26
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
27
        while (pos < len && str[pos] != ' ') {
R
rensilin 已提交
28 29
            ++pos;
        }
R
rensilin 已提交
30
        if (pos >= len) {
R
rensilin 已提交
31
            VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
R
rensilin 已提交
32 33
            return -1;
        }
R
rensilin 已提交
34
        VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
R
rensilin 已提交
35 36 37 38 39 40 41
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1, len - pos - 1);
        return 0;
    }

    virtual int parse(const char* str, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
42
        while (str[pos] != '\0' && str[pos] != ' ') {
R
rensilin 已提交
43 44
            ++pos;
        }
R
rensilin 已提交
45
        if (str[pos] == '\0') {
R
rensilin 已提交
46
            VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos;
R
rensilin 已提交
47 48
            return -1;
        }
R
rensilin 已提交
49
        VLOG(5) << "getline: " << str << " , pos: " << pos;
R
rensilin 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1);
        return 0;
    }

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
REGISTER_CLASS(DataParser, LineDataParser);

int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
    _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
    if (_parser == nullptr) {
        VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
        return -1;
    }
    if (_parser->initialize(config["parser"], context) != 0) {
        VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
        return -1;
    }
    _pipeline_cmd = config["pipeline_cmd"].as<std::string>();
    return 0;
}

class LineDataReader : public DataReader {
public:
    LineDataReader() {}
    virtual ~LineDataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        if (DataReader::initialize(config, context) != 0) {
            return -1;
        }
        _done_file_name = config["done_file"].as<std::string>();
R
rensilin 已提交
84
        _filename_prefix = config["filename_prefix"].as<std::string>("");
R
rensilin 已提交
85 86 87 88 89 90 91 92 93 94

        if (config["file_system"] && config["file_system"]["class"]) {
            _file_system.reset(
                    CREATE_CLASS(FileSystem, config["file_system"]["class"].as<std::string>()));
            if (_file_system == nullptr ||
                _file_system->initialize(config["file_system"], context) != 0) {
                VLOG(2) << "fail to create class: "
                        << config["file_system"]["class"].as<std::string>();
                return -1;
            }
X
xiexionghang 已提交
95 96
        } else if (context->file_system != nullptr) { 
            _file_system = context->file_system;
R
rensilin 已提交
97 98 99 100 101 102 103
        } else {
            _file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
            if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
                VLOG(2) << "fail to init file system";
                return -1;
            }
        }
R
rensilin 已提交
104 105 106 107 108
        return 0;
    }

    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) {
R
rensilin 已提交
109 110
        auto done_file_path = _file_system->path_join(data_dir, _done_file_name);
        if (_file_system->exists(done_file_path)) {
R
rensilin 已提交
111 112 113 114 115
            return true;
        }
        return false;
    }

R
rensilin 已提交
116 117
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
        std::vector<std::string> data_files;
R
rensilin 已提交
118 119
        for (auto& filepath : _file_system->list(data_dir)) {
            auto filename = _file_system->path_split(filepath).second;
R
rensilin 已提交
120 121
            if (filename != _done_file_name &&
                string::begin_with(filename, _filename_prefix)) {
R
rensilin 已提交
122 123 124 125 126 127
                data_files.push_back(std::move(filepath));
            }
        }
        return data_files;
    }

R
rensilin 已提交
128
    //读取数据样本流中
R
rensilin 已提交
129
    virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
X
xiexionghang 已提交
130 131 132 133
        auto file_list = data_file_list(data_dir);
        return read_all(file_list, data_channel);
    }
    virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) {
R
fs_bug  
rensilin 已提交
134
        const int file_list_size = file_list.size();
R
openmp  
rensilin 已提交
135
        std::atomic<bool> is_failed(false);
R
rensilin 已提交
136

R
fs_bug  
rensilin 已提交
137 138 139 140 141 142 143
        const int max_threads = omp_get_max_threads();
        std::vector<framework::ChannelWriter<DataItem>> writers; // writer is not thread safe
        writers.reserve(max_threads);
        for (int i = 0; i < max_threads; ++i) {
            writers.emplace_back(data_channel.get());
        }
        VLOG(5) << "file_list: " << string::join_strings(file_list, ' ');
R
rensilin 已提交
144 145
        #pragma omp parallel for
        for (int i = 0; i < file_list_size; ++i) {
R
fs_bug  
rensilin 已提交
146 147 148 149 150 151 152 153
            if (is_failed) {
                continue;
            }
            const int thread_num = omp_get_thread_num();
            framework::ChannelWriter<DataItem> *writer = nullptr;
            if (thread_num < max_threads) {
                writer = &writers[thread_num];
            }
R
rensilin 已提交
154
            const auto& filepath = file_list[i];
R
fs_bug  
rensilin 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
            std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
            if (fin == nullptr) {
                VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
                is_failed = true;
                continue;
            }
            char *buffer = nullptr;
            size_t buffer_size = 0;
            ssize_t line_len = 0;
            while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
                // 去掉行位回车
                if (line_len > 0 && buffer[line_len - 1] == '\n') {
                    buffer[--line_len] = '\0';
                }
                // 忽略空行
                if (line_len <= 0) {
R
openmp  
rensilin 已提交
171
                    continue;
R
rensilin 已提交
172
                }
R
fix  
rensilin 已提交
173
                DataItem data_item;
R
fs_bug  
rensilin 已提交
174 175 176 177 178 179 180
                if (_parser->parse(buffer, line_len, data_item) == 0) {
                    VLOG(5) << "parse data: " << data_item.id << " " << data_item.data << ", filename: " << filepath << ", thread_num: " << thread_num << ", max_threads: " << max_threads;
                    if (writer == nullptr) {
                        if (!data_channel->Put(std::move(data_item))) {
                            VLOG(2) << "fail to put data, thread_num: " << thread_num;
                        }
                    } else {
R
rensilin 已提交
181
                        (*writer) << std::move(data_item);
R
rensilin 已提交
182
                    }
R
rensilin 已提交
183
                }
R
fs_bug  
rensilin 已提交
184 185 186 187 188 189 190 191 192 193
            }
            if (buffer != nullptr) {
                free(buffer);
                buffer = nullptr;
                buffer_size = 0;
            }
            if (ferror(fin.get()) != 0) {
                VLOG(2) << "fail to read file: " << filepath;
                is_failed = true;
                continue;
R
rensilin 已提交
194
            }
R
fix  
rensilin 已提交
195
            if (_file_system->err_no() != 0) {
R
rensilin 已提交
196
                _file_system->reset_err_no();
R
openmp  
rensilin 已提交
197 198
                is_failed = true;
                continue;
R
rensilin 已提交
199 200
            }
        }
R
fs_bug  
rensilin 已提交
201 202 203 204 205 206 207 208
        // omp end

        for (int i = 0; i < max_threads; ++i) {
            writers[i].Flush();
            if (!writers[i]) {
                VLOG(2) << "writer " << i << " is failed";
                is_failed = true;
            }
R
rensilin 已提交
209
        }
R
rensilin 已提交
210
        data_channel->Close();
R
openmp  
rensilin 已提交
211
        return is_failed ? -1 : 0;
R
rensilin 已提交
212 213 214 215 216
    }

    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
217

R
rensilin 已提交
218
private:
R
openmp  
rensilin 已提交
219
    std::string _done_file_name;  // without data_dir
R
rensilin 已提交
220
    std::string _filename_prefix;
X
xiexionghang 已提交
221
    std::shared_ptr<FileSystem> _file_system;
R
rensilin 已提交
222 223 224
};
REGISTER_CLASS(DataReader, LineDataReader);

R
rensilin 已提交
225 226 227
}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle