data_reader.cc 7.4 KB
Newer Older
R
rensilin 已提交
1 2 3
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"

#include <cstdio>
R
openmp  
rensilin 已提交
4
#include <atomic>
R
rensilin 已提交
5 6

#include <glog/logging.h>
R
rensilin 已提交
7
#include <omp.h>
R
rensilin 已提交
8

R
rensilin 已提交
9
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
R
rensilin 已提交
10 11 12 13 14

namespace paddle {
namespace custom_trainer {
namespace feed {

R
rensilin 已提交
15
class LineDataParser : public DataParser {
R
rensilin 已提交
16 17 18 19 20 21 22 23 24 25 26
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
27
        while (pos < len && str[pos] != ' ') {
R
rensilin 已提交
28 29
            ++pos;
        }
R
rensilin 已提交
30
        if (pos >= len) {
R
rensilin 已提交
31
            VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
R
rensilin 已提交
32 33
            return -1;
        }
R
rensilin 已提交
34
        VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
R
rensilin 已提交
35 36 37 38 39 40 41
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1, len - pos - 1);
        return 0;
    }

    virtual int parse(const char* str, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
42
        while (str[pos] != '\0' && str[pos] != ' ') {
R
rensilin 已提交
43 44
            ++pos;
        }
R
rensilin 已提交
45
        if (str[pos] == '\0') {
R
rensilin 已提交
46
            VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos;
R
rensilin 已提交
47 48
            return -1;
        }
R
rensilin 已提交
49
        VLOG(5) << "getline: " << str << " , pos: " << pos;
R
rensilin 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1);
        return 0;
    }

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
REGISTER_CLASS(DataParser, LineDataParser);

int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
    _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
    if (_parser == nullptr) {
        VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
        return -1;
    }
    if (_parser->initialize(config["parser"], context) != 0) {
        VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
        return -1;
    }
    _pipeline_cmd = config["pipeline_cmd"].as<std::string>();
    return 0;
}

class LineDataReader : public DataReader {
public:
    LineDataReader() {}
    virtual ~LineDataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        if (DataReader::initialize(config, context) != 0) {
            return -1;
        }
        _done_file_name = config["done_file"].as<std::string>();
R
rensilin 已提交
84
        _filename_prefix = config["filename_prefix"].as<std::string>("");
R
rensilin 已提交
85 86 87 88 89 90 91 92 93 94

        if (config["file_system"] && config["file_system"]["class"]) {
            _file_system.reset(
                    CREATE_CLASS(FileSystem, config["file_system"]["class"].as<std::string>()));
            if (_file_system == nullptr ||
                _file_system->initialize(config["file_system"], context) != 0) {
                VLOG(2) << "fail to create class: "
                        << config["file_system"]["class"].as<std::string>();
                return -1;
            }
X
xiexionghang 已提交
95 96
        } else if (context->file_system != nullptr) { 
            _file_system = context->file_system;
R
rensilin 已提交
97 98 99 100 101 102 103
        } else {
            _file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
            if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
                VLOG(2) << "fail to init file system";
                return -1;
            }
        }
R
rensilin 已提交
104 105 106 107 108
        return 0;
    }

    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) {
R
rensilin 已提交
109 110
        auto done_file_path = _file_system->path_join(data_dir, _done_file_name);
        if (_file_system->exists(done_file_path)) {
R
rensilin 已提交
111 112 113 114 115
            return true;
        }
        return false;
    }

R
rensilin 已提交
116 117
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
        std::vector<std::string> data_files;
R
rensilin 已提交
118 119
        for (auto& filepath : _file_system->list(data_dir)) {
            auto filename = _file_system->path_split(filepath).second;
R
rensilin 已提交
120 121
            if (filename != _done_file_name &&
                string::begin_with(filename, _filename_prefix)) {
R
rensilin 已提交
122 123 124 125 126 127
                data_files.push_back(std::move(filepath));
            }
        }
        return data_files;
    }

R
rensilin 已提交
128
    //读取数据样本流中
R
rensilin 已提交
129
    virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
X
xiexionghang 已提交
130 131 132 133
        auto file_list = data_file_list(data_dir);
        return read_all(file_list, data_channel);
    }
    virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) {
R
rensilin 已提交
134 135 136 137 138 139 140 141
        auto deleter = [](framework::ChannelWriter<DataItem> *writer) {
            if (writer) {
                writer->Flush();
                VLOG(3) << "writer auto flush";
            }
            delete writer;
        };
        std::unique_ptr<framework::ChannelWriter<DataItem>, decltype(deleter)> writer(new framework::ChannelWriter<DataItem>(data_channel.get()), deleter);
R
rensilin 已提交
142
        DataItem data_item;
R
rensilin 已提交
143 144

        int file_list_size = file_list.size();
R
openmp  
rensilin 已提交
145
        std::atomic<bool> is_failed(false);
R
rensilin 已提交
146 147 148 149

        #pragma omp parallel for
        for (int i = 0; i < file_list_size; ++i) {
            const auto& filepath = file_list[i];
R
openmp  
rensilin 已提交
150
            if (!is_failed) {
R
rensilin 已提交
151 152 153
                std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
                if (fin == nullptr) {
                    VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
R
openmp  
rensilin 已提交
154 155
                    is_failed = true;
                    continue;
R
rensilin 已提交
156
                }
R
rensilin 已提交
157 158 159 160 161 162 163 164
                char *buffer = nullptr;
                size_t buffer_size = 0;
                ssize_t line_len = 0;
                while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
                    if (line_len > 0 && buffer[line_len - 1] == '\n') {
                        buffer[--line_len] = '\0';
                    }
                    if (line_len <= 0) {
R
rensilin 已提交
165 166
                        continue;
                    }
R
rensilin 已提交
167 168
                    if (_parser->parse(buffer, line_len, data_item) == 0) {
                        (*writer) << std::move(data_item);
R
rensilin 已提交
169
                    }
R
rensilin 已提交
170 171 172 173 174
                }
                if (buffer != nullptr) {
                    free(buffer);
                    buffer = nullptr;
                    buffer_size = 0;
R
rensilin 已提交
175 176 177
                }
                if (ferror(fin.get()) != 0) {
                    VLOG(2) << "fail to read file: " << filepath;
R
openmp  
rensilin 已提交
178 179
                    is_failed = true;
                    continue;
R
rensilin 已提交
180 181
                }
            }
R
fix  
rensilin 已提交
182
            if (_file_system->err_no() != 0) {
R
rensilin 已提交
183
                _file_system->reset_err_no();
R
openmp  
rensilin 已提交
184 185
                is_failed = true;
                continue;
R
rensilin 已提交
186 187
            }
        }
R
rensilin 已提交
188 189
        writer->Flush();
        if (!(*writer)) {
R
rensilin 已提交
190
            VLOG(2) << "fail when write to channel";
R
openmp  
rensilin 已提交
191
            is_failed = true;
R
rensilin 已提交
192
        }
R
rensilin 已提交
193
        data_channel->Close();
R
openmp  
rensilin 已提交
194
        return is_failed ? -1 : 0;
R
rensilin 已提交
195 196 197 198 199
    }

    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
200

R
rensilin 已提交
201
private:
R
openmp  
rensilin 已提交
202
    std::string _done_file_name;  // without data_dir
R
rensilin 已提交
203
    std::string _filename_prefix;
X
xiexionghang 已提交
204
    std::shared_ptr<FileSystem> _file_system;
R
rensilin 已提交
205 206 207
};
REGISTER_CLASS(DataReader, LineDataReader);

R
rensilin 已提交
208 209 210
}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle