data_reader.cc 7.2 KB
Newer Older
R
rensilin 已提交
1 2 3
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"

#include <cstdio>
R
openmp  
rensilin 已提交
4
#include <atomic>
R
rensilin 已提交
5 6

#include <glog/logging.h>
R
rensilin 已提交
7
#include <omp.h>
R
rensilin 已提交
8

R
rensilin 已提交
9
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
R
rensilin 已提交
10 11 12 13 14

namespace paddle {
namespace custom_trainer {
namespace feed {

X
xiexionghang 已提交
15 16 17 18
int LineDataParser::parse(const char* str, size_t len, DataItem& data) const {
    size_t pos = 0;
    while (pos < len && str[pos] != ' ') {
        ++pos;
R
rensilin 已提交
19
    }
X
xiexionghang 已提交
20 21 22
    if (pos >= len) {
        VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
        return -1;
R
rensilin 已提交
23
    }
X
xiexionghang 已提交
24 25 26 27 28 29
    VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
    data.id.assign(str, pos);
    data.data.assign(str + pos + 1, len - pos - 1);
    return 0;
}
REGIST_CLASS(DataParser, LineDataParser);
R
rensilin 已提交
30 31

int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
X
xiexionghang 已提交
32
    _parser.reset(CREATE_INSTANCE(DataParser, config["parser"]["class"].as<std::string>()));
R
rensilin 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
    if (_parser == nullptr) {
        VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
        return -1;
    }
    if (_parser->initialize(config["parser"], context) != 0) {
        VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
        return -1;
    }
    _pipeline_cmd = config["pipeline_cmd"].as<std::string>();
    return 0;
}

class LineDataReader : public DataReader {
public:
    LineDataReader() {}
    virtual ~LineDataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        if (DataReader::initialize(config, context) != 0) {
            return -1;
        }
        _done_file_name = config["done_file"].as<std::string>();
R
rensilin 已提交
54
        _filename_prefix = config["filename_prefix"].as<std::string>("");
R
rensilin 已提交
55 56 57

        if (config["file_system"] && config["file_system"]["class"]) {
            _file_system.reset(
X
xiexionghang 已提交
58
                    CREATE_INSTANCE(FileSystem, config["file_system"]["class"].as<std::string>()));
R
rensilin 已提交
59 60 61 62 63 64
            if (_file_system == nullptr ||
                _file_system->initialize(config["file_system"], context) != 0) {
                VLOG(2) << "fail to create class: "
                        << config["file_system"]["class"].as<std::string>();
                return -1;
            }
X
xiexionghang 已提交
65 66
        } else if (context->file_system != nullptr) { 
            _file_system = context->file_system;
R
rensilin 已提交
67
        } else {
X
xiexionghang 已提交
68
            _file_system.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
R
rensilin 已提交
69 70 71 72 73
            if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
                VLOG(2) << "fail to init file system";
                return -1;
            }
        }
R
rensilin 已提交
74 75 76 77 78
        return 0;
    }

    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) {
R
rensilin 已提交
79 80
        auto done_file_path = _file_system->path_join(data_dir, _done_file_name);
        if (_file_system->exists(done_file_path)) {
R
rensilin 已提交
81 82 83 84 85
            return true;
        }
        return false;
    }

R
rensilin 已提交
86 87
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
        std::vector<std::string> data_files;
R
rensilin 已提交
88 89
        for (auto& filepath : _file_system->list(data_dir)) {
            auto filename = _file_system->path_split(filepath).second;
R
rensilin 已提交
90 91
            if (filename != _done_file_name &&
                string::begin_with(filename, _filename_prefix)) {
R
rensilin 已提交
92 93 94 95 96 97
                data_files.push_back(std::move(filepath));
            }
        }
        return data_files;
    }

R
rensilin 已提交
98
    //读取数据样本流中
R
rensilin 已提交
99
    virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
X
xiexionghang 已提交
100 101 102 103
        auto file_list = data_file_list(data_dir);
        return read_all(file_list, data_channel);
    }
    virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) {
104
        data_channel->Open();
R
fs_bug  
rensilin 已提交
105
        const int file_list_size = file_list.size();
R
openmp  
rensilin 已提交
106
        std::atomic<bool> is_failed(false);
R
rensilin 已提交
107

R
fs_bug  
rensilin 已提交
108 109 110 111 112 113 114
        const int max_threads = omp_get_max_threads();
        std::vector<framework::ChannelWriter<DataItem>> writers; // writer is not thread safe
        writers.reserve(max_threads);
        for (int i = 0; i < max_threads; ++i) {
            writers.emplace_back(data_channel.get());
        }
        VLOG(5) << "file_list: " << string::join_strings(file_list, ' ');
R
rensilin 已提交
115 116
        #pragma omp parallel for
        for (int i = 0; i < file_list_size; ++i) {
R
fs_bug  
rensilin 已提交
117 118 119 120 121 122 123 124
            if (is_failed) {
                continue;
            }
            const int thread_num = omp_get_thread_num();
            framework::ChannelWriter<DataItem> *writer = nullptr;
            if (thread_num < max_threads) {
                writer = &writers[thread_num];
            }
R
rensilin 已提交
125
            const auto& filepath = file_list[i];
R
fs_bug  
rensilin 已提交
126 127 128 129 130 131 132 133 134 135
            std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
            if (fin == nullptr) {
                VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
                is_failed = true;
                continue;
            }
            char *buffer = nullptr;
            size_t buffer_size = 0;
            ssize_t line_len = 0;
            while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
R
rensilin 已提交
136
                // 去掉行尾回车
R
fs_bug  
rensilin 已提交
137 138 139 140 141
                if (line_len > 0 && buffer[line_len - 1] == '\n') {
                    buffer[--line_len] = '\0';
                }
                // 忽略空行
                if (line_len <= 0) {
R
openmp  
rensilin 已提交
142
                    continue;
R
rensilin 已提交
143
                }
R
fix  
rensilin 已提交
144
                DataItem data_item;
R
fs_bug  
rensilin 已提交
145 146 147 148
                if (_parser->parse(buffer, line_len, data_item) == 0) {
                    VLOG(5) << "parse data: " << data_item.id << " " << data_item.data << ", filename: " << filepath << ", thread_num: " << thread_num << ", max_threads: " << max_threads;
                    if (writer == nullptr) {
                        if (!data_channel->Put(std::move(data_item))) {
R
rensilin 已提交
149 150
                            LOG(WARNING) << "fail to put data, thread_num: " << thread_num;
                            is_failed = true;
R
fs_bug  
rensilin 已提交
151 152
                        }
                    } else {
R
rensilin 已提交
153
                        (*writer) << std::move(data_item);
R
rensilin 已提交
154
                    }
R
rensilin 已提交
155
                }
R
fs_bug  
rensilin 已提交
156 157 158 159 160 161 162 163 164 165
            }
            if (buffer != nullptr) {
                free(buffer);
                buffer = nullptr;
                buffer_size = 0;
            }
            if (ferror(fin.get()) != 0) {
                VLOG(2) << "fail to read file: " << filepath;
                is_failed = true;
                continue;
R
rensilin 已提交
166 167
            }
        }
R
fs_bug  
rensilin 已提交
168 169 170 171 172 173 174 175
        // omp end

        for (int i = 0; i < max_threads; ++i) {
            writers[i].Flush();
            if (!writers[i]) {
                VLOG(2) << "writer " << i << " is failed";
                is_failed = true;
            }
R
rensilin 已提交
176
        }
R
rensilin 已提交
177
        data_channel->Close();
R
openmp  
rensilin 已提交
178
        return is_failed ? -1 : 0;
R
rensilin 已提交
179 180 181 182 183
    }

    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
184

R
rensilin 已提交
185
private:
R
openmp  
rensilin 已提交
186
    std::string _done_file_name;  // without data_dir
R
rensilin 已提交
187
    std::string _filename_prefix;
X
xiexionghang 已提交
188
    std::shared_ptr<FileSystem> _file_system;
R
rensilin 已提交
189
};
X
xiexionghang 已提交
190
REGIST_CLASS(DataReader, LineDataReader);
R
rensilin 已提交
191

R
rensilin 已提交
192 193 194
}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle