data_feed.h 3.9 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_

#include <memory>
#include <set>
#include <map>
#include <string>
#include <thread>               // NOLINT
#include <vector>
#include <queue>
#include <mutex>                // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <condition_variable>   // NOLINT
#include <fstream>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
struct Gauc {
  int show, click;
  uint64_t fea;
  std::string lineid;
};

struct Instance {
  std::vector<std::vector<uint64_t>> feed_vec_buffer;
  std::vector<std::vector<int>> feed_vec_lod;
  std::vector<float> other_label;
  std::vector<Gauc> gauc_vec;
};

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
class DataFeed {
  DataFeed() {}
  virtual ~DataFeed() {}
};

class BlockingQueueDataFeed : DataFeed {
  BlockingQueueDataFeed() {}
  virtual ~BlockingQueueDataFeed() {}
};

class ThreadedDataFeed : DataFeed {
  ThreadedDataFeed() {}
  virtual ~ThreadedDataFeed() {}
};

W
wangguibao 已提交
65 66 67 68
class DataFeed {
 public:
  DataFeed() {}
  virtual ~DataFeed() {}
W
wangguibao 已提交
69
  virtual void Init() = 0;
W
wangguibao 已提交
70 71 72 73 74
  /*
  * This function will be used to check file format.
  * Considering that this function may be used alone,
  * it does not check anything.
  * */
W
wangguibao 已提交
75 76 77 78 79
  virtual bool CheckFile(const char* filename) = 0;
  virtual bool SetFile(const char* filename) = 0;
  virtual bool ReadBatch() = 0;
  virtual const std::vector<uint16_t>& GetAllSlotIds() {
    return all_slot_ids_;
W
wangguibao 已提交
80 81
  }

W
wangguibao 已提交
82 83
  virtual const std::vector<uint16_t>& GetUseSlotIds() {
    return use_slot_ids_;
W
wangguibao 已提交
84 85
  }

W
wangguibao 已提交
86 87
  virtual const std::vector<std::string>& GetUseSlotAlias() {
    return use_slot_alias_;
W
wangguibao 已提交
88 89
  }

W
wangguibao 已提交
90
  virtual void AddFeedVar(Variable* var,
W
wangguibao 已提交
91
                            const std::string& name) = 0;
W
wangguibao 已提交
92 93 94 95
  virtual void BindScope(Scope* scope) = 0;
  virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
  virtual int GetBatchSize() { return batch_size_; }
  virtual void SetBufferSize(int buffer_size) {}
W
wangguibao 已提交
96

W
wangguibao 已提交
97 98
  std::vector<LoDTensor*>& GetFeedVec() {
    return feed_vec_;
W
wangguibao 已提交
99 100
  }

W
wangguibao 已提交
101
  virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) {
W
wangguibao 已提交
102
    LOG(ERROR) << "use defalut get_feed_vec";
W
wangguibao 已提交
103
    return feed_vec_;
W
wangguibao 已提交
104 105 106
  }

 protected:
W
wangguibao 已提交
107 108 109 110 111 112
  std::vector<uint16_t> all_slot_ids_;
  std::vector<uint16_t> use_slot_ids_;
  std::vector<std::string> use_slot_alias_;
  std::vector<LoDTensor*> feed_vec_;
  int default_batch_size_;
  int batch_size_;
W
wangguibao 已提交
113 114 115 116 117
};

class TextClassDataFeed : public DataFeed {
 public:
  virtual ~TextClassDataFeed() {}
W
wangguibao 已提交
118
  virtual void Init();
W
wangguibao 已提交
119 120 121 122
  virtual bool ReadBatch();
  virtual void AddFeedVar(Variable* feed, const std::string& name);
  virtual void BindScope(Scope* scope) {}
  virtual bool SetFile(const char* filename);
W
wangguibao 已提交
123

W
wangguibao 已提交
124
  virtual bool CheckFile(const char* filename) {
W
wangguibao 已提交
125 126 127 128
    // TODO(xxx)
    return false;
  }

W
wangguibao 已提交
129
  void SetBatchSize(int batch) {batch_size_ = batch;}
W
wangguibao 已提交
130 131

 private:
W
wangguibao 已提交
132 133 134 135 136 137 138 139 140 141
  int ReadWholeFile(const std::string& filename, char* buffer);
  char* file_content_buffer_;
  char* file_content_buffer_ptr_;
  int* batch_id_buffer_;
  int* label_ptr_;
  int file_size_;
  std::vector<std::string> names_;
  std::shared_ptr<char> file_content_buffer_host_;
  std::shared_ptr<int> batch_id_host_;
  std::shared_ptr<int> label_host_;
W
wangguibao 已提交
142 143 144 145 146 147 148
};

}   // namespace framework
}   // namespace paddle

#endif  // PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */