data_feed.h 37.4 KB
Newer Older
W
Wang Guibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

J
jiaqi 已提交
17 18 19 20 21
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif

W
Wang Guibao 已提交
22
#include <fstream>
23
#include <future>  // NOLINT
W
Wang Guibao 已提交
24 25
#include <memory>
#include <mutex>  // NOLINT
26
#include <sstream>
W
Wang Guibao 已提交
27 28
#include <string>
#include <thread>  // NOLINT
29
#include <unordered_map>
30
#include <unordered_set>
31
#include <utility>
32
#include <vector>
W
Wang Guibao 已提交
33

J
jiaqi 已提交
34
#include "paddle/fluid/framework/archive.h"
35
#include "paddle/fluid/framework/blocking_queue.h"
J
jiaqi 已提交
36
#include "paddle/fluid/framework/channel.h"
W
Wang Guibao 已提交
37
#include "paddle/fluid/framework/data_feed.pb.h"
38
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
W
Wang Guibao 已提交
39 40 41
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
Y
yaoxuefeng 已提交
42
#include "paddle/fluid/platform/timer.h"
43
#include "paddle/fluid/string/string_helper.h"
W
Wang Guibao 已提交
44

Y
yaoxuefeng 已提交
45 46 47 48 49
DECLARE_int32(record_pool_max_size);
DECLARE_int32(slotpool_thread_num);
DECLARE_bool(enable_slotpool_wait_release);
DECLARE_bool(enable_slotrecord_reset_shrink);

W
wanghuancoder 已提交
50 51 52 53 54 55 56 57
namespace paddle {
namespace framework {
class DataFeedDesc;
class Scope;
class Variable;
}  // namespace framework
}  // namespace paddle

58
namespace phi {
59
class DenseTensor;
60
}  // namespace phi
61

W
Wang Guibao 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
namespace paddle {
namespace framework {

// DataFeed is the base virtual class for all ohther DataFeeds.
// It is used to read files and parse the data for subsequent trainer.
// Example:
//   DataFeed* reader =
//   paddle::framework::DataFeedFactory::CreateDataFeed(data_feed_name);
//   reader->Init(data_feed_desc); // data_feed_desc is a protobuf object
//   reader->SetFileList(filelist);
//   const std::vector<std::string> & use_slot_alias =
//   reader->GetUseSlotAlias();
//   for (auto name: use_slot_alias){ // for binding memory
//     reader->AddFeedVar(scope->Var(name), name);
//   }
//   reader->Start();
//   while (reader->Next()) {
//      // trainer do something
//   }
Y
yaoxuefeng 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

template <typename T>
struct SlotValues {
  std::vector<T> slot_values;
  std::vector<uint32_t> slot_offsets;

  void add_values(const T* values, uint32_t num) {
    if (slot_offsets.empty()) {
      slot_offsets.push_back(0);
    }
    if (num > 0) {
      slot_values.insert(slot_values.end(), values, values + num);
    }
    slot_offsets.push_back(static_cast<uint32_t>(slot_values.size()));
  }
  T* get_values(int idx, size_t* size) {
    uint32_t& offset = slot_offsets[idx];
    (*size) = slot_offsets[idx + 1] - offset;
    return &slot_values[offset];
  }
  void add_slot_feasigns(const std::vector<std::vector<T>>& slot_feasigns,
                         uint32_t fea_num) {
    slot_values.reserve(fea_num);
    int slot_num = static_cast<int>(slot_feasigns.size());
    slot_offsets.resize(slot_num + 1);
    for (int i = 0; i < slot_num; ++i) {
      auto& slot_val = slot_feasigns[i];
      slot_offsets[i] = static_cast<uint32_t>(slot_values.size());
      uint32_t num = static_cast<uint32_t>(slot_val.size());
      if (num > 0) {
        slot_values.insert(slot_values.end(), slot_val.begin(), slot_val.end());
      }
    }
    slot_offsets[slot_num] = slot_values.size();
  }
  void clear(bool shrink) {
    slot_offsets.clear();
    slot_values.clear();
    if (shrink) {
      slot_values.shrink_to_fit();
      slot_offsets.shrink_to_fit();
    }
  }
};
T
Thunderbrook 已提交
125
union FeatureFeasign {
126 127 128 129 130 131
  uint64_t uint64_feasign_;
  float float_feasign_;
};

struct FeatureItem {
  FeatureItem() {}
T
Thunderbrook 已提交
132
  FeatureItem(FeatureFeasign sign, uint16_t slot) {
133 134 135
    this->sign() = sign;
    this->slot() = slot;
  }
T
Thunderbrook 已提交
136 137 138 139 140 141
  FeatureFeasign& sign() {
    return *(reinterpret_cast<FeatureFeasign*>(sign_buffer()));
  }
  const FeatureFeasign& sign() const {
    const FeatureFeasign* ret =
        reinterpret_cast<FeatureFeasign*>(sign_buffer());
142 143 144 145 146 147 148
    return *ret;
  }
  uint16_t& slot() { return slot_; }
  const uint16_t& slot() const { return slot_; }

 private:
  char* sign_buffer() const { return const_cast<char*>(sign_); }
T
Thunderbrook 已提交
149
  char sign_[sizeof(FeatureFeasign)];
150 151 152
  uint16_t slot_;
};

Y
yaoxuefeng 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
struct AllSlotInfo {
  std::string slot;
  std::string type;
  int used_idx;
  int slot_value_idx;
};
struct UsedSlotInfo {
  int idx;
  int slot_value_idx;
  std::string slot;
  std::string type;
  bool dense;
  std::vector<int> local_shape;
  int total_dims_without_inductive;
  int inductive_shape_index;
};
struct SlotRecordObject {
  uint64_t search_id;
  uint32_t rank;
  uint32_t cmatch;
  std::string ins_id_;
  SlotValues<uint64_t> slot_uint64_feasigns_;
  SlotValues<float> slot_float_feasigns_;

  ~SlotRecordObject() { clear(true); }
  void reset(void) { clear(FLAGS_enable_slotrecord_reset_shrink); }
  void clear(bool shrink) {
    slot_uint64_feasigns_.clear(shrink);
    slot_float_feasigns_.clear(shrink);
  }
};
using SlotRecord = SlotRecordObject*;
185 186 187 188 189 190 191 192 193
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
  std::vector<FeatureItem> uint64_feasigns_;
  std::vector<FeatureItem> float_feasigns_;
  std::string ins_id_;
  std::string content_;
  uint64_t search_id;
  uint32_t rank;
  uint32_t cmatch;
194
  std::string uid_;
195 196
};

Y
yaoxuefeng 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
inline SlotRecord make_slotrecord() {
  static const size_t slot_record_byte_size = sizeof(SlotRecordObject);
  void* p = malloc(slot_record_byte_size);
  new (p) SlotRecordObject;
  return reinterpret_cast<SlotRecordObject*>(p);
}

inline void free_slotrecord(SlotRecordObject* p) {
  p->~SlotRecordObject();
  free(p);
}

template <class T>
class SlotObjAllocator {
 public:
  explicit SlotObjAllocator(std::function<void(T*)> deleter)
      : free_nodes_(NULL), capacity_(0), deleter_(deleter) {}
  ~SlotObjAllocator() { clear(); }

  void clear() {
    T* tmp = NULL;
    while (free_nodes_ != NULL) {
      tmp = reinterpret_cast<T*>(reinterpret_cast<void*>(free_nodes_));
      free_nodes_ = free_nodes_->next;
      deleter_(tmp);
      --capacity_;
    }
    CHECK_EQ(capacity_, static_cast<size_t>(0));
  }
  T* acquire(void) {
    T* x = NULL;
    x = reinterpret_cast<T*>(reinterpret_cast<void*>(free_nodes_));
    free_nodes_ = free_nodes_->next;
    --capacity_;
    return x;
  }
  void release(T* x) {
    Node* node = reinterpret_cast<Node*>(reinterpret_cast<void*>(x));
    node->next = free_nodes_;
    free_nodes_ = node;
    ++capacity_;
  }
  size_t capacity(void) { return capacity_; }

 private:
  struct alignas(T) Node {
    union {
      Node* next;
      char data[sizeof(T)];
    };
  };
  Node* free_nodes_;  // a list
  size_t capacity_;
  std::function<void(T*)> deleter_ = nullptr;
};
static const int OBJPOOL_BLOCK_SIZE = 10000;
class SlotObjPool {
 public:
  SlotObjPool()
      : max_capacity_(FLAGS_record_pool_max_size), alloc_(free_slotrecord) {
    ins_chan_ = MakeChannel<SlotRecord>();
    ins_chan_->SetBlockSize(OBJPOOL_BLOCK_SIZE);
    for (int i = 0; i < FLAGS_slotpool_thread_num; ++i) {
      threads_.push_back(std::thread([this]() { run(); }));
    }
    disable_pool_ = false;
    count_ = 0;
  }
  ~SlotObjPool() {
    ins_chan_->Close();
    for (auto& t : threads_) {
      t.join();
    }
  }
  void disable_pool(bool disable) { disable_pool_ = disable; }
  void set_max_capacity(size_t max_capacity) { max_capacity_ = max_capacity; }
  void get(std::vector<SlotRecord>* output, int n) {
    output->resize(n);
    return get(&(*output)[0], n);
  }
  void get(SlotRecord* output, int n) {
    int size = 0;
    mutex_.lock();
    int left = static_cast<int>(alloc_.capacity());
    if (left > 0) {
      size = (left >= n) ? n : left;
      for (int i = 0; i < size; ++i) {
        output[i] = alloc_.acquire();
      }
    }
    mutex_.unlock();
    count_ += n;
    if (size == n) {
      return;
    }
    for (int i = size; i < n; ++i) {
      output[i] = make_slotrecord();
    }
  }
  void put(std::vector<SlotRecord>* input) {
    size_t size = input->size();
    if (size == 0) {
      return;
    }
    put(&(*input)[0], size);
    input->clear();
  }
  void put(SlotRecord* input, size_t size) {
    CHECK(ins_chan_->WriteMove(size, input) == size);
  }
  void run(void) {
    std::vector<SlotRecord> input;
    while (ins_chan_->ReadOnce(input, OBJPOOL_BLOCK_SIZE)) {
      if (input.empty()) {
        continue;
      }
      // over max capacity
      size_t n = input.size();
      count_ -= n;
      if (disable_pool_ || n + capacity() > max_capacity_) {
        for (auto& t : input) {
          free_slotrecord(t);
        }
      } else {
        for (auto& t : input) {
          t->reset();
        }
        mutex_.lock();
        for (auto& t : input) {
          alloc_.release(t);
        }
        mutex_.unlock();
      }
      input.clear();
    }
  }
  void clear(void) {
    platform::Timer timeline;
    timeline.Start();
    mutex_.lock();
    alloc_.clear();
    mutex_.unlock();
    // wait release channel data
    if (FLAGS_enable_slotpool_wait_release) {
      while (!ins_chan_->Empty()) {
        sleep(1);
      }
    }
    timeline.Pause();
    VLOG(3) << "clear slot pool data size=" << count_.load()
            << ", span=" << timeline.ElapsedSec();
  }
  size_t capacity(void) {
    mutex_.lock();
    size_t total = alloc_.capacity();
    mutex_.unlock();
    return total;
  }

 private:
  size_t max_capacity_;
  Channel<SlotRecord> ins_chan_;
  std::vector<std::thread> threads_;
  std::mutex mutex_;
  SlotObjAllocator<SlotRecordObject> alloc_;
  bool disable_pool_;
  std::atomic<long> count_;  // NOLINT
};

inline SlotObjPool& SlotRecordPool() {
  static SlotObjPool pool;
  return pool;
}
370 371 372 373 374 375 376 377 378
struct PvInstanceObject {
  std::vector<Record*> ads;
  void merge_instance(Record* ins) { ads.push_back(ins); }
};

using PvInstance = PvInstanceObject*;

inline PvInstance make_pv_instance() { return new PvInstanceObject(); }

T
Thunderbrook 已提交
379 380 381 382 383 384 385 386 387 388 389 390
struct SlotConf {
  std::string name;
  std::string type;
  int use_slots_index;
  int use_slots_is_dense;
};

class CustomParser {
 public:
  CustomParser() {}
  virtual ~CustomParser() {}
  virtual void Init(const std::vector<SlotConf>& slots) = 0;
Y
yaoxuefeng 已提交
391
  virtual bool Init(const std::vector<AllSlotInfo>& slots);
T
Thunderbrook 已提交
392
  virtual void ParseOneInstance(const char* str, Record* instance) = 0;
Y
yaoxuefeng 已提交
393 394 395 396 397 398 399 400 401 402 403 404 405
  virtual bool ParseOneInstance(
      const std::string& line,
      std::function<void(std::vector<SlotRecord>&, int)>
          GetInsFunc) {  // NOLINT
    return true;
  }
  virtual bool ParseFileInstance(
      std::function<int(char* buf, int len)> ReadBuffFunc,
      std::function<void(std::vector<SlotRecord>&, int, int)>
          PullRecordsFunc,  // NOLINT
      int& lines) {         // NOLINT
    return false;
  }
T
Thunderbrook 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
};

typedef paddle::framework::CustomParser* (*CreateParserObjectFunc)();

class DLManager {
  struct DLHandle {
    void* module;
    paddle::framework::CustomParser* parser;
  };

 public:
  DLManager() {}

  ~DLManager() {
#ifdef _LINUX
    std::lock_guard<std::mutex> lock(mutex_);
    for (auto it = handle_map_.begin(); it != handle_map_.end(); ++it) {
      delete it->second.parser;
      dlclose(it->second.module);
    }
#endif
  }

  bool Close(const std::string& name) {
#ifdef _LINUX
    auto it = handle_map_.find(name);
    if (it == handle_map_.end()) {
      return true;
    }
    delete it->second.parser;
    dlclose(it->second.module);
#endif
    VLOG(0) << "Not implement in windows";
    return false;
  }

  paddle::framework::CustomParser* Load(const std::string& name,
Y
yaoxuefeng 已提交
443
                                        const std::vector<SlotConf>& conf) {
T
Thunderbrook 已提交
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
#ifdef _LINUX
    std::lock_guard<std::mutex> lock(mutex_);
    DLHandle handle;
    std::map<std::string, DLHandle>::iterator it = handle_map_.find(name);
    if (it != handle_map_.end()) {
      return it->second.parser;
    }

    handle.module = dlopen(name.c_str(), RTLD_NOW);
    if (handle.module == nullptr) {
      VLOG(0) << "Create so of " << name << " fail";
      return nullptr;
    }

    CreateParserObjectFunc create_parser_func =
        (CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject");
    handle.parser = create_parser_func();
    handle.parser->Init(conf);
    handle_map_.insert({name, handle});

    return handle.parser;
#endif
    VLOG(0) << "Not implement in windows";
    return nullptr;
  }

Y
yaoxuefeng 已提交
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
  paddle::framework::CustomParser* Load(const std::string& name,
                                        const std::vector<AllSlotInfo>& conf) {
#ifdef _LINUX
    std::lock_guard<std::mutex> lock(mutex_);
    DLHandle handle;
    std::map<std::string, DLHandle>::iterator it = handle_map_.find(name);
    if (it != handle_map_.end()) {
      return it->second.parser;
    }
    handle.module = dlopen(name.c_str(), RTLD_NOW);
    if (handle.module == nullptr) {
      VLOG(0) << "Create so of " << name << " fail";
      exit(-1);
      return nullptr;
    }

    CreateParserObjectFunc create_parser_func =
        (CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject");
    handle.parser = create_parser_func();
    handle.parser->Init(conf);
    handle_map_.insert({name, handle});

    return handle.parser;
#endif
    VLOG(0) << "Not implement in windows";
    return nullptr;
  }

T
Thunderbrook 已提交
498
  paddle::framework::CustomParser* ReLoad(const std::string& name,
Y
yaoxuefeng 已提交
499
                                          const std::vector<SlotConf>& conf) {
T
Thunderbrook 已提交
500 501 502 503 504 505 506 507 508
    Close(name);
    return Load(name, conf);
  }

 private:
  std::mutex mutex_;
  std::map<std::string, DLHandle> handle_map_;
};

W
Wang Guibao 已提交
509 510
class DataFeed {
 public:
511 512 513
  DataFeed() {
    mutex_for_pick_file_ = nullptr;
    file_idx_ = nullptr;
H
hutuxian 已提交
514 515
    mutex_for_fea_num_ = nullptr;
    total_fea_num_ = nullptr;
516
  }
W
Wang Guibao 已提交
517
  virtual ~DataFeed() {}
H
hutuxian 已提交
518
  virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
W
Wang Guibao 已提交
519
  virtual bool CheckFile(const char* filename) {
520 521
    PADDLE_THROW(platform::errors::Unimplemented(
        "This function(CheckFile) is not implemented."));
W
Wang Guibao 已提交
522 523 524 525 526 527
  }
  // Set filelist for DataFeed.
  // Pay attention that it must init all readers before call this function.
  // Otherwise, Init() function will init finish_set_filelist_ flag.
  virtual bool SetFileList(const std::vector<std::string>& files);
  virtual bool Start() = 0;
D
dongdaxiang 已提交
528

W
Wang Guibao 已提交
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
  // The trainer calls the Next() function, and the DataFeed will load a new
  // batch to the feed_vec. The return value of this function is the batch
  // size of the current batch.
  virtual int Next() = 0;
  // Get all slots' alias which defined in protofile
  virtual const std::vector<std::string>& GetAllSlotAlias() {
    return all_slots_;
  }
  // Get used slots' alias which defined in protofile
  virtual const std::vector<std::string>& GetUseSlotAlias() {
    return use_slots_;
  }
  // This function is used for binding feed_vec memory
  virtual void AddFeedVar(Variable* var, const std::string& name);

H
hutuxian 已提交
544 545 546
  // This function is used for binding feed_vec memory in a given scope
  virtual void AssignFeedVar(const Scope& scope);

547 548 549 550 551 552 553
  // This function will do nothing at default
  virtual void SetInputPvChannel(void* channel) {}
  // This function will do nothing at default
  virtual void SetOutputPvChannel(void* channel) {}
  // This function will do nothing at default
  virtual void SetConsumePvChannel(void* channel) {}

554
  // This function will do nothing at default
J
jiaqi 已提交
555 556 557
  virtual void SetInputChannel(void* channel) {}
  // This function will do nothing at default
  virtual void SetOutputChannel(void* channel) {}
558
  // This function will do nothing at default
J
jiaqi 已提交
559
  virtual void SetConsumeChannel(void* channel) {}
560
  // This function will do nothing at default
561
  virtual void SetThreadId(int thread_id) {}
562
  // This function will do nothing at default
563
  virtual void SetThreadNum(int thread_num) {}
564 565
  // This function will do nothing at default
  virtual void SetParseInsId(bool parse_ins_id) {}
566
  virtual void SetParseUid(bool parse_uid) {}
567
  virtual void SetParseContent(bool parse_content) {}
568 569 570
  virtual void SetParseLogKey(bool parse_logkey) {}
  virtual void SetEnablePvMerge(bool enable_pv_merge) {}
  virtual void SetCurrentPhase(int current_phase) {}
571 572 573
  virtual void SetFileListMutex(std::mutex* mutex) {
    mutex_for_pick_file_ = mutex;
  }
H
hutuxian 已提交
574
  virtual void SetFeaNumMutex(std::mutex* mutex) { mutex_for_fea_num_ = mutex; }
575
  virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
H
hutuxian 已提交
576
  virtual void SetFeaNum(uint64_t* fea_num) { total_fea_num_ = fea_num; }
577 578 579 580 581 582 583
  virtual const std::vector<std::string>& GetInsIdVec() const {
    return ins_id_vec_;
  }
  virtual const std::vector<std::string>& GetInsContentVec() const {
    return ins_content_vec_;
  }
  virtual int GetCurBatchSize() { return batch_size_; }
584
  virtual void LoadIntoMemory() {
585 586
    PADDLE_THROW(platform::errors::Unimplemented(
        "This function(LoadIntoMemory) is not implemented."));
587
  }
588 589 590 591
  virtual void SetPlace(const paddle::platform::Place& place) {
    place_ = place;
  }
  virtual const paddle::platform::Place& GetPlace() const { return place_; }
592

W
Wang Guibao 已提交
593 594 595 596 597 598 599 600 601 602 603 604
 protected:
  // The following three functions are used to check if it is executed in this
  // order:
  //   Init() -> SetFileList() -> Start() -> Next()
  virtual void CheckInit();
  virtual void CheckSetFileList();
  virtual void CheckStart();
  virtual void SetBatchSize(
      int batch);  // batch size will be set in Init() function
  // This function is used to pick one file from the global filelist(thread
  // safe).
  virtual bool PickOneFile(std::string* filename);
605
  virtual void CopyToFeedTensor(void* dst, const void* src, size_t size);
W
Wang Guibao 已提交
606

607 608 609
  std::vector<std::string> filelist_;
  size_t* file_idx_;
  std::mutex* mutex_for_pick_file_;
H
hutuxian 已提交
610 611 612
  std::mutex* mutex_for_fea_num_ = nullptr;
  uint64_t* total_fea_num_ = nullptr;
  uint64_t fea_num_ = 0;
W
Wang Guibao 已提交
613 614 615 616 617 618 619 620 621 622

  // the alias of used slots, and its order is determined by
  // data_feed_desc(proto object)
  std::vector<std::string> use_slots_;
  std::vector<bool> use_slots_is_dense_;

  // the alias of all slots, and its order is determined by data_feed_desc(proto
  // object)
  std::vector<std::string> all_slots_;
  std::vector<std::string> all_slots_type_;
623
  std::vector<std::vector<int>> use_slots_shape_;
624 625
  std::vector<int> inductive_shape_index_;
  std::vector<int> total_dims_without_inductive_;
H
hutuxian 已提交
626 627
  // For the inductive shape passed within data
  std::vector<std::vector<int>> multi_inductive_shape_index_;
W
Wang Guibao 已提交
628 629 630 631
  std::vector<int>
      use_slots_index_;  // -1: not used; >=0: the index of use_slots_

  // The data read by DataFeed will be stored here
632
  std::vector<LoDTensor*> feed_vec_;
W
Wang Guibao 已提交
633

634 635
  LoDTensor* rank_offset_;

W
Wang Guibao 已提交
636 637 638 639 640 641
  // the batch size defined by user
  int default_batch_size_;
  // current batch size
  int batch_size_;

  bool finish_init_;
642
  bool finish_set_filelist_;
W
Wang Guibao 已提交
643
  bool finish_start_;
644
  std::string pipe_command_;
T
Thunderbrook 已提交
645 646
  std::string so_parser_name_;
  std::vector<SlotConf> slot_conf_;
647 648
  std::vector<std::string> ins_id_vec_;
  std::vector<std::string> ins_content_vec_;
649
  platform::Place place_;
650
  std::string uid_slot_;
651 652 653

  // The input type of pipe reader, 0 for one sample, 1 for one batch
  int input_type_;
W
Wang Guibao 已提交
654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
};

// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
// It use a read-thread to read file and parse data to a private-queue
// (thread level), and get data from this queue when trainer call Next().
template <typename T>
class PrivateQueueDataFeed : public DataFeed {
 public:
  PrivateQueueDataFeed() {}
  virtual ~PrivateQueueDataFeed() {}
  virtual bool Start();
  virtual int Next();

 protected:
  // The thread implementation function for reading file and parse.
  virtual void ReadThread();
  // This function is used to set private-queue size, and the most
  // efficient when the queue size is close to the batch size.
  virtual void SetQueueSize(int queue_size);
  // The reading and parsing method called in the ReadThread.
  virtual bool ParseOneInstance(T* instance) = 0;
D
dongdaxiang 已提交
675
  virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
W
Wang Guibao 已提交
676 677 678 679 680 681 682 683 684 685 686 687 688 689
  // This function is used to put instance to vec_ins
  virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
                                   int index) = 0;
  // This function is used to put ins_vec to feed_vec
  virtual void PutToFeedVec(const T& ins_vec) = 0;

  // The thread for read files
  std::thread read_thread_;
  // using ifstream one line and one line parse is faster
  // than using fread one buffer and one buffer parse.
  //   for a 601M real data:
  //     ifstream one line and one line parse: 6034 ms
  //     fread one buffer and one buffer parse: 7097 ms
  std::ifstream file_;
D
dongdaxiang 已提交
690
  std::shared_ptr<FILE> fp_;
W
Wang Guibao 已提交
691
  size_t queue_size_;
692
  string::LineFileReader reader_;
W
Wang Guibao 已提交
693
  // The queue for store parsed data
694
  std::shared_ptr<paddle::framework::ChannelObject<T>> queue_;
W
Wang Guibao 已提交
695 696
};

697
template <typename T>
J
jiaqi 已提交
698
class InMemoryDataFeed : public DataFeed {
699 700 701
 public:
  InMemoryDataFeed();
  virtual ~InMemoryDataFeed() {}
H
hutuxian 已提交
702
  virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
703 704
  virtual bool Start();
  virtual int Next();
705 706 707 708
  virtual void SetInputPvChannel(void* channel);
  virtual void SetOutputPvChannel(void* channel);
  virtual void SetConsumePvChannel(void* channel);

J
jiaqi 已提交
709 710 711
  virtual void SetInputChannel(void* channel);
  virtual void SetOutputChannel(void* channel);
  virtual void SetConsumeChannel(void* channel);
712 713
  virtual void SetThreadId(int thread_id);
  virtual void SetThreadNum(int thread_num);
714
  virtual void SetParseInsId(bool parse_ins_id);
715
  virtual void SetParseUid(bool parse_uid);
716
  virtual void SetParseContent(bool parse_content);
717 718 719
  virtual void SetParseLogKey(bool parse_logkey);
  virtual void SetEnablePvMerge(bool enable_pv_merge);
  virtual void SetCurrentPhase(int current_phase);
720
  virtual void LoadIntoMemory();
T
Thunderbrook 已提交
721
  virtual void LoadIntoMemoryFromSo();
Y
yaoxuefeng 已提交
722 723 724 725 726
  virtual void SetRecord(T* records) { records_ = records; }
  int GetDefaultBatchSize() { return default_batch_size_; }
  void AddBatchOffset(const std::pair<int, int>& offset) {
    batch_offsets_.push_back(offset);
  }
X
xujiaqi01 已提交
727

728 729 730
 protected:
  virtual bool ParseOneInstance(T* instance) = 0;
  virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
T
Thunderbrook 已提交
731 732
  virtual void ParseOneInstanceFromSo(const char* str, T* instance,
                                      CustomParser* parser) {}
J
jiaqi 已提交
733
  virtual void PutToFeedVec(const std::vector<T>& ins_vec) = 0;
Y
yaoxuefeng 已提交
734
  virtual void PutToFeedVec(const T* ins_vec, int num) = 0;
735

Y
yaoxuefeng 已提交
736 737 738 739 740
  std::vector<std::vector<float>> batch_float_feasigns_;
  std::vector<std::vector<uint64_t>> batch_uint64_feasigns_;
  std::vector<std::vector<size_t>> offset_;
  std::vector<bool> visit_;

741 742
  int thread_id_;
  int thread_num_;
743
  bool parse_ins_id_;
744
  bool parse_uid_;
745
  bool parse_content_;
746 747 748
  bool parse_logkey_;
  bool enable_pv_merge_;
  int current_phase_{-1};  // only for untest
J
jiaqi 已提交
749 750 751 752 753
  std::ifstream file_;
  std::shared_ptr<FILE> fp_;
  paddle::framework::ChannelObject<T>* input_channel_;
  paddle::framework::ChannelObject<T>* output_channel_;
  paddle::framework::ChannelObject<T>* consume_channel_;
754 755 756 757

  paddle::framework::ChannelObject<PvInstance>* input_pv_channel_;
  paddle::framework::ChannelObject<PvInstance>* output_pv_channel_;
  paddle::framework::ChannelObject<PvInstance>* consume_pv_channel_;
Y
yaoxuefeng 已提交
758 759 760 761 762

  std::vector<std::pair<int, int>> batch_offsets_;
  uint64_t offset_index_ = 0;
  bool enable_heterps_ = false;
  T* records_ = nullptr;
763 764
};

W
Wang Guibao 已提交
765 766 767 768 769
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
class MultiSlotType {
 public:
  MultiSlotType() {}
  ~MultiSlotType() {}
H
hutuxian 已提交
770
  void Init(const std::string& type, size_t reserved_size = 0) {
W
Wang Guibao 已提交
771 772 773
    CheckType(type);
    if (type_[0] == 'f') {
      float_feasign_.clear();
H
hutuxian 已提交
774 775 776
      if (reserved_size) {
        float_feasign_.reserve(reserved_size);
      }
W
Wang Guibao 已提交
777 778
    } else if (type_[0] == 'u') {
      uint64_feasign_.clear();
H
hutuxian 已提交
779 780 781
      if (reserved_size) {
        uint64_feasign_.reserve(reserved_size);
      }
W
Wang Guibao 已提交
782 783 784
    }
    type_ = type;
  }
H
hutuxian 已提交
785 786 787 788
  void InitOffset(size_t max_batch_size = 0) {
    if (max_batch_size > 0) {
      offset_.reserve(max_batch_size + 1);
    }
W
Wang Guibao 已提交
789 790 791 792 793 794
    offset_.resize(1);
    // LoDTensor' lod is counted from 0, the size of lod
    // is one size larger than the size of data.
    offset_[0] = 0;
  }
  const std::vector<size_t>& GetOffset() const { return offset_; }
795
  std::vector<size_t>& MutableOffset() { return offset_; }
W
Wang Guibao 已提交
796 797 798 799 800 801 802 803
  void AddValue(const float v) {
    CheckFloat();
    float_feasign_.push_back(v);
  }
  void AddValue(const uint64_t v) {
    CheckUint64();
    uint64_feasign_.push_back(v);
  }
H
hutuxian 已提交
804 805 806 807 808 809 810 811 812 813
  void CopyValues(const float* input, size_t size) {
    CheckFloat();
    float_feasign_.resize(size);
    memcpy(float_feasign_.data(), input, size * sizeof(float));
  }
  void CopyValues(const uint64_t* input, size_t size) {
    CheckUint64();
    uint64_feasign_.resize(size);
    memcpy(uint64_feasign_.data(), input, size * sizeof(uint64_t));
  }
W
Wang Guibao 已提交
814 815 816 817 818 819 820 821 822 823 824 825 826
  void AddIns(const MultiSlotType& ins) {
    if (ins.GetType()[0] == 'f') {  // float
      CheckFloat();
      auto& vec = ins.GetFloatData();
      offset_.push_back(offset_.back() + vec.size());
      float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end());
    } else if (ins.GetType()[0] == 'u') {  // uint64
      CheckUint64();
      auto& vec = ins.GetUint64Data();
      offset_.push_back(offset_.back() + vec.size());
      uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
    }
  }
H
hutuxian 已提交
827 828 829 830 831 832 833 834
  void AppendValues(const uint64_t* input, size_t size) {
    CheckUint64();
    offset_.push_back(offset_.back() + size);
    uint64_feasign_.insert(uint64_feasign_.end(), input, input + size);
  }
  void AppendValues(const float* input, size_t size) {
    CheckFloat();
    offset_.push_back(offset_.back() + size);
835

H
hutuxian 已提交
836 837
    float_feasign_.insert(float_feasign_.end(), input, input + size);
  }
W
Wang Guibao 已提交
838
  const std::vector<float>& GetFloatData() const { return float_feasign_; }
839
  std::vector<float>& MutableFloatData() { return float_feasign_; }
W
Wang Guibao 已提交
840
  const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
841
  std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
W
Wang Guibao 已提交
842
  const std::string& GetType() const { return type_; }
H
hutuxian 已提交
843
  size_t GetBatchSize() { return offset_.size() - 1; }
844
  std::string& MutableType() { return type_; }
W
Wang Guibao 已提交
845

X
xujiaqi01 已提交
846 847
  std::string DebugString() {
    std::stringstream ss;
W
wanghuancoder 已提交
848

849 850
    ss << "\ntype: " << type_ << "\n";
    ss << "offset: ";
X
xujiaqi01 已提交
851 852 853 854
    ss << "[";
    for (const size_t& i : offset_) {
      ss << offset_[i] << ",";
    }
855
    ss << "]\ndata: [";
X
xujiaqi01 已提交
856 857 858 859 860 861 862 863 864 865 866 867 868
    if (type_[0] == 'f') {
      for (const float& i : float_feasign_) {
        ss << i << ",";
      }
    } else {
      for (const uint64_t& i : uint64_feasign_) {
        ss << i << ",";
      }
    }
    ss << "]\n";
    return ss.str();
  }

W
Wang Guibao 已提交
869 870
 private:
  void CheckType(const std::string& type) const {
871 872 873 874 875
    PADDLE_ENFORCE_EQ((type == "uint64" || type == "float"), true,
                      platform::errors::InvalidArgument(
                          "MultiSlotType error, expect type is uint64 or "
                          "float, but received type is %s.",
                          type));
W
Wang Guibao 已提交
876 877
  }
  void CheckFloat() const {
878 879 880 881
    PADDLE_ENFORCE_EQ(
        type_[0], 'f',
        platform::errors::InvalidArgument(
            "MultiSlotType error, add %s value to float slot.", type_));
W
Wang Guibao 已提交
882 883
  }
  void CheckUint64() const {
884 885 886 887
    PADDLE_ENFORCE_EQ(
        type_[0], 'u',
        platform::errors::InvalidArgument(
            "MultiSlotType error, add %s value to uint64 slot.", type_));
W
Wang Guibao 已提交
888 889 890 891 892 893 894
  }
  std::vector<float> float_feasign_;
  std::vector<uint64_t> uint64_feasign_;
  std::string type_;
  std::vector<size_t> offset_;
};

J
jiaqi 已提交
895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
                                           const MultiSlotType& ins) {
  ar << ins.GetType();
#ifdef _LINUX
  ar << ins.GetOffset();
#else
  const auto& offset = ins.GetOffset();
  ar << (uint64_t)offset.size();
  for (const size_t& x : offset) {
    ar << (const uint64_t)x;
  }
#endif
  ar << ins.GetFloatData();
  ar << ins.GetUint64Data();
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
                                           MultiSlotType& ins) {
  ar >> ins.MutableType();
#ifdef _LINUX
  ar >> ins.MutableOffset();
#else
  auto& offset = ins.MutableOffset();
  offset.resize(ar.template Get<uint64_t>());
  for (size_t& x : offset) {
    uint64_t t;
    ar >> t;
Y
yaoxuefeng 已提交
925
    x = static_cast<size_t>(t);
J
jiaqi 已提交
926 927 928 929 930 931 932
  }
#endif
  ar >> ins.MutableFloatData();
  ar >> ins.MutableUint64Data();
  return ar;
}

933 934
struct RecordCandidate {
  std::string ins_id_;
T
Thunderbrook 已提交
935
  std::unordered_multimap<uint16_t, FeatureFeasign> feas_;
936 937 938 939 940 941 942 943 944 945 946 947
  size_t shadow_index_ = -1;  // Optimization for Reservoir Sample

  RecordCandidate() {}
  RecordCandidate(const Record& rec,
                  const std::unordered_set<uint16_t>& slot_index_to_replace) {
    for (const auto& fea : rec.uint64_feasigns_) {
      if (slot_index_to_replace.find(fea.slot()) !=
          slot_index_to_replace.end()) {
        feas_.insert({fea.slot(), fea.sign()});
      }
    }
  }
948 949

  RecordCandidate& operator=(const Record& rec) {
950
    feas_.clear();
951 952
    ins_id_ = rec.ins_id_;
    for (auto& fea : rec.uint64_feasigns_) {
953
      feas_.insert({fea.slot(), fea.sign()});
954 955 956 957 958 959 960 961
    }
    return *this;
  }
};

class RecordCandidateList {
 public:
  RecordCandidateList() = default;
962
  RecordCandidateList(const RecordCandidateList&) {}
963

964
  size_t Size() { return cur_size_; }
965 966 967
  void ReSize(size_t length);

  void ReInit();
968 969 970 971 972 973 974 975 976 977 978 979
  void ReInitPass() {
    for (size_t i = 0; i < cur_size_; ++i) {
      if (candidate_list_[i].shadow_index_ != i) {
        candidate_list_[i].ins_id_ =
            candidate_list_[candidate_list_[i].shadow_index_].ins_id_;
        candidate_list_[i].feas_.swap(
            candidate_list_[candidate_list_[i].shadow_index_].feas_);
        candidate_list_[i].shadow_index_ = i;
      }
    }
    candidate_list_.resize(cur_size_);
  }
980 981

  void AddAndGet(const Record& record, RecordCandidate* result);
982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
  void AddAndGet(const Record& record, size_t& index_result) {  // NOLINT
    // std::unique_lock<std::mutex> lock(mutex_);
    size_t index = 0;
    ++total_size_;
    auto fleet_ptr = FleetWrapper::GetInstance();
    if (!full_) {
      candidate_list_.emplace_back(record, slot_index_to_replace_);
      candidate_list_.back().shadow_index_ = cur_size_;
      ++cur_size_;
      full_ = (cur_size_ == capacity_);
    } else {
      index = fleet_ptr->LocalRandomEngine()() % total_size_;
      if (index < capacity_) {
        candidate_list_.emplace_back(record, slot_index_to_replace_);
        candidate_list_[index].shadow_index_ = candidate_list_.size() - 1;
      }
    }
    index = fleet_ptr->LocalRandomEngine()() % cur_size_;
    index_result = candidate_list_[index].shadow_index_;
  }
  const RecordCandidate& Get(size_t index) const {
    PADDLE_ENFORCE_LT(
        index, candidate_list_.size(),
        platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
                                     "elements in candidate_list[%lu].",
                                     index, candidate_list_.size()));
    return candidate_list_[index];
  }
  void SetSlotIndexToReplace(
      const std::unordered_set<uint16_t>& slot_index_to_replace) {
    slot_index_to_replace_ = slot_index_to_replace;
  }
1014 1015

 private:
1016 1017 1018 1019 1020 1021 1022
  size_t capacity_ = 0;
  std::mutex mutex_;
  bool full_ = false;
  size_t cur_size_ = 0;
  size_t total_size_ = 0;
  std::vector<RecordCandidate> candidate_list_;
  std::unordered_set<uint16_t> slot_index_to_replace_;
1023 1024
};

J
jiaqi 已提交
1025 1026
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
T
Thunderbrook 已提交
1027
                                           const FeatureFeasign& fk) {
J
jiaqi 已提交
1028 1029 1030 1031 1032 1033 1034
  ar << fk.uint64_feasign_;
  ar << fk.float_feasign_;
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
T
Thunderbrook 已提交
1035
                                           FeatureFeasign& fk) {
J
jiaqi 已提交
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074
  ar >> fk.uint64_feasign_;
  ar >> fk.float_feasign_;
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
                                           const FeatureItem& fi) {
  ar << fi.sign();
  ar << fi.slot();
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
                                           FeatureItem& fi) {
  ar >> fi.sign();
  ar >> fi.slot();
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
                                           const Record& r) {
  ar << r.uint64_feasigns_;
  ar << r.float_feasigns_;
  ar << r.ins_id_;
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
                                           Record& r) {
  ar >> r.uint64_feasigns_;
  ar >> r.float_feasigns_;
  ar >> r.ins_id_;
  return ar;
}

W
Wang Guibao 已提交
1075 1076 1077 1078 1079 1080 1081 1082
// This DataFeed is used to feed multi-slot type data.
// The format of multi-slot type data:
//   [n feasign_0 feasign_1 ... feasign_n]*
class MultiSlotDataFeed
    : public PrivateQueueDataFeed<std::vector<MultiSlotType>> {
 public:
  MultiSlotDataFeed() {}
  virtual ~MultiSlotDataFeed() {}
H
hutuxian 已提交
1083
  virtual void Init(const DataFeedDesc& data_feed_desc);
W
Wang Guibao 已提交
1084 1085 1086
  virtual bool CheckFile(const char* filename);

 protected:
D
dongdaxiang 已提交
1087
  virtual void ReadThread();
W
Wang Guibao 已提交
1088 1089 1090 1091
  virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
                                   const std::vector<MultiSlotType>& instance,
                                   int index);
  virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
D
dongdaxiang 已提交
1092
  virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
W
Wang Guibao 已提交
1093 1094
  virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
};
1095

J
jiaqi 已提交
1096
class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
1097 1098 1099
 public:
  MultiSlotInMemoryDataFeed() {}
  virtual ~MultiSlotInMemoryDataFeed() {}
H
hutuxian 已提交
1100
  virtual void Init(const DataFeedDesc& data_feed_desc);
Y
yaoxuefeng 已提交
1101
  // void SetRecord(Record* records) { records_ = records; }
1102

1103
 protected:
J
jiaqi 已提交
1104 1105
  virtual bool ParseOneInstance(Record* instance);
  virtual bool ParseOneInstanceFromPipe(Record* instance);
T
Thunderbrook 已提交
1106 1107
  virtual void ParseOneInstanceFromSo(const char* str, Record* instance,
                                      CustomParser* parser);
J
jiaqi 已提交
1108
  virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
1109 1110
  virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id,
                                uint32_t* cmatch, uint32_t* rank);
Y
yaoxuefeng 已提交
1111
  virtual void PutToFeedVec(const Record* ins_vec, int num);
1112 1113
};

Y
yaoxuefeng 已提交
1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149
class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
 public:
  SlotRecordInMemoryDataFeed() {}
  virtual ~SlotRecordInMemoryDataFeed() {}
  virtual void Init(const DataFeedDesc& data_feed_desc);
  virtual void LoadIntoMemory();
  void ExpandSlotRecord(SlotRecord* ins);

 protected:
  virtual bool Start();
  virtual int Next();
  virtual bool ParseOneInstance(SlotRecord* instance) { return false; }
  virtual bool ParseOneInstanceFromPipe(SlotRecord* instance) { return false; }
  // virtual void ParseOneInstanceFromSo(const char* str, T* instance,
  //                                    CustomParser* parser) {}
  virtual void PutToFeedVec(const std::vector<SlotRecord>& ins_vec) {}

  virtual void LoadIntoMemoryByCommand(void);
  virtual void LoadIntoMemoryByLib(void);
  virtual void LoadIntoMemoryByLine(void);
  virtual void LoadIntoMemoryByFile(void);
  virtual void SetInputChannel(void* channel) {
    input_channel_ = static_cast<ChannelObject<SlotRecord>*>(channel);
  }
  bool ParseOneInstance(const std::string& line, SlotRecord* rec);
  virtual void PutToFeedVec(const SlotRecord* ins_vec, int num);
  float sample_rate_ = 1.0f;
  int use_slot_size_ = 0;
  int float_use_slot_size_ = 0;
  int uint64_use_slot_size_ = 0;
  std::vector<AllSlotInfo> all_slots_info_;
  std::vector<UsedSlotInfo> used_slots_info_;
  size_t float_total_dims_size_ = 0;
  std::vector<int> float_total_dims_without_inductives_;
};

1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
 public:
  PaddleBoxDataFeed() {}
  virtual ~PaddleBoxDataFeed() {}

 protected:
  virtual void Init(const DataFeedDesc& data_feed_desc);
  virtual bool Start();
  virtual int Next();
  virtual void AssignFeedVar(const Scope& scope);
  virtual void PutToFeedVec(const std::vector<PvInstance>& pv_vec);
  virtual void PutToFeedVec(const std::vector<Record*>& ins_vec);
  virtual int GetCurrentPhase();
  virtual void GetRankOffset(const std::vector<PvInstance>& pv_vec,
                             int ins_number);
  std::string rank_offset_name_;
  int pv_batch_size_;
1167 1168
};

1169
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
H
hutuxian 已提交
1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217
template <typename T>
class PrivateInstantDataFeed : public DataFeed {
 public:
  PrivateInstantDataFeed() {}
  virtual ~PrivateInstantDataFeed() {}
  void Init(const DataFeedDesc& data_feed_desc) override;
  bool Start() override { return true; }
  int Next() override;

 protected:
  // The batched data buffer
  std::vector<MultiSlotType> ins_vec_;

  // This function is used to preprocess with a given filename, e.g. open it or
  // mmap
  virtual bool Preprocess(const std::string& filename) = 0;

  // This function is used to postprocess system resource such as closing file
  // NOTICE: Ensure that it is safe to call before Preprocess
  virtual bool Postprocess() = 0;

  // The reading and parsing method.
  virtual bool ParseOneMiniBatch() = 0;

  // This function is used to put ins_vec to feed_vec
  virtual void PutToFeedVec();
};

class MultiSlotFileInstantDataFeed
    : public PrivateInstantDataFeed<std::vector<MultiSlotType>> {
 public:
  MultiSlotFileInstantDataFeed() {}
  virtual ~MultiSlotFileInstantDataFeed() {}

 protected:
  int fd_{-1};
  char* buffer_{nullptr};
  size_t end_{0};
  size_t offset_{0};

  bool Preprocess(const std::string& filename) override;

  bool Postprocess() override;

  bool ParseOneMiniBatch() override;
};
#endif

W
Wang Guibao 已提交
1218 1219
}  // namespace framework
}  // namespace paddle