data_reader.cc 4.4 KB
Newer Older
R
rensilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"

#include <cstdio>

#include <glog/logging.h>

#include "paddle/fluid/framework/io/fs.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

class LineDataParser : public DataParser{
public:
    LineDataParser() {}

    virtual ~LineDataParser() {}

    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        return 0;
    }

    virtual int parse(const char* str, size_t len, DataItem& data) const {
        size_t pos = 0;
        while (str[pos] != ' ') {
            if (pos >= len) {
                VLOG(2) << "fail to parse line, strlen: " << len;
                return -1;
            }
            ++pos;
        }
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1, len - pos - 1);
        return 0;
    }

    virtual int parse(const char* str, DataItem& data) const {
        size_t pos = 0;
        while (str[pos] != ' ') {
            if (str[pos] == '\0') {
                VLOG(2) << "fail to parse line, get '\\0' at pos: " << pos;
                return -1;
            }
            ++pos;
        }
        data.id.assign(str, pos);
        data.data.assign(str + pos + 1);
        return 0;
    }

    virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
        return 0;
    }
};
REGISTER_CLASS(DataParser, LineDataParser);

int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
    _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
    if (_parser == nullptr) {
        VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
        return -1;
    }
    if (_parser->initialize(config["parser"], context) != 0) {
        VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
        return -1;
    }
    _pipeline_cmd = config["pipeline_cmd"].as<std::string>();
    return 0;
}

class LineDataReader : public DataReader {
public:
    LineDataReader() {}
    virtual ~LineDataReader() {}
    virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
        if (DataReader::initialize(config, context) != 0) {
            return -1;
        }
        _done_file_name = config["done_file"].as<std::string>();
        _buffer_size = config["buffer_size"].as<int>(1024);
        _buffer.reset(new char[_buffer_size]);
        return 0;
    }

    //判断样本数据是否已就绪,就绪表明可以开始download
    virtual bool is_data_ready(const std::string& data_dir) {
        auto done_file_path = ::paddle::framework::fs_path_join(data_dir, _done_file_name);
        if (::paddle::framework::fs_exists(done_file_path)) {
            return true;
        }
        return false;
    }

    //读取数据样本流中
    virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) {
        ::paddle::framework::ChannelWriter<DataItem> writer(data_channel.get());
        DataItem data_item;
        if (_buffer_size <= 0 || _buffer == nullptr) {
            VLOG(2) << "no buffer";
            return -1;
        }
        for (const auto& filename : ::paddle::framework::fs_list(data_dir)) {
            if (::paddle::framework::fs_path_split(filename).second == _done_file_name) {
                continue;
            }
            int err_no;
            std::shared_ptr<FILE> fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd);
            if (err_no != 0) {
                return -1;
            }
            while (fgets(_buffer.get(), _buffer_size, fin.get())) {
                if (_parser->parse(_buffer.get(), data_item) != 0) {
                    return -1;
                }
                writer << std::move(data_item);
            }
            if (ferror(fin.get()) != 0) {
                VLOG(2) << "fail to read file: " << filename;
                return -1;
            }
        }
        writer.Flush();
        if (!writer) {
            VLOG(2) << "fail when write to channel";
            return -1;
        }
        return 0;
    }

    virtual const DataParser* get_parser() {
        return _parser.get();
    }
private:
    std::string _done_file_name; // without data_dir
    int _buffer_size = 0;
    std::unique_ptr<char[]> _buffer;
};
REGISTER_CLASS(DataReader, LineDataReader);

}//namespace feed
}//namespace custom_trainer
}//namespace paddle