data_feed.h 4.8 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_

#include <memory>
#include <set>
#include <map>
#include <string>
#include <thread>               // NOLINT
#include <vector>
#include <queue>
#include <mutex>                // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <condition_variable>   // NOLINT
#include <fstream>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
struct Gauc {
  int show, click;
  uint64_t fea;
  std::string lineid;
};

struct Instance {
  std::vector<std::vector<uint64_t>> feed_vec_buffer;
  std::vector<std::vector<int>> feed_vec_lod;
  std::vector<float> other_label;
  std::vector<Gauc> gauc_vec;
};

class DataFeed {
 public:
52
  DataFeed() : default_batch_size_(1), batch_size_(0), thread_id_(0) {}
W
wangguibao 已提交
53
  virtual ~DataFeed() {}
W
wangguibao 已提交
54
  virtual void Init() = 0;
W
wangguibao 已提交
55 56 57 58 59
  /*
  * This function will be used to check file format.
  * Considering that this function may be used alone,
  * it does not check anything.
  * */
W
wangguibao 已提交
60 61 62 63 64
  virtual bool CheckFile(const char* filename) = 0;
  virtual bool SetFile(const char* filename) = 0;
  virtual bool ReadBatch() = 0;
  virtual const std::vector<uint16_t>& GetAllSlotIds() {
    return all_slot_ids_;
W
wangguibao 已提交
65 66
  }

W
wangguibao 已提交
67 68
  virtual const std::vector<uint16_t>& GetUseSlotIds() {
    return use_slot_ids_;
W
wangguibao 已提交
69 70
  }

W
wangguibao 已提交
71 72
  virtual const std::vector<std::string>& GetUseSlotAlias() {
    return use_slot_alias_;
W
wangguibao 已提交
73 74
  }

W
wangguibao 已提交
75
  virtual void AddFeedVar(Variable* var,
W
wangguibao 已提交
76
                            const std::string& name) = 0;
W
wangguibao 已提交
77 78 79 80
  virtual void BindScope(Scope* scope) = 0;
  virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
  virtual int GetBatchSize() { return batch_size_; }
  virtual void SetBufferSize(int buffer_size) {}
81 82 83 84 85
  virtual unsigned int GetCurrentEpoch() = 0;
  virtual const char *PickOneFile() = 0;
  virtual void UpdateEpochNum() = 0;
  virtual void StartOneEpoch() = 0;
  virtual void WaitNextEpoch() = 0;
W
wangguibao 已提交
86

W
wangguibao 已提交
87 88
  std::vector<LoDTensor*>& GetFeedVec() {
    return feed_vec_;
W
wangguibao 已提交
89 90
  }

W
wangguibao 已提交
91
  virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) {
W
wangguibao 已提交
92
    LOG(ERROR) << "use defalut get_feed_vec";
W
wangguibao 已提交
93
    return feed_vec_;
W
wangguibao 已提交
94 95
  }

96 97 98
  int GetThreadId() {return thread_id_;}
  void SetThreadId(int thread_id) {thread_id_ = thread_id;}

W
wangguibao 已提交
99
 protected:
W
wangguibao 已提交
100 101 102 103 104 105
  std::vector<uint16_t> all_slot_ids_;
  std::vector<uint16_t> use_slot_ids_;
  std::vector<std::string> use_slot_alias_;
  std::vector<LoDTensor*> feed_vec_;
  int default_batch_size_;
  int batch_size_;
106
  int thread_id_;
W
wangguibao 已提交
107 108 109
};

class TextClassDataFeed : public DataFeed {
110 111 112 113
 public:
  TextClassDataFeed();
  TextClassDataFeed(const TextClassDataFeed& data_feed);

W
wangguibao 已提交
114 115
 public:
  virtual ~TextClassDataFeed() {}
W
wangguibao 已提交
116
  virtual void Init();
W
wangguibao 已提交
117 118 119 120 121
  virtual bool ReadBatch();
  virtual void AddFeedVar(Variable* feed, const std::string& name);
  virtual void BindScope(Scope* scope) {}
  virtual bool SetFile(const char* filename);
  virtual bool CheckFile(const char* filename) {
W
wangguibao 已提交
122 123 124
    // TODO(xxx)
    return false;
  }
W
wangguibao 已提交
125
  void SetBatchSize(int batch) {batch_size_ = batch;}
126 127 128 129 130 131 132 133 134 135 136 137 138
  unsigned int GetCurrentEpoch() {return s_current_epoch_;}
  void UpdateEpochNum();
  void StartOneEpoch();
  void WaitNextEpoch();

 public:
  void SetFieldNames(const std::vector<std::string>& field_names);

 public:
  static void SetFileList(const char* filelist);

 private:
  const char* PickOneFile();
W
wangguibao 已提交
139 140

 private:
W
wangguibao 已提交
141 142 143 144 145
  char* file_content_buffer_;
  char* file_content_buffer_ptr_;
  int* batch_id_buffer_;
  int* label_ptr_;
  int file_size_;
146
  std::vector<std::string> field_names_;
W
wangguibao 已提交
147 148 149
  std::shared_ptr<char> file_content_buffer_host_;
  std::shared_ptr<int> batch_id_host_;
  std::shared_ptr<int> label_host_;
150 151 152 153 154 155 156 157 158 159

  static std::vector<std::string> s_filelist_;
  static std::mutex s_locker_for_pick_file_;
  static unsigned int s_current_file_idx_;
  static size_t s_current_finished_file_cnt_;
  static unsigned int s_current_epoch_;
  static int s_current_save_epoch_;
  static std::mutex s_locker_epoch_start_;
  static std::condition_variable s_condition_epoch_start_;
  static bool s_epoch_start_flag_;
W
wangguibao 已提交
160 161 162 163 164 165 166
};

}   // namespace framework
}   // namespace paddle

#endif  // PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */