data_feed.h 3.8 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_

#include <memory>
#include <set>
#include <map>
#include <string>
#include <thread>               // NOLINT
#include <vector>
#include <queue>
#include <mutex>                // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <condition_variable>   // NOLINT
#include <fstream>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "proto/FeedDataParameter.pb.h"

namespace paddle {
namespace framework {
struct Gauc {
  int show, click;
  uint64_t fea;
  std::string lineid;
};

struct Instance {
  std::vector<std::vector<uint64_t>> feed_vec_buffer;
  std::vector<std::vector<int>> feed_vec_lod;
  std::vector<float> other_label;
  std::vector<Gauc> gauc_vec;
};

class DataFeed {
 public:
  DataFeed() {}
  virtual ~DataFeed() {}
W
wangguibao 已提交
55
  virtual void Init(const datafeed::DataFeedParameter& feed_param) = 0;
W
wangguibao 已提交
56 57 58 59 60
  /*
  * This function will be used to check file format.
  * Considering that this function may be used alone,
  * it does not check anything.
  * */
W
wangguibao 已提交
61 62 63 64 65
  virtual bool CheckFile(const char* filename) = 0;
  virtual bool SetFile(const char* filename) = 0;
  virtual bool ReadBatch() = 0;
  virtual const std::vector<uint16_t>& GetAllSlotIds() {
    return all_slot_ids_;
W
wangguibao 已提交
66 67
  }

W
wangguibao 已提交
68 69
  virtual const std::vector<uint16_t>& GetUseSlotIds() {
    return use_slot_ids_;
W
wangguibao 已提交
70 71
  }

W
wangguibao 已提交
72 73
  virtual const std::vector<std::string>& GetUseSlotAlias() {
    return use_slot_alias_;
W
wangguibao 已提交
74 75
  }

W
wangguibao 已提交
76
  virtual void AddFeedVar(Variable* var,
W
wangguibao 已提交
77
                            const std::string& name) = 0;
W
wangguibao 已提交
78 79 80 81
  virtual void BindScope(Scope* scope) = 0;
  virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
  virtual int GetBatchSize() { return batch_size_; }
  virtual void SetBufferSize(int buffer_size) {}
W
wangguibao 已提交
82

W
wangguibao 已提交
83 84
  std::vector<LoDTensor*>& GetFeedVec() {
    return feed_vec_;
W
wangguibao 已提交
85 86
  }

W
wangguibao 已提交
87
  virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) {
W
wangguibao 已提交
88
    LOG(ERROR) << "use defalut get_feed_vec";
W
wangguibao 已提交
89
    return feed_vec_;
W
wangguibao 已提交
90 91 92
  }

 protected:
W
wangguibao 已提交
93 94 95 96 97 98
  std::vector<uint16_t> all_slot_ids_;
  std::vector<uint16_t> use_slot_ids_;
  std::vector<std::string> use_slot_alias_;
  std::vector<LoDTensor*> feed_vec_;
  int default_batch_size_;
  int batch_size_;
W
wangguibao 已提交
99 100 101 102 103
};

class TextClassDataFeed : public DataFeed {
 public:
  virtual ~TextClassDataFeed() {}
W
wangguibao 已提交
104 105 106 107 108
  virtual void Init(const datafeed::DataFeedParameter& feed_param);
  virtual bool ReadBatch();
  virtual void AddFeedVar(Variable* feed, const std::string& name);
  virtual void BindScope(Scope* scope) {}
  virtual bool SetFile(const char* filename);
W
wangguibao 已提交
109

W
wangguibao 已提交
110
  virtual bool CheckFile(const char* filename) {
W
wangguibao 已提交
111 112 113 114
    // TODO(xxx)
    return false;
  }

W
wangguibao 已提交
115
  void SetBatchSize(int batch) {batch_size_ = batch;}
W
wangguibao 已提交
116 117

 private:
W
wangguibao 已提交
118 119 120 121 122 123 124 125 126 127
  int ReadWholeFile(const std::string& filename, char* buffer);
  char* file_content_buffer_;
  char* file_content_buffer_ptr_;
  int* batch_id_buffer_;
  int* label_ptr_;
  int file_size_;
  std::vector<std::string> names_;
  std::shared_ptr<char> file_content_buffer_host_;
  std::shared_ptr<int> batch_id_host_;
  std::shared_ptr<int> label_host_;
W
wangguibao 已提交
128 129 130 131 132 133 134
};

}   // namespace framework
}   // namespace paddle

#endif  // PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */