data_feed.h 3.7 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_

#include <memory>
#include <set>
#include <map>
#include <string>
#include <thread>               // NOLINT
#include <vector>
#include <queue>
#include <mutex>                // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <condition_variable>   // NOLINT
#include <fstream>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
struct Gauc {
  int show, click;
  uint64_t fea;
  std::string lineid;
};

struct Instance {
  std::vector<std::vector<uint64_t>> feed_vec_buffer;
  std::vector<std::vector<int>> feed_vec_lod;
  std::vector<float> other_label;
  std::vector<Gauc> gauc_vec;
};

class DataFeed {
 public:
  DataFeed() {}
  virtual ~DataFeed() {}
W
wangguibao 已提交
54
  virtual void Init() = 0;
W
wangguibao 已提交
55 56 57 58 59
  /*
  * This function will be used to check file format.
  * Considering that this function may be used alone,
  * it does not check anything.
  * */
W
wangguibao 已提交
60 61 62 63 64
  virtual bool CheckFile(const char* filename) = 0;
  virtual bool SetFile(const char* filename) = 0;
  virtual bool ReadBatch() = 0;
  virtual const std::vector<uint16_t>& GetAllSlotIds() {
    return all_slot_ids_;
W
wangguibao 已提交
65 66
  }

W
wangguibao 已提交
67 68
  virtual const std::vector<uint16_t>& GetUseSlotIds() {
    return use_slot_ids_;
W
wangguibao 已提交
69 70
  }

W
wangguibao 已提交
71 72
  virtual const std::vector<std::string>& GetUseSlotAlias() {
    return use_slot_alias_;
W
wangguibao 已提交
73 74
  }

W
wangguibao 已提交
75
  virtual void AddFeedVar(Variable* var,
W
wangguibao 已提交
76
                            const std::string& name) = 0;
W
wangguibao 已提交
77 78 79 80
  virtual void BindScope(Scope* scope) = 0;
  virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
  virtual int GetBatchSize() { return batch_size_; }
  virtual void SetBufferSize(int buffer_size) {}
W
wangguibao 已提交
81

W
wangguibao 已提交
82 83
  std::vector<LoDTensor*>& GetFeedVec() {
    return feed_vec_;
W
wangguibao 已提交
84 85
  }

W
wangguibao 已提交
86
  virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) {
W
wangguibao 已提交
87
    LOG(ERROR) << "use defalut get_feed_vec";
W
wangguibao 已提交
88
    return feed_vec_;
W
wangguibao 已提交
89 90 91
  }

 protected:
W
wangguibao 已提交
92 93 94 95 96 97
  std::vector<uint16_t> all_slot_ids_;
  std::vector<uint16_t> use_slot_ids_;
  std::vector<std::string> use_slot_alias_;
  std::vector<LoDTensor*> feed_vec_;
  int default_batch_size_;
  int batch_size_;
W
wangguibao 已提交
98 99 100 101 102
};

class TextClassDataFeed : public DataFeed {
 public:
  virtual ~TextClassDataFeed() {}
W
wangguibao 已提交
103
  virtual void Init();
W
wangguibao 已提交
104 105 106 107
  virtual bool ReadBatch();
  virtual void AddFeedVar(Variable* feed, const std::string& name);
  virtual void BindScope(Scope* scope) {}
  virtual bool SetFile(const char* filename);
W
wangguibao 已提交
108

W
wangguibao 已提交
109
  virtual bool CheckFile(const char* filename) {
W
wangguibao 已提交
110 111 112 113
    // TODO(xxx)
    return false;
  }

W
wangguibao 已提交
114
  void SetBatchSize(int batch) {batch_size_ = batch;}
W
wangguibao 已提交
115 116

 private:
W
wangguibao 已提交
117 118 119 120 121 122 123 124 125 126
  int ReadWholeFile(const std::string& filename, char* buffer);
  char* file_content_buffer_;
  char* file_content_buffer_ptr_;
  int* batch_id_buffer_;
  int* label_ptr_;
  int file_size_;
  std::vector<std::string> names_;
  std::shared_ptr<char> file_content_buffer_host_;
  std::shared_ptr<int> batch_id_host_;
  std::shared_ptr<int> label_host_;
W
wangguibao 已提交
127 128 129 130 131 132 133
};

}   // namespace framework
}   // namespace paddle

#endif  // PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */