data_reader.cc 8.1 KB
Newer Older
R
rensilin 已提交
1 2 3
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"

#include <cstdio>
R
openmp  
rensilin 已提交
4
#include <atomic>
R
rensilin 已提交
5 6

#include <glog/logging.h>
R
rensilin 已提交
7
#include <omp.h>
R
rensilin 已提交
8

R
rensilin 已提交
9
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
R
rensilin 已提交
10 11 12 13 14

namespace paddle {
namespace custom_trainer {
namespace feed {

R
rensilin 已提交
15
class LineDataParser : public DataParser {
R
rensilin 已提交
16 17 18 19 20 21 22 23 24 25 26
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
27
        while (pos < len && str[pos] != ' ') {
R
rensilin 已提交
28 29
            ++pos;
        }
R
rensilin 已提交
30
        if (pos >= len) {
R
rensilin 已提交
31
            VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
R
rensilin 已提交
32 33
            return -1;
        }
R
rensilin 已提交
34
        VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
R
rensilin 已提交
35 36 37 38 39 40 41
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1, len - pos - 1);
        return 0;
    }

    virtual int parse(const char* str, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
42
        while (str[pos] != '\0' && str[pos] != ' ') {
R
rensilin 已提交
43 44
            ++pos;
        }
R
rensilin 已提交
45
        if (str[pos] == '\0') {
R
rensilin 已提交
46
            VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos;
R
rensilin 已提交
47 48
            return -1;
        }
R
rensilin 已提交
49
        VLOG(5) << "getline: " << str << " , pos: " << pos;
R
rensilin 已提交
50 51 52 53 54 55 56 57 58
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1);
        return 0;
    }

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
X
xiexionghang 已提交
59
REGIST_CLASS(DataParser, LineDataParser);
R
rensilin 已提交
60 61

int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
X
xiexionghang 已提交
62
    _parser.reset(CREATE_INSTANCE(DataParser, config["parser"]["class"].as<std::string>()));
R
rensilin 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    if (_parser == nullptr) {
        VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
        return -1;
    }
    if (_parser->initialize(config["parser"], context) != 0) {
        VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
        return -1;
    }
    _pipeline_cmd = config["pipeline_cmd"].as<std::string>();
    return 0;
}

class LineDataReader : public DataReader {
public:
    LineDataReader() {}
    virtual ~LineDataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        if (DataReader::initialize(config, context) != 0) {
            return -1;
        }
        _done_file_name = config["done_file"].as<std::string>();
R
rensilin 已提交
84
        _filename_prefix = config["filename_prefix"].as<std::string>("");
R
rensilin 已提交
85 86 87

        if (config["file_system"] && config["file_system"]["class"]) {
            _file_system.reset(
X
xiexionghang 已提交
88
                    CREATE_INSTANCE(FileSystem, config["file_system"]["class"].as<std::string>()));
R
rensilin 已提交
89 90 91 92 93 94
            if (_file_system == nullptr ||
                _file_system->initialize(config["file_system"], context) != 0) {
                VLOG(2) << "fail to create class: "
                        << config["file_system"]["class"].as<std::string>();
                return -1;
            }
X
xiexionghang 已提交
95 96
        } else if (context->file_system != nullptr) { 
            _file_system = context->file_system;
R
rensilin 已提交
97
        } else {
X
xiexionghang 已提交
98
            _file_system.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
R
rensilin 已提交
99 100 101 102 103
            if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
                VLOG(2) << "fail to init file system";
                return -1;
            }
        }
R
rensilin 已提交
104 105 106 107 108
        return 0;
    }

    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) {
R
rensilin 已提交
109 110
        auto done_file_path = _file_system->path_join(data_dir, _done_file_name);
        if (_file_system->exists(done_file_path)) {
R
rensilin 已提交
111 112 113 114 115
            return true;
        }
        return false;
    }

R
rensilin 已提交
116 117
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
        std::vector<std::string> data_files;
R
rensilin 已提交
118 119
        for (auto& filepath : _file_system->list(data_dir)) {
            auto filename = _file_system->path_split(filepath).second;
R
rensilin 已提交
120 121
            if (filename != _done_file_name &&
                string::begin_with(filename, _filename_prefix)) {
R
rensilin 已提交
122 123 124 125 126 127
                data_files.push_back(std::move(filepath));
            }
        }
        return data_files;
    }

R
rensilin 已提交
128
    //读取数据样本流中
R
rensilin 已提交
129
    virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
X
xiexionghang 已提交
130 131 132 133
        auto file_list = data_file_list(data_dir);
        return read_all(file_list, data_channel);
    }
    virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) {
R
fs_bug  
rensilin 已提交
134
        const int file_list_size = file_list.size();
R
openmp  
rensilin 已提交
135
        std::atomic<bool> is_failed(false);
R
rensilin 已提交
136

R
fs_bug  
rensilin 已提交
137 138 139 140 141 142 143
        const int max_threads = omp_get_max_threads();
        std::vector<framework::ChannelWriter<DataItem>> writers; // writer is not thread safe
        writers.reserve(max_threads);
        for (int i = 0; i < max_threads; ++i) {
            writers.emplace_back(data_channel.get());
        }
        VLOG(5) << "file_list: " << string::join_strings(file_list, ' ');
R
rensilin 已提交
144 145
        #pragma omp parallel for
        for (int i = 0; i < file_list_size; ++i) {
R
fs_bug  
rensilin 已提交
146 147 148 149 150 151 152 153
            if (is_failed) {
                continue;
            }
            const int thread_num = omp_get_thread_num();
            framework::ChannelWriter<DataItem> *writer = nullptr;
            if (thread_num < max_threads) {
                writer = &writers[thread_num];
            }
R
rensilin 已提交
154
            const auto& filepath = file_list[i];
R
fs_bug  
rensilin 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
            std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
            if (fin == nullptr) {
                VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
                is_failed = true;
                continue;
            }
            char *buffer = nullptr;
            size_t buffer_size = 0;
            ssize_t line_len = 0;
            while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
                // 去掉行位回车
                if (line_len > 0 && buffer[line_len - 1] == '\n') {
                    buffer[--line_len] = '\0';
                }
                // 忽略空行
                if (line_len <= 0) {
R
openmp  
rensilin 已提交
171
                    continue;
R
rensilin 已提交
172
                }
R
fix  
rensilin 已提交
173
                DataItem data_item;
R
fs_bug  
rensilin 已提交
174 175 176 177 178 179 180
                if (_parser->parse(buffer, line_len, data_item) == 0) {
                    VLOG(5) << "parse data: " << data_item.id << " " << data_item.data << ", filename: " << filepath << ", thread_num: " << thread_num << ", max_threads: " << max_threads;
                    if (writer == nullptr) {
                        if (!data_channel->Put(std::move(data_item))) {
                            VLOG(2) << "fail to put data, thread_num: " << thread_num;
                        }
                    } else {
R
rensilin 已提交
181
                        (*writer) << std::move(data_item);
R
rensilin 已提交
182
                    }
R
rensilin 已提交
183
                }
R
fs_bug  
rensilin 已提交
184 185 186 187 188 189 190 191 192 193
            }
            if (buffer != nullptr) {
                free(buffer);
                buffer = nullptr;
                buffer_size = 0;
            }
            if (ferror(fin.get()) != 0) {
                VLOG(2) << "fail to read file: " << filepath;
                is_failed = true;
                continue;
R
rensilin 已提交
194
            }
R
fix  
rensilin 已提交
195
            if (_file_system->err_no() != 0) {
R
rensilin 已提交
196
                _file_system->reset_err_no();
R
openmp  
rensilin 已提交
197 198
                is_failed = true;
                continue;
R
rensilin 已提交
199 200
            }
        }
R
fs_bug  
rensilin 已提交
201 202 203 204 205 206 207 208
        // omp end

        for (int i = 0; i < max_threads; ++i) {
            writers[i].Flush();
            if (!writers[i]) {
                VLOG(2) << "writer " << i << " is failed";
                is_failed = true;
            }
R
rensilin 已提交
209
        }
R
rensilin 已提交
210
        data_channel->Close();
R
openmp  
rensilin 已提交
211
        return is_failed ? -1 : 0;
R
rensilin 已提交
212 213 214 215 216
    }

    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
217

R
rensilin 已提交
218
private:
R
openmp  
rensilin 已提交
219
    std::string _done_file_name;  // without data_dir
R
rensilin 已提交
220
    std::string _filename_prefix;
X
xiexionghang 已提交
221
    std::shared_ptr<FileSystem> _file_system;
R
rensilin 已提交
222
};
X
xiexionghang 已提交
223
REGIST_CLASS(DataReader, LineDataReader);
R
rensilin 已提交
224

R
rensilin 已提交
225 226 227
}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle