data_feed_refactor.h 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_

#include <memory>
#include <set>
#include <map>
#include <string>
#include <thread>               // NOLINT
#include <vector>
#include <queue>
#include <mutex>                // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <condition_variable>   // NOLINT
#include <fstream>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
struct Gauc {
  int show, click;
  uint64_t fea;
  std::string lineid;
};

struct Instance {
  std::vector<std::vector<uint64_t>> feed_vec_buffer;
  std::vector<std::vector<int>> feed_vec_lod;
  std::vector<float> other_label;
  std::vector<Gauc> gauc_vec;
};

class DataFeed {
 public:
  DataFeed() {}
  virtual ~DataFeed() {}
  virtual void Init() = 0;
  /*
  * This function will be used to check file format.
  * Considering that this function may be used alone,
  * it does not check anything.
  * */
  virtual bool CheckFile(const char* filename) = 0;
  virtual bool SetFileList(const std::vector<std::string>& files) = 0;
  virtual bool Next() = 0;
  virtual const std::vector<uint16_t>& GetAllSlotIds() {
    return all_slot_ids_;
  }

  virtual const std::vector<uint16_t>& GetUseSlotIds() {
    return use_slot_ids_;
  }

  virtual const std::vector<std::string>& GetUseSlotAlias() {
    return use_slot_alias_;
  }

  virtual void AddFeedVar(Variable* var,
                          const std::string& name) = 0;
  virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
  virtual int GetBatchSize() { return batch_size_; }
  virtual void SetQueueSize(int queue_size) {}
  virtual void SetBufferSize(int buffer_size) {}

  std::vector<LoDTensor*>& GetFeedVec() {
    return feed_vec_;
  }

  virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) {
    LOG(ERROR) << "use defalut get_feed_vec";
    return feed_vec_;
  }

 protected:
  std::vector<uint16_t> all_slot_ids_;
  std::vector<uint16_t> use_slot_ids_;
  std::vector<std::string> use_slot_alias_;
  std::vector<LoDTensor*> feed_vec_;
  int default_batch_size_;
  int batch_size_;
};

class TextClassDataFeed : public DataFeed {
 public:
  virtual ~TextClassDataFeed() {}
  virtual void Init();
  virtual bool ReadBatch();
  virtual void AddFeedVar(Variable* feed, const std::string& name);
  virtual void BindScope(Scope* scope) {}
  virtual bool SetFile(const char* filename);

  virtual bool CheckFile(const char* filename) {
    // TODO(xxx)
    return false;
  }

  void SetBatchSize(int batch) {batch_size_ = batch;}

 private:
  int ReadWholeFile(const std::string& filename, char* buffer);
  char* file_content_buffer_;
  char* file_content_buffer_ptr_;
  int* batch_id_buffer_;
  int* label_ptr_;
  int file_size_;
  std::vector<std::string> names_;
  std::shared_ptr<char> file_content_buffer_host_;
  std::shared_ptr<int> batch_id_host_;
  std::shared_ptr<int> label_host_;
};

}   // namespace framework
}   // namespace paddle

#endif  // PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */