data_reader.cc 7.2 KB
Newer Older
R
rensilin 已提交
1 2 3 4 5
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"

#include <cstdio>

#include <glog/logging.h>
R
rensilin 已提交
6
#include <omp.h>
R
rensilin 已提交
7

R
rensilin 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
R
rensilin 已提交
9 10 11 12 13

namespace paddle {
namespace custom_trainer {
namespace feed {

R
rensilin 已提交
14
class LineDataParser : public DataParser {
R
rensilin 已提交
15 16 17 18 19 20 21 22 23 24 25
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
26
        while (pos < len && str[pos] != ' ') {
R
rensilin 已提交
27 28
            ++pos;
        }
R
rensilin 已提交
29
        if (pos >= len) {
R
rensilin 已提交
30
            VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
R
rensilin 已提交
31 32
            return -1;
        }
R
rensilin 已提交
33
        VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
R
rensilin 已提交
34 35 36 37 38 39 40
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1, len - pos - 1);
        return 0;
    }

    virtual int parse(const char* str, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
41
        while (str[pos] != '\0' && str[pos] != ' ') {
R
rensilin 已提交
42 43
            ++pos;
        }
R
rensilin 已提交
44
        if (str[pos] == '\0') {
R
rensilin 已提交
45
            VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos;
R
rensilin 已提交
46 47
            return -1;
        }
R
rensilin 已提交
48
        VLOG(5) << "getline: " << str << " , pos: " << pos;
R
rensilin 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1);
        return 0;
    }

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
REGISTER_CLASS(DataParser, LineDataParser);

int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
    _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
    if (_parser == nullptr) {
        VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
        return -1;
    }
    if (_parser->initialize(config["parser"], context) != 0) {
        VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
        return -1;
    }
    _pipeline_cmd = config["pipeline_cmd"].as<std::string>();
    return 0;
}

class LineDataReader : public DataReader {
public:
    LineDataReader() {}
    virtual ~LineDataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        if (DataReader::initialize(config, context) != 0) {
            return -1;
        }
        _done_file_name = config["done_file"].as<std::string>();
R
rensilin 已提交
83
        _filename_prefix = config["filename_prefix"].as<std::string>("");
R
rensilin 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

        if (config["file_system"] && config["file_system"]["class"]) {
            _file_system.reset(
                    CREATE_CLASS(FileSystem, config["file_system"]["class"].as<std::string>()));
            if (_file_system == nullptr ||
                _file_system->initialize(config["file_system"], context) != 0) {
                VLOG(2) << "fail to create class: "
                        << config["file_system"]["class"].as<std::string>();
                return -1;
            }
        } else {
            _file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
            if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
                VLOG(2) << "fail to init file system";
                return -1;
            }
        }
R
rensilin 已提交
101 102 103 104 105
        return 0;
    }

    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) {
R
rensilin 已提交
106 107
        auto done_file_path = _file_system->path_join(data_dir, _done_file_name);
        if (_file_system->exists(done_file_path)) {
R
rensilin 已提交
108 109 110 111 112
            return true;
        }
        return false;
    }

R
rensilin 已提交
113 114
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
        std::vector<std::string> data_files;
R
rensilin 已提交
115 116
        for (auto& filepath : _file_system->list(data_dir)) {
            auto filename = _file_system->path_split(filepath).second;
R
rensilin 已提交
117 118
            if (filename != _done_file_name &&
                string::begin_with(filename, _filename_prefix)) {
R
rensilin 已提交
119 120 121 122 123 124
                data_files.push_back(std::move(filepath));
            }
        }
        return data_files;
    }

R
rensilin 已提交
125
    //读取数据样本流中
R
rensilin 已提交
126
    virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
R
rensilin 已提交
127 128 129 130 131 132 133 134
        auto deleter = [](framework::ChannelWriter<DataItem> *writer) {
            if (writer) {
                writer->Flush();
                VLOG(3) << "writer auto flush";
            }
            delete writer;
        };
        std::unique_ptr<framework::ChannelWriter<DataItem>, decltype(deleter)> writer(new framework::ChannelWriter<DataItem>(data_channel.get()), deleter);
R
rensilin 已提交
135
        DataItem data_item;
R
rensilin 已提交
136 137 138 139 140 141 142 143

        auto file_list = data_file_list(data_dir);
        int file_list_size = file_list.size();

        VLOG(5) << "omg max_threads: " << omp_get_max_threads();
        #pragma omp parallel for
        for (int i = 0; i < file_list_size; ++i) {
            VLOG(5) << "omg num_threads: " << omp_get_num_threads() << ", start read: " << i << std::endl;
R
rensilin 已提交
144
        }
R
rensilin 已提交
145 146 147
        for (int i = 0; i < file_list_size; ++i) {
            //VLOG(5) << "omg num_threads: " << omp_get_num_threads() << ", start read: " << i;
            const auto& filepath = file_list[i];
R
rensilin 已提交
148 149 150 151 152 153
            {
                std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
                if (fin == nullptr) {
                    VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
                    return -1;
                }
R
rensilin 已提交
154 155 156 157 158 159 160 161
                char *buffer = nullptr;
                size_t buffer_size = 0;
                ssize_t line_len = 0;
                while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
                    if (line_len > 0 && buffer[line_len - 1] == '\n') {
                        buffer[--line_len] = '\0';
                    }
                    if (line_len <= 0) {
R
rensilin 已提交
162 163
                        continue;
                    }
R
rensilin 已提交
164 165
                    if (_parser->parse(buffer, line_len, data_item) == 0) {
                        (*writer) << std::move(data_item);
R
rensilin 已提交
166
                    }
R
rensilin 已提交
167 168 169 170 171
                }
                if (buffer != nullptr) {
                    free(buffer);
                    buffer = nullptr;
                    buffer_size = 0;
R
rensilin 已提交
172 173 174
                }
                if (ferror(fin.get()) != 0) {
                    VLOG(2) << "fail to read file: " << filepath;
R
rensilin 已提交
175 176 177
                    return -1;
                }
            }
R
fix  
rensilin 已提交
178
            if (_file_system->err_no() != 0) {
R
rensilin 已提交
179
                _file_system->reset_err_no();
R
rensilin 已提交
180 181 182
                return -1;
            }
        }
R
rensilin 已提交
183 184
        writer->Flush();
        if (!(*writer)) {
R
rensilin 已提交
185 186 187
            VLOG(2) << "fail when write to channel";
            return -1;
        }
R
rensilin 已提交
188
        data_channel->Close();
R
rensilin 已提交
189 190 191 192 193 194
        return 0;
    }

    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
195

R
rensilin 已提交
196
private:
R
rensilin 已提交
197
    std::string _done_file_name;  // without data_dirq
R
rensilin 已提交
198
    std::string _filename_prefix;
R
rensilin 已提交
199
    std::unique_ptr<FileSystem> _file_system;
R
rensilin 已提交
200 201 202
};
REGISTER_CLASS(DataReader, LineDataReader);

R
rensilin 已提交
203 204 205
}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle