data_reader.cc 4.7 KB
Newer Older
R
rensilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"

#include <cstdio>

#include <glog/logging.h>

#include "paddle/fluid/framework/io/fs.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

class LineDataParser : public DataParser{
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
25
        while (pos < len && str[pos] != ' ') {
R
rensilin 已提交
26 27
            ++pos;
        }
R
rensilin 已提交
28 29 30 31 32
        if (pos >= len) {
            VLOG(2) << "fail to parse line" << std::string(str, len) << ", strlen: " << len;
            return -1;
        }
        VLOG(5) << "getline: "  << str << " , pos: " << pos << ", len: " << len;
R
rensilin 已提交
33 34 35 36 37 38 39
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1, len - pos - 1);
        return 0;
    }

    virtual int parse(const char* str, DataItem& data) const {
        size_t pos = 0;
R
rensilin 已提交
40
        while (str[pos] != '\0' && str[pos] != ' ') {
R
rensilin 已提交
41 42
            ++pos;
        }
R
rensilin 已提交
43 44 45 46 47
        if (str[pos] == '\0') {
            VLOG(2) << "fail to parse line" << str << ", get '\\0' at pos: " << pos;
            return -1;
        }
        VLOG(5) << "getline: "  << str << " , pos: " << pos;
R
rensilin 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1);
        return 0;
    }

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
REGISTER_CLASS(DataParser, LineDataParser);

int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
    _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
    if (_parser == nullptr) {
        VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
        return -1;
    }
    if (_parser->initialize(config["parser"], context) != 0) {
        VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
        return -1;
    }
    _pipeline_cmd = config["pipeline_cmd"].as<std::string>();
    return 0;
}

class LineDataReader : public DataReader {
public:
    LineDataReader() {}
    virtual ~LineDataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        if (DataReader::initialize(config, context) != 0) {
            return -1;
        }
        _done_file_name = config["done_file"].as<std::string>();
        _buffer_size = config["buffer_size"].as<int>(1024);
        _buffer.reset(new char[_buffer_size]);
        return 0;
    }

    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) {
        auto done_file_path = ::paddle::framework::fs_path_join(data_dir, _done_file_name);
        if (::paddle::framework::fs_exists(done_file_path)) {
            return true;
        }
        return false;
    }

    //读取数据样本流中
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) {
        ::paddle::framework::ChannelWriter<DataItem> writer(data_channel.get());
        DataItem data_item;
        if (_buffer_size <= 0 || _buffer == nullptr) {
            VLOG(2) << "no buffer";
            return -1;
        }
        for (const auto& filename : ::paddle::framework::fs_list(data_dir)) {
            if (::paddle::framework::fs_path_split(filename).second == _done_file_name) {
                continue;
            }
            int err_no;
            std::shared_ptr<FILE> fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd);
            if (err_no != 0) {
R
rensilin 已提交
111
                VLOG(2) << "fail to open file: " << filename << ", with cmd: " << _pipeline_cmd;
R
rensilin 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
                return -1;
            }
            while (fgets(_buffer.get(), _buffer_size, fin.get())) {
                if (_parser->parse(_buffer.get(), data_item) != 0) {
                    return -1;
                }
                writer << std::move(data_item);
            }
            if (ferror(fin.get()) != 0) {
                VLOG(2) << "fail to read file: " << filename;
                return -1;
            }
        }
        writer.Flush();
        if (!writer) {
            VLOG(2) << "fail when write to channel";
            return -1;
        }
        return 0;
    }

    virtual const DataParser* get_parser() {
        return _parser.get();
    }
private:
    std::string _done_file_name; // without data_dir
    int _buffer_size = 0;
    std::unique_ptr<char[]> _buffer;
};
REGISTER_CLASS(DataReader, LineDataReader);

}//namespace feed
}//namespace custom_trainer
}//namespace paddle