data_reader.cc 7.1 KB
Newer Older
R
rensilin 已提交
1 2 3 4 5 6
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"

#include <cstdio>

#include <glog/logging.h>

R
rensilin 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
R
rensilin 已提交
8 9 10 11 12

namespace paddle {
namespace custom_trainer {
namespace feed {

R
rensilin 已提交
13
class LineDataParser : public DataParser {
R
rensilin 已提交
14 15 16 17 18 19 20 21 22 23 24
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
25
        while (pos < len && str[pos] != ' ') {
R
rensilin 已提交
26 27
            ++pos;
        }
R
rensilin 已提交
28
        if (pos >= len) {
R
rensilin 已提交
29
            VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
R
rensilin 已提交
30 31
            return -1;
        }
R
rensilin 已提交
32
        VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
R
rensilin 已提交
33 34
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1, len - pos - 1);
R
rensilin 已提交
35 36 37
        if (!data.data.empty() && data.data.back() == '\n') {
            data.data.pop_back();
        }
R
rensilin 已提交
38 39 40 41 42
        return 0;
    }

    virtual int parse(const char* str, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
43
        while (str[pos] != '\0' && str[pos] != ' ') {
R
rensilin 已提交
44 45
            ++pos;
        }
R
rensilin 已提交
46
        if (str[pos] == '\0') {
R
rensilin 已提交
47
            VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos;
R
rensilin 已提交
48 49
            return -1;
        }
R
rensilin 已提交
50
        VLOG(5) << "getline: " << str << " , pos: " << pos;
R
rensilin 已提交
51 52
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1);
R
rensilin 已提交
53 54 55
        if (!data.data.empty() && data.data.back() == '\n') {
            data.data.pop_back();
        }
R
rensilin 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        return 0;
    }

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
REGISTER_CLASS(DataParser, LineDataParser);

int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
    _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
    if (_parser == nullptr) {
        VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
        return -1;
    }
    if (_parser->initialize(config["parser"], context) != 0) {
        VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
        return -1;
    }
    _pipeline_cmd = config["pipeline_cmd"].as<std::string>();
    return 0;
}

class LineDataReader : public DataReader {
public:
    LineDataReader() {}
    virtual ~LineDataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        if (DataReader::initialize(config, context) != 0) {
            return -1;
        }
        _done_file_name = config["done_file"].as<std::string>();
        _buffer_size = config["buffer_size"].as<int>(1024);
R
rensilin 已提交
89
        _filename_prefix = config["filename_prefix"].as<std::string>("");
R
rensilin 已提交
90
        _buffer.reset(new char[_buffer_size]);
R
rensilin 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

        if (config["file_system"] && config["file_system"]["class"]) {
            _file_system.reset(
                    CREATE_CLASS(FileSystem, config["file_system"]["class"].as<std::string>()));
            if (_file_system == nullptr ||
                _file_system->initialize(config["file_system"], context) != 0) {
                VLOG(2) << "fail to create class: "
                        << config["file_system"]["class"].as<std::string>();
                return -1;
            }
        } else {
            _file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
            if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
                VLOG(2) << "fail to init file system";
                return -1;
            }
        }
R
rensilin 已提交
108 109 110 111 112
        return 0;
    }

    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) {
R
rensilin 已提交
113 114
        auto done_file_path = _file_system->path_join(data_dir, _done_file_name);
        if (_file_system->exists(done_file_path)) {
R
rensilin 已提交
115 116 117 118 119
            return true;
        }
        return false;
    }

R
rensilin 已提交
120 121
    virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
        if (_filename_prefix.empty()) {
R
rensilin 已提交
122
            return _file_system->list(data_dir);
R
rensilin 已提交
123 124
        }
        std::vector<std::string> data_files;
R
rensilin 已提交
125 126 127 128
        for (auto& filepath : _file_system->list(data_dir)) {
            auto filename = _file_system->path_split(filepath).second;
            if (filename.size() >= _filename_prefix.size() &&
                filename.substr(0, _filename_prefix.size()) == _filename_prefix) {
R
rensilin 已提交
129 130 131 132 133 134
                data_files.push_back(std::move(filepath));
            }
        }
        return data_files;
    }

R
rensilin 已提交
135
    //读取数据样本流中
R
rensilin 已提交
136
    virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
R
rensilin 已提交
137 138 139 140 141 142 143 144
        auto deleter = [](framework::ChannelWriter<DataItem> *writer) {
            if (writer) {
                writer->Flush();
                VLOG(3) << "writer auto flush";
            }
            delete writer;
        };
        std::unique_ptr<framework::ChannelWriter<DataItem>, decltype(deleter)> writer(new framework::ChannelWriter<DataItem>(data_channel.get()), deleter);
R
rensilin 已提交
145 146 147 148 149
        DataItem data_item;
        if (_buffer_size <= 0 || _buffer == nullptr) {
            VLOG(2) << "no buffer";
            return -1;
        }
R
rensilin 已提交
150
        for (const auto& filepath : data_file_list(data_dir)) {
R
rensilin 已提交
151
            if (_file_system->path_split(filepath).second == _done_file_name) {
R
rensilin 已提交
152 153
                continue;
            }
R
rensilin 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
            {
                std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
                if (fin == nullptr) {
                    VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
                    return -1;
                }
                while (fgets(_buffer.get(), _buffer_size, fin.get())) {
                    if (_buffer[0] == '\n') {
                        continue;
                    }
                    if (_parser->parse(_buffer.get(), data_item) != 0) {
                        return -1;
                    }
                    (*writer) << std::move(data_item);
                }
                if (ferror(fin.get()) != 0) {
                    VLOG(2) << "fail to read file: " << filepath;
R
rensilin 已提交
171 172 173
                    return -1;
                }
            }
R
fix  
rensilin 已提交
174
            if (_file_system->err_no() != 0) {
R
rensilin 已提交
175
                _file_system->reset_err_no();
R
rensilin 已提交
176 177 178
                return -1;
            }
        }
R
rensilin 已提交
179 180
        writer->Flush();
        if (!(*writer)) {
R
rensilin 已提交
181 182 183
            VLOG(2) << "fail when write to channel";
            return -1;
        }
R
rensilin 已提交
184
        data_channel->Close();
R
rensilin 已提交
185 186 187 188 189 190
        return 0;
    }

    virtual const DataParser* get_parser() {
        return _parser.get();
    }
R
rensilin 已提交
191

R
rensilin 已提交
192
private:
R
rensilin 已提交
193
    std::string _done_file_name;  // without data_dir
R
rensilin 已提交
194 195
    int _buffer_size = 0;
    std::unique_ptr<char[]> _buffer;
R
rensilin 已提交
196
    std::string _filename_prefix;
R
rensilin 已提交
197
    std::unique_ptr<FileSystem> _file_system;
R
rensilin 已提交
198 199 200
};
REGISTER_CLASS(DataReader, LineDataReader);

R
rensilin 已提交
201 202 203
}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle