data_reader.cc 5.0 KB
Newer Older
R
rensilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"

#include <cstdio>

#include <glog/logging.h>

#include "paddle/fluid/framework/io/fs.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

class LineDataParser : public DataParser{
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
25
        while (pos < len && str[pos] != ' ') {
R
rensilin 已提交
26 27
            ++pos;
        }
R
rensilin 已提交
28
        if (pos >= len) {
R
rensilin 已提交
29
            VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
R
rensilin 已提交
30 31 32
            return -1;
        }
        VLOG(5) << "getline: "  << str << " , pos: " << pos << ", len: " << len;
R
rensilin 已提交
33 34
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1, len - pos - 1);
R
rensilin 已提交
35 36 37
        if (!data.data.empty() && data.data.back() == '\n') {
            data.data.pop_back();
        }
R
rensilin 已提交
38 39 40 41 42
        return 0;
    }

    virtual int parse(const char* str, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
43
        while (str[pos] != '\0' && str[pos] != ' ') {
R
rensilin 已提交
44 45
            ++pos;
        }
R
rensilin 已提交
46
        if (str[pos] == '\0') {
R
rensilin 已提交
47
            VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos;
R
rensilin 已提交
48 49 50
            return -1;
        }
        VLOG(5) << "getline: "  << str << " , pos: " << pos;
R
rensilin 已提交
51 52
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1);
R
rensilin 已提交
53 54 55
        if (!data.data.empty() && data.data.back() == '\n') {
            data.data.pop_back();
        }
R
rensilin 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        return 0;
    }

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
REGISTER_CLASS(DataParser, LineDataParser);

int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
    _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
    if (_parser == nullptr) {
        VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
        return -1;
    }
    if (_parser->initialize(config["parser"], context) != 0) {
        VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
        return -1;
    }
    _pipeline_cmd = config["pipeline_cmd"].as<std::string>();
    return 0;
}

class LineDataReader : public DataReader {
public:
    LineDataReader() {}
    virtual ~LineDataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        if (DataReader::initialize(config, context) != 0) {
            return -1;
        }
        _done_file_name = config["done_file"].as<std::string>();
        _buffer_size = config["buffer_size"].as<int>(1024);
        _buffer.reset(new char[_buffer_size]);
        return 0;
    }

    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) {
        auto done_file_path = ::paddle::framework::fs_path_join(data_dir, _done_file_name);
        if (::paddle::framework::fs_exists(done_file_path)) {
            return true;
        }
        return false;
    }

    //读取数据样本流中
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) {
        ::paddle::framework::ChannelWriter<DataItem> writer(data_channel.get());
        DataItem data_item;
        if (_buffer_size <= 0 || _buffer == nullptr) {
            VLOG(2) << "no buffer";
            return -1;
        }
        for (const auto& filename : ::paddle::framework::fs_list(data_dir)) {
            if (::paddle::framework::fs_path_split(filename).second == _done_file_name) {
                continue;
            }
R
rensilin 已提交
114
            int err_no = 0;
R
rensilin 已提交
115 116
            std::shared_ptr<FILE> fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd);
            if (err_no != 0) {
R
rensilin 已提交
117
                VLOG(2) << "fail to open file: " << filename << ", with cmd: " << _pipeline_cmd;
R
rensilin 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
                return -1;
            }
            while (fgets(_buffer.get(), _buffer_size, fin.get())) {
                if (_parser->parse(_buffer.get(), data_item) != 0) {
                    return -1;
                }
                writer << std::move(data_item);
            }
            if (ferror(fin.get()) != 0) {
                VLOG(2) << "fail to read file: " << filename;
                return -1;
            }
        }
        writer.Flush();
        if (!writer) {
            VLOG(2) << "fail when write to channel";
            return -1;
        }
R
rensilin 已提交
136
        data_channel->Close();
R
rensilin 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
        return 0;
    }

    virtual const DataParser* get_parser() {
        return _parser.get();
    }
private:
    std::string _done_file_name; // without data_dir
    int _buffer_size = 0;
    std::unique_ptr<char[]> _buffer;
};
REGISTER_CLASS(DataReader, LineDataReader);

}//namespace feed
}//namespace custom_trainer
}//namespace paddle