data_feed.h 58.7 KB
Newer Older
W
Wang Guibao 已提交
1 2 3 4 5 6
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

L
lxsbupt 已提交
7
http://www.apache.org/licenses/LICENSE-2.0
W
Wang Guibao 已提交
8 9 10 11 12 13 14 15 16

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

J
jiaqi 已提交
17 18 19 20 21
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif

W
Wang Guibao 已提交
22
#include <fstream>
23
#include <future>  // NOLINT
W
Wang Guibao 已提交
24 25
#include <memory>
#include <mutex>  // NOLINT
D
danleifeng 已提交
26
#include <random>
27
#include <sstream>
W
Wang Guibao 已提交
28 29
#include <string>
#include <thread>  // NOLINT
30
#include <unordered_map>
31
#include <unordered_set>
32
#include <utility>
33
#include <vector>
W
Wang Guibao 已提交
34

J
jiaqi 已提交
35
#include "paddle/fluid/framework/archive.h"
36
#include "paddle/fluid/framework/blocking_queue.h"
J
jiaqi 已提交
37
#include "paddle/fluid/framework/channel.h"
W
Wang Guibao 已提交
38
#include "paddle/fluid/framework/data_feed.pb.h"
39
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
W
Wang Guibao 已提交
40 41 42
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
Y
yaoxuefeng 已提交
43
#include "paddle/fluid/platform/timer.h"
44
#include "paddle/fluid/string/string_helper.h"
45
#if defined(PADDLE_WITH_CUDA)
D
danleifeng 已提交
46
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h"
47 48 49
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#endif
W
Wang Guibao 已提交
50

Y
yaoxuefeng 已提交
51 52 53 54 55
DECLARE_int32(record_pool_max_size);
DECLARE_int32(slotpool_thread_num);
DECLARE_bool(enable_slotpool_wait_release);
DECLARE_bool(enable_slotrecord_reset_shrink);

W
wanghuancoder 已提交
56 57 58 59 60
namespace paddle {
namespace framework {
class DataFeedDesc;
class Scope;
class Variable;
D
danleifeng 已提交
61 62
class NeighborSampleResult;
class NodeQueryResult;
L
lxsbupt 已提交
63 64
template <typename KeyType, typename ValType>
class HashTable;
W
wanghuancoder 已提交
65 66 67
}  // namespace framework
}  // namespace paddle

68
namespace phi {
69
class DenseTensor;
70
}  // namespace phi
71

W
Wang Guibao 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
namespace paddle {
namespace framework {

// DataFeed is the base virtual class for all ohther DataFeeds.
// It is used to read files and parse the data for subsequent trainer.
// Example:
//   DataFeed* reader =
//   paddle::framework::DataFeedFactory::CreateDataFeed(data_feed_name);
//   reader->Init(data_feed_desc); // data_feed_desc is a protobuf object
//   reader->SetFileList(filelist);
//   const std::vector<std::string> & use_slot_alias =
//   reader->GetUseSlotAlias();
//   for (auto name: use_slot_alias){ // for binding memory
//     reader->AddFeedVar(scope->Var(name), name);
//   }
//   reader->Start();
//   while (reader->Next()) {
//      // trainer do something
//   }
Y
yaoxuefeng 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

template <typename T>
struct SlotValues {
  std::vector<T> slot_values;
  std::vector<uint32_t> slot_offsets;

  void add_values(const T* values, uint32_t num) {
    if (slot_offsets.empty()) {
      slot_offsets.push_back(0);
    }
    if (num > 0) {
      slot_values.insert(slot_values.end(), values, values + num);
    }
    slot_offsets.push_back(static_cast<uint32_t>(slot_values.size()));
  }
  T* get_values(int idx, size_t* size) {
    uint32_t& offset = slot_offsets[idx];
    (*size) = slot_offsets[idx + 1] - offset;
    return &slot_values[offset];
  }
  void add_slot_feasigns(const std::vector<std::vector<T>>& slot_feasigns,
                         uint32_t fea_num) {
    slot_values.reserve(fea_num);
    int slot_num = static_cast<int>(slot_feasigns.size());
    slot_offsets.resize(slot_num + 1);
    for (int i = 0; i < slot_num; ++i) {
      auto& slot_val = slot_feasigns[i];
      slot_offsets[i] = static_cast<uint32_t>(slot_values.size());
      uint32_t num = static_cast<uint32_t>(slot_val.size());
      if (num > 0) {
        slot_values.insert(slot_values.end(), slot_val.begin(), slot_val.end());
      }
    }
    slot_offsets[slot_num] = slot_values.size();
  }
  void clear(bool shrink) {
    slot_offsets.clear();
    slot_values.clear();
    if (shrink) {
      slot_values.shrink_to_fit();
      slot_offsets.shrink_to_fit();
    }
  }
};
T
Thunderbrook 已提交
135
union FeatureFeasign {
136 137 138 139 140 141
  uint64_t uint64_feasign_;
  float float_feasign_;
};

struct FeatureItem {
  FeatureItem() {}
T
Thunderbrook 已提交
142
  FeatureItem(FeatureFeasign sign, uint16_t slot) {
143 144 145
    this->sign() = sign;
    this->slot() = slot;
  }
T
Thunderbrook 已提交
146 147 148 149 150 151
  FeatureFeasign& sign() {
    return *(reinterpret_cast<FeatureFeasign*>(sign_buffer()));
  }
  const FeatureFeasign& sign() const {
    const FeatureFeasign* ret =
        reinterpret_cast<FeatureFeasign*>(sign_buffer());
152 153 154 155 156 157 158
    return *ret;
  }
  uint16_t& slot() { return slot_; }
  const uint16_t& slot() const { return slot_; }

 private:
  char* sign_buffer() const { return const_cast<char*>(sign_); }
T
Thunderbrook 已提交
159
  char sign_[sizeof(FeatureFeasign)];
160 161 162
  uint16_t slot_;
};

Y
yaoxuefeng 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
struct AllSlotInfo {
  std::string slot;
  std::string type;
  int used_idx;
  int slot_value_idx;
};
struct UsedSlotInfo {
  int idx;
  int slot_value_idx;
  std::string slot;
  std::string type;
  bool dense;
  std::vector<int> local_shape;
  int total_dims_without_inductive;
  int inductive_shape_index;
};
struct SlotRecordObject {
  uint64_t search_id;
  uint32_t rank;
  uint32_t cmatch;
  std::string ins_id_;
  SlotValues<uint64_t> slot_uint64_feasigns_;
  SlotValues<float> slot_float_feasigns_;

  ~SlotRecordObject() { clear(true); }
  void reset(void) { clear(FLAGS_enable_slotrecord_reset_shrink); }
  void clear(bool shrink) {
    slot_uint64_feasigns_.clear(shrink);
    slot_float_feasigns_.clear(shrink);
  }
};
using SlotRecord = SlotRecordObject*;
195 196 197 198 199 200 201 202 203
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
  std::vector<FeatureItem> uint64_feasigns_;
  std::vector<FeatureItem> float_feasigns_;
  std::string ins_id_;
  std::string content_;
  uint64_t search_id;
  uint32_t rank;
  uint32_t cmatch;
204
  std::string uid_;
205 206
};

Y
yaoxuefeng 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
inline SlotRecord make_slotrecord() {
  static const size_t slot_record_byte_size = sizeof(SlotRecordObject);
  void* p = malloc(slot_record_byte_size);
  new (p) SlotRecordObject;
  return reinterpret_cast<SlotRecordObject*>(p);
}

inline void free_slotrecord(SlotRecordObject* p) {
  p->~SlotRecordObject();
  free(p);
}

template <class T>
class SlotObjAllocator {
 public:
  explicit SlotObjAllocator(std::function<void(T*)> deleter)
      : free_nodes_(NULL), capacity_(0), deleter_(deleter) {}
  ~SlotObjAllocator() { clear(); }

  void clear() {
    T* tmp = NULL;
    while (free_nodes_ != NULL) {
      tmp = reinterpret_cast<T*>(reinterpret_cast<void*>(free_nodes_));
      free_nodes_ = free_nodes_->next;
      deleter_(tmp);
      --capacity_;
    }
    CHECK_EQ(capacity_, static_cast<size_t>(0));
  }
  T* acquire(void) {
    T* x = NULL;
    x = reinterpret_cast<T*>(reinterpret_cast<void*>(free_nodes_));
    free_nodes_ = free_nodes_->next;
    --capacity_;
    return x;
  }
  void release(T* x) {
    Node* node = reinterpret_cast<Node*>(reinterpret_cast<void*>(x));
    node->next = free_nodes_;
    free_nodes_ = node;
    ++capacity_;
  }
  size_t capacity(void) { return capacity_; }

 private:
  struct alignas(T) Node {
    union {
      Node* next;
      char data[sizeof(T)];
    };
  };
  Node* free_nodes_;  // a list
  size_t capacity_;
  std::function<void(T*)> deleter_ = nullptr;
};
static const int OBJPOOL_BLOCK_SIZE = 10000;
class SlotObjPool {
 public:
  SlotObjPool()
      : max_capacity_(FLAGS_record_pool_max_size), alloc_(free_slotrecord) {
    ins_chan_ = MakeChannel<SlotRecord>();
    ins_chan_->SetBlockSize(OBJPOOL_BLOCK_SIZE);
    for (int i = 0; i < FLAGS_slotpool_thread_num; ++i) {
      threads_.push_back(std::thread([this]() { run(); }));
    }
    disable_pool_ = false;
    count_ = 0;
  }
  ~SlotObjPool() {
    ins_chan_->Close();
    for (auto& t : threads_) {
      t.join();
    }
  }
  void disable_pool(bool disable) { disable_pool_ = disable; }
  void set_max_capacity(size_t max_capacity) { max_capacity_ = max_capacity; }
  void get(std::vector<SlotRecord>* output, int n) {
    output->resize(n);
    return get(&(*output)[0], n);
  }
  void get(SlotRecord* output, int n) {
    int size = 0;
    mutex_.lock();
    int left = static_cast<int>(alloc_.capacity());
    if (left > 0) {
      size = (left >= n) ? n : left;
      for (int i = 0; i < size; ++i) {
        output[i] = alloc_.acquire();
      }
    }
    mutex_.unlock();
    count_ += n;
    if (size == n) {
      return;
    }
    for (int i = size; i < n; ++i) {
      output[i] = make_slotrecord();
    }
  }
  void put(std::vector<SlotRecord>* input) {
    size_t size = input->size();
    if (size == 0) {
      return;
    }
    put(&(*input)[0], size);
    input->clear();
  }
  void put(SlotRecord* input, size_t size) {
    CHECK(ins_chan_->WriteMove(size, input) == size);
  }
  void run(void) {
    std::vector<SlotRecord> input;
    while (ins_chan_->ReadOnce(input, OBJPOOL_BLOCK_SIZE)) {
      if (input.empty()) {
        continue;
      }
      // over max capacity
      size_t n = input.size();
      count_ -= n;
      if (disable_pool_ || n + capacity() > max_capacity_) {
        for (auto& t : input) {
          free_slotrecord(t);
        }
      } else {
        for (auto& t : input) {
          t->reset();
        }
        mutex_.lock();
        for (auto& t : input) {
          alloc_.release(t);
        }
        mutex_.unlock();
      }
      input.clear();
    }
  }
  void clear(void) {
    platform::Timer timeline;
    timeline.Start();
    mutex_.lock();
    alloc_.clear();
    mutex_.unlock();
    // wait release channel data
    if (FLAGS_enable_slotpool_wait_release) {
      while (!ins_chan_->Empty()) {
        sleep(1);
      }
    }
    timeline.Pause();
    VLOG(3) << "clear slot pool data size=" << count_.load()
            << ", span=" << timeline.ElapsedSec();
  }
  size_t capacity(void) {
    mutex_.lock();
    size_t total = alloc_.capacity();
    mutex_.unlock();
    return total;
  }

 private:
  size_t max_capacity_;
  Channel<SlotRecord> ins_chan_;
  std::vector<std::thread> threads_;
  std::mutex mutex_;
  SlotObjAllocator<SlotRecordObject> alloc_;
  bool disable_pool_;
  std::atomic<long> count_;  // NOLINT
};

inline SlotObjPool& SlotRecordPool() {
  static SlotObjPool pool;
  return pool;
}
380 381 382 383 384 385 386 387 388
struct PvInstanceObject {
  std::vector<Record*> ads;
  void merge_instance(Record* ins) { ads.push_back(ins); }
};

using PvInstance = PvInstanceObject*;

inline PvInstance make_pv_instance() { return new PvInstanceObject(); }

T
Thunderbrook 已提交
389 390 391 392 393 394 395 396 397 398 399 400
struct SlotConf {
  std::string name;
  std::string type;
  int use_slots_index;
  int use_slots_is_dense;
};

class CustomParser {
 public:
  CustomParser() {}
  virtual ~CustomParser() {}
  virtual void Init(const std::vector<SlotConf>& slots) = 0;
T
Thunderbrook 已提交
401
  virtual bool Init(const std::vector<AllSlotInfo>& slots) = 0;
T
Thunderbrook 已提交
402
  virtual void ParseOneInstance(const char* str, Record* instance) = 0;
403 404
  virtual int ParseInstance(int len,
                            const char* str,
T
Thunderbrook 已提交
405 406
                            std::vector<Record>* instances) {
    return 0;
407
  }
Y
yaoxuefeng 已提交
408 409 410 411 412 413 414 415 416 417 418 419 420
  virtual bool ParseOneInstance(
      const std::string& line,
      std::function<void(std::vector<SlotRecord>&, int)>
          GetInsFunc) {  // NOLINT
    return true;
  }
  virtual bool ParseFileInstance(
      std::function<int(char* buf, int len)> ReadBuffFunc,
      std::function<void(std::vector<SlotRecord>&, int, int)>
          PullRecordsFunc,  // NOLINT
      int& lines) {         // NOLINT
    return false;
  }
T
Thunderbrook 已提交
421 422
};

423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
struct UsedSlotGpuType {
  int is_uint64_value;
  int slot_value_idx;
};

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
template <typename T>
struct CudaBuffer {
  T* cu_buffer;
  uint64_t buf_size;

  CudaBuffer<T>() {
    cu_buffer = NULL;
    buf_size = 0;
  }
  ~CudaBuffer<T>() { free(); }
  T* data() { return cu_buffer; }
  uint64_t size() { return buf_size; }
  void malloc(uint64_t size) {
    buf_size = size;
    CUDA_CHECK(
        cudaMalloc(reinterpret_cast<void**>(&cu_buffer), size * sizeof(T)));
  }
  void free() {
    if (cu_buffer != NULL) {
      CUDA_CHECK(cudaFree(cu_buffer));
      cu_buffer = NULL;
    }
    buf_size = 0;
  }
  void resize(uint64_t size) {
    if (size <= buf_size) {
      return;
    }
    free();
    malloc(size);
  }
};
template <typename T>
struct HostBuffer {
  T* host_buffer;
  size_t buf_size;
  size_t data_len;

  HostBuffer<T>() {
    host_buffer = NULL;
    buf_size = 0;
    data_len = 0;
  }
  ~HostBuffer<T>() { free(); }

  T* data() { return host_buffer; }
  const T* data() const { return host_buffer; }
  size_t size() const { return data_len; }
  void clear() { free(); }
  T& back() { return host_buffer[data_len - 1]; }

  T& operator[](size_t i) { return host_buffer[i]; }
  const T& operator[](size_t i) const { return host_buffer[i]; }
  void malloc(size_t len) {
    buf_size = len;
    CUDA_CHECK(cudaHostAlloc(reinterpret_cast<void**>(&host_buffer),
485 486
                             buf_size * sizeof(T),
                             cudaHostAllocDefault));
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
    CHECK(host_buffer != NULL);
  }
  void free() {
    if (host_buffer != NULL) {
      CUDA_CHECK(cudaFreeHost(host_buffer));
      host_buffer = NULL;
    }
    buf_size = 0;
  }
  void resize(size_t size) {
    if (size <= buf_size) {
      data_len = size;
      return;
    }
    data_len = size;
    free();
    malloc(size);
  }
};

struct BatchCPUValue {
  HostBuffer<int> h_uint64_lens;
  HostBuffer<uint64_t> h_uint64_keys;
  HostBuffer<int> h_uint64_offset;

  HostBuffer<int> h_float_lens;
  HostBuffer<float> h_float_keys;
  HostBuffer<int> h_float_offset;

  HostBuffer<int> h_rank;
  HostBuffer<int> h_cmatch;
  HostBuffer<int> h_ad_offset;
};

struct BatchGPUValue {
  CudaBuffer<int> d_uint64_lens;
  CudaBuffer<uint64_t> d_uint64_keys;
  CudaBuffer<int> d_uint64_offset;

  CudaBuffer<int> d_float_lens;
  CudaBuffer<float> d_float_keys;
  CudaBuffer<int> d_float_offset;

  CudaBuffer<int> d_rank;
  CudaBuffer<int> d_cmatch;
  CudaBuffer<int> d_ad_offset;
};

class MiniBatchGpuPack {
 public:
  MiniBatchGpuPack(const paddle::platform::Place& place,
                   const std::vector<UsedSlotInfo>& infos);
  ~MiniBatchGpuPack();
  void reset(const paddle::platform::Place& place);
  void pack_instance(const SlotRecord* ins_vec, int num);
  int ins_num() { return ins_num_; }
  int pv_num() { return pv_num_; }
  BatchGPUValue& value() { return value_; }
  BatchCPUValue& cpu_value() { return buf_; }
  UsedSlotGpuType* get_gpu_slots(void) {
    return reinterpret_cast<UsedSlotGpuType*>(gpu_slots_.data());
  }
  SlotRecord* get_records(void) { return &ins_vec_[0]; }

  // tensor gpu memory reused
  void resize_tensor(void) {
    if (used_float_num_ > 0) {
      int float_total_len = buf_.h_float_lens.back();
      if (float_total_len > 0) {
        float_tensor_.mutable_data<float>({float_total_len, 1}, this->place_);
      }
    }
    if (used_uint64_num_ > 0) {
      int uint64_total_len = buf_.h_uint64_lens.back();
      if (uint64_total_len > 0) {
        uint64_tensor_.mutable_data<int64_t>({uint64_total_len, 1},
                                             this->place_);
      }
    }
  }
567 568
  phi::DenseTensor& float_tensor(void) { return float_tensor_; }
  phi::DenseTensor& uint64_tensor(void) { return uint64_tensor_; }
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605

  HostBuffer<size_t>& offsets(void) { return offsets_; }
  HostBuffer<void*>& h_tensor_ptrs(void) { return h_tensor_ptrs_; }

  void* gpu_slot_offsets(void) { return gpu_slot_offsets_->ptr(); }

  void* slot_buf_ptr(void) { return slot_buf_ptr_->ptr(); }

  void resize_gpu_slot_offsets(const size_t slot_total_bytes) {
    if (gpu_slot_offsets_ == nullptr) {
      gpu_slot_offsets_ = memory::AllocShared(place_, slot_total_bytes);
    } else if (gpu_slot_offsets_->size() < slot_total_bytes) {
      auto buf = memory::AllocShared(place_, slot_total_bytes);
      gpu_slot_offsets_.swap(buf);
      buf = nullptr;
    }
  }
  const std::string& get_lineid(int idx) {
    if (enable_pv_) {
      return ins_vec_[idx]->ins_id_;
    }
    return batch_ins_[idx]->ins_id_;
  }

 private:
  void transfer_to_gpu(void);
  void pack_all_data(const SlotRecord* ins_vec, int num);
  void pack_uint64_data(const SlotRecord* ins_vec, int num);
  void pack_float_data(const SlotRecord* ins_vec, int num);

 public:
  template <typename T>
  void copy_host2device(CudaBuffer<T>* buf, const T* val, size_t size) {
    if (size == 0) {
      return;
    }
    buf->resize(size);
606 607
    CUDA_CHECK(cudaMemcpyAsync(
        buf->data(), val, size * sizeof(T), cudaMemcpyHostToDevice, stream_));
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
  }
  template <typename T>
  void copy_host2device(CudaBuffer<T>* buf, const HostBuffer<T>& val) {
    copy_host2device(buf, val.data(), val.size());
  }

 private:
  paddle::platform::Place place_;
  cudaStream_t stream_;
  BatchGPUValue value_;
  BatchCPUValue buf_;
  int ins_num_ = 0;
  int pv_num_ = 0;

  bool enable_pv_ = false;
  int used_float_num_ = 0;
  int used_uint64_num_ = 0;
  int used_slot_size_ = 0;

  CudaBuffer<UsedSlotGpuType> gpu_slots_;
  std::vector<UsedSlotGpuType> gpu_used_slots_;
  std::vector<SlotRecord> ins_vec_;
  const SlotRecord* batch_ins_ = nullptr;

  // uint64 tensor
633
  phi::DenseTensor uint64_tensor_;
634
  // float tensor
635
  phi::DenseTensor float_tensor_;
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682
  // batch
  HostBuffer<size_t> offsets_;
  HostBuffer<void*> h_tensor_ptrs_;

  std::shared_ptr<phi::Allocation> gpu_slot_offsets_ = nullptr;
  std::shared_ptr<phi::Allocation> slot_buf_ptr_ = nullptr;
};
class MiniBatchGpuPackMgr {
  static const int MAX_DEIVCE_NUM = 16;

 public:
  MiniBatchGpuPackMgr() {
    for (int i = 0; i < MAX_DEIVCE_NUM; ++i) {
      pack_list_[i] = nullptr;
    }
  }
  ~MiniBatchGpuPackMgr() {
    for (int i = 0; i < MAX_DEIVCE_NUM; ++i) {
      if (pack_list_[i] == nullptr) {
        continue;
      }
      delete pack_list_[i];
      pack_list_[i] = nullptr;
    }
  }
  // one device one thread
  MiniBatchGpuPack* get(const paddle::platform::Place& place,
                        const std::vector<UsedSlotInfo>& infos) {
    int device_id = place.GetDeviceId();
    if (pack_list_[device_id] == nullptr) {
      pack_list_[device_id] = new MiniBatchGpuPack(place, infos);
    } else {
      pack_list_[device_id]->reset(place);
    }
    return pack_list_[device_id];
  }

 private:
  MiniBatchGpuPack* pack_list_[MAX_DEIVCE_NUM];
};
// global mgr
inline MiniBatchGpuPackMgr& BatchGpuPackMgr() {
  static MiniBatchGpuPackMgr mgr;
  return mgr;
}
#endif

T
Thunderbrook 已提交
683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
typedef paddle::framework::CustomParser* (*CreateParserObjectFunc)();

class DLManager {
  struct DLHandle {
    void* module;
    paddle::framework::CustomParser* parser;
  };

 public:
  DLManager() {}

  ~DLManager() {
#ifdef _LINUX
    std::lock_guard<std::mutex> lock(mutex_);
    for (auto it = handle_map_.begin(); it != handle_map_.end(); ++it) {
      delete it->second.parser;
      dlclose(it->second.module);
    }
#endif
  }

  bool Close(const std::string& name) {
#ifdef _LINUX
    auto it = handle_map_.find(name);
    if (it == handle_map_.end()) {
      return true;
    }
    delete it->second.parser;
    dlclose(it->second.module);
#endif
    VLOG(0) << "Not implement in windows";
    return false;
  }

  paddle::framework::CustomParser* Load(const std::string& name,
Y
yaoxuefeng 已提交
718
                                        const std::vector<SlotConf>& conf) {
T
Thunderbrook 已提交
719 720 721 722 723 724 725 726 727 728
#ifdef _LINUX
    std::lock_guard<std::mutex> lock(mutex_);
    DLHandle handle;
    std::map<std::string, DLHandle>::iterator it = handle_map_.find(name);
    if (it != handle_map_.end()) {
      return it->second.parser;
    }

    handle.module = dlopen(name.c_str(), RTLD_NOW);
    if (handle.module == nullptr) {
T
Thunderbrook 已提交
729
      VLOG(0) << "Create so of " << name << " fail, " << dlerror();
T
Thunderbrook 已提交
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744
      return nullptr;
    }

    CreateParserObjectFunc create_parser_func =
        (CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject");
    handle.parser = create_parser_func();
    handle.parser->Init(conf);
    handle_map_.insert({name, handle});

    return handle.parser;
#endif
    VLOG(0) << "Not implement in windows";
    return nullptr;
  }

Y
yaoxuefeng 已提交
745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
  paddle::framework::CustomParser* Load(const std::string& name,
                                        const std::vector<AllSlotInfo>& conf) {
#ifdef _LINUX
    std::lock_guard<std::mutex> lock(mutex_);
    DLHandle handle;
    std::map<std::string, DLHandle>::iterator it = handle_map_.find(name);
    if (it != handle_map_.end()) {
      return it->second.parser;
    }
    handle.module = dlopen(name.c_str(), RTLD_NOW);
    if (handle.module == nullptr) {
      VLOG(0) << "Create so of " << name << " fail";
      exit(-1);
      return nullptr;
    }

    CreateParserObjectFunc create_parser_func =
        (CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject");
    handle.parser = create_parser_func();
    handle.parser->Init(conf);
    handle_map_.insert({name, handle});

    return handle.parser;
#endif
    VLOG(0) << "Not implement in windows";
    return nullptr;
  }

T
Thunderbrook 已提交
773
  paddle::framework::CustomParser* ReLoad(const std::string& name,
Y
yaoxuefeng 已提交
774
                                          const std::vector<SlotConf>& conf) {
T
Thunderbrook 已提交
775 776 777 778 779 780 781 782 783
    Close(name);
    return Load(name, conf);
  }

 private:
  std::mutex mutex_;
  std::map<std::string, DLHandle> handle_map_;
};

D
danleifeng 已提交
784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
struct engine_wrapper_t {
  std::default_random_engine engine;
#if !defined(_WIN32)
  engine_wrapper_t() {
    struct timespec tp;
    clock_gettime(CLOCK_REALTIME, &tp);
    double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
    static std::atomic<uint64_t> x(0);
    std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
    engine.seed(sseq);
  }
#endif
};

struct BufState {
  int left;
  int right;
  int central_word;
  int step;
  engine_wrapper_t random_engine_;

  int len;
  int cursor;
  int row_num;

  int batch_size;
  int walk_len;
  std::vector<int>* window;

  BufState() {}
  ~BufState() {}

  void Init(int graph_batch_size,
            int graph_walk_len,
            std::vector<int>* graph_window) {
    batch_size = graph_batch_size;
    walk_len = graph_walk_len;
    window = graph_window;

    left = 0;
    right = window->size() - 1;
    central_word = -1;
    step = -1;

    len = 0;
    cursor = 0;
    row_num = 0;
    for (size_t i = 0; i < graph_window->size(); i++) {
      VLOG(2) << "graph_window[" << i << "] = " << (*graph_window)[i];
    }
  }

  void Reset(int total_rows) {
    cursor = 0;
    row_num = total_rows;
    int tmp_len = cursor + batch_size > row_num ? row_num - cursor : batch_size;
    len = tmp_len;
    central_word = -1;
    step = -1;
    GetNextCentrolWord();
  }

  int GetNextStep() {
    step++;
    if (step <= right && central_word + (*window)[step] < walk_len) {
      return 1;
    }
    return 0;
  }

  void Debug() {
    VLOG(2) << "left: " << left << " right: " << right
            << " central_word: " << central_word << " step: " << step
            << " cursor: " << cursor << " len: " << len
            << " row_num: " << row_num;
  }

  int GetNextCentrolWord() {
    if (++central_word >= walk_len) {
      return 0;
    }
    int window_size = window->size() / 2;
    int random_window = random_engine_.engine() % window_size + 1;
    left = window_size - random_window;
    right = window_size + random_window - 1;
    VLOG(2) << "random window: " << random_window << " window[" << left
            << "] = " << (*window)[left] << " window[" << right
            << "] = " << (*window)[right];

    for (step = left; step <= right; step++) {
      if (central_word + (*window)[step] >= 0) {
        return 1;
      }
    }
    return 0;
  }

  int GetNextBatch() {
    cursor += len;
L
lxsbupt 已提交
883 884 885
    if (row_num - cursor < 0) {
      return 0;
    }
D
danleifeng 已提交
886 887 888 889 890 891 892 893 894 895 896 897 898 899
    int tmp_len = cursor + batch_size > row_num ? row_num - cursor : batch_size;
    if (tmp_len == 0) {
      return 0;
    }
    len = tmp_len;
    central_word = -1;
    step = -1;
    GetNextCentrolWord();
    return tmp_len != 0;
  }
};

class GraphDataGenerator {
 public:
900 901
  GraphDataGenerator() {}
  virtual ~GraphDataGenerator() {}
D
danleifeng 已提交
902
  void SetConfig(const paddle::framework::DataFeedDesc& data_feed_desc);
L
lxsbupt 已提交
903 904 905
  void AllocResource(int thread_id, std::vector<phi::DenseTensor*> feed_vec);
  void AllocTrainResource(int thread_id);
  void SetFeedVec(std::vector<phi::DenseTensor*> feed_vec);
D
danleifeng 已提交
906 907
  int AcquireInstance(BufState* state);
  int GenerateBatch();
L
lxsbupt 已提交
908 909 910 911 912
  int FillWalkBuf();
  int FillWalkBufMultiPath();
  int FillInferBuf();
  void DoWalkandSage();
  int FillSlotFeature(uint64_t* d_walk);
D
danleifeng 已提交
913 914 915 916
  int FillFeatureBuf(uint64_t* d_walk, uint64_t* d_feature, size_t key_num);
  int FillFeatureBuf(std::shared_ptr<phi::Allocation> d_walk,
                     std::shared_ptr<phi::Allocation> d_feature);
  void FillOneStep(uint64_t* start_ids,
917
                   int etype_id,
D
danleifeng 已提交
918
                   uint64_t* walk,
919
                   uint8_t* walk_ntype,
D
danleifeng 已提交
920
                   int len,
921
                   NeighborSampleResult& sample_res,  // NOLINT
D
danleifeng 已提交
922 923 924
                   int cur_degree,
                   int step,
                   int* len_per_row);
L
lxsbupt 已提交
925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942
  int FillInsBuf(cudaStream_t stream);
  int FillIdShowClkTensor(int total_instance,
                          bool gpu_graph_training,
                          size_t cursor = 0);
  int FillGraphIdShowClkTensor(int uniq_instance,
                               int total_instance,
                               int index);
  int FillGraphSlotFeature(
      int total_instance,
      bool gpu_graph_training,
      std::shared_ptr<phi::Allocation> final_sage_nodes = nullptr);
  int FillSlotFeature(uint64_t* d_walk, size_t key_num);
  int MakeInsPair(cudaStream_t stream);
  uint64_t CopyUniqueNodes();
  int GetPathNum() { return total_row_; }
  void ResetPathNum() { total_row_ = 0; }
  void ResetEpochFinish() { epoch_finish_ = false; }
  void ClearSampleState();
943
  void DumpWalkPath(std::string dump_path, size_t dump_rate);
D
danleifeng 已提交
944
  void SetDeviceKeys(std::vector<uint64_t>* device_keys, int type) {
L
lxsbupt 已提交
945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971
    // type_to_index_[type] = h_device_keys_.size();
    // h_device_keys_.push_back(device_keys);
  }

  std::vector<std::shared_ptr<phi::Allocation>> SampleNeighbors(
      int64_t* uniq_nodes,
      int len,
      int sample_size,
      std::vector<int>& edges_split_num,  // NOLINT
      int64_t* neighbor_len);
  std::shared_ptr<phi::Allocation> FillReindexHashTable(int64_t* input,
                                                        int num_input,
                                                        int64_t len_hashtable,
                                                        int64_t* keys,
                                                        int* values,
                                                        int* key_index,
                                                        int* final_nodes_len);
  std::shared_ptr<phi::Allocation> GetReindexResult(int64_t* reindex_src_data,
                                                    int64_t* center_nodes,
                                                    int* final_nodes_len,
                                                    int node_len,
                                                    int64_t neighbor_len);
  std::shared_ptr<phi::Allocation> GenerateSampleGraph(
      uint64_t* node_ids,
      int len,
      int* uniq_len,
      std::shared_ptr<phi::Allocation>& inverse);  // NOLINT
972
  std::shared_ptr<phi::Allocation> GetNodeDegree(uint64_t* node_ids, int len);
L
lxsbupt 已提交
973 974 975 976 977 978
  int InsertTable(const uint64_t* d_keys,
                  uint64_t len,
                  std::shared_ptr<phi::Allocation> d_uniq_node_num);
  std::vector<uint64_t>& GetHostVec() { return host_vec_; }
  bool get_epoch_finish() { return epoch_finish_; }
  void clear_gpu_mem();
D
danleifeng 已提交
979 980

 protected:
L
lxsbupt 已提交
981
  HashTable<uint64_t, uint64_t>* table_;
D
danleifeng 已提交
982 983 984 985 986 987
  int walk_degree_;
  int walk_len_;
  int window_;
  int once_sample_startid_len_;
  int gpuid_;
  size_t cursor_;
L
lxsbupt 已提交
988
  int thread_id_;
D
danleifeng 已提交
989
  size_t jump_rows_;
L
lxsbupt 已提交
990
  int edge_to_id_len_;
D
danleifeng 已提交
991
  int64_t* id_tensor_ptr_;
L
lxsbupt 已提交
992
  int* index_tensor_ptr_;
D
danleifeng 已提交
993 994
  int64_t* show_tensor_ptr_;
  int64_t* clk_tensor_ptr_;
995
  int* degree_tensor_ptr_;
L
lxsbupt 已提交
996 997 998

  cudaStream_t train_stream_;
  cudaStream_t sample_stream_;
D
danleifeng 已提交
999
  paddle::platform::Place place_;
1000
  std::vector<phi::DenseTensor*> feed_vec_;
D
danleifeng 已提交
1001 1002 1003
  std::vector<size_t> offset_;
  std::shared_ptr<phi::Allocation> d_prefix_sum_;
  std::vector<std::shared_ptr<phi::Allocation>> d_device_keys_;
L
lxsbupt 已提交
1004
  std::shared_ptr<phi::Allocation> d_train_metapath_keys_;
D
danleifeng 已提交
1005 1006

  std::shared_ptr<phi::Allocation> d_walk_;
1007 1008
  std::shared_ptr<phi::Allocation> d_walk_ntype_;
  std::shared_ptr<phi::Allocation> d_excluded_train_pair_;
L
lxsbupt 已提交
1009
  std::shared_ptr<phi::Allocation> d_feature_list_;
D
danleifeng 已提交
1010 1011 1012
  std::shared_ptr<phi::Allocation> d_feature_;
  std::shared_ptr<phi::Allocation> d_len_per_row_;
  std::shared_ptr<phi::Allocation> d_random_row_;
L
lxsbupt 已提交
1013 1014 1015 1016 1017
  std::shared_ptr<phi::Allocation> d_uniq_node_num_;
  std::shared_ptr<phi::Allocation> d_slot_feature_num_map_;
  std::shared_ptr<phi::Allocation> d_actual_slot_id_map_;
  std::shared_ptr<phi::Allocation> d_fea_offset_map_;

D
danleifeng 已提交
1018 1019 1020 1021 1022 1023 1024
  std::vector<std::shared_ptr<phi::Allocation>> d_sampleidx2rows_;
  int cur_sampleidx2row_;
  // record the keys to call graph_neighbor_sample
  std::shared_ptr<phi::Allocation> d_sample_keys_;
  int sample_keys_len_;

  std::shared_ptr<phi::Allocation> d_ins_buf_;
L
lxsbupt 已提交
1025 1026
  std::shared_ptr<phi::Allocation> d_feature_size_list_buf_;
  std::shared_ptr<phi::Allocation> d_feature_size_prefixsum_buf_;
D
danleifeng 已提交
1027 1028 1029
  std::shared_ptr<phi::Allocation> d_pair_num_;
  std::shared_ptr<phi::Allocation> d_slot_tensor_ptr_;
  std::shared_ptr<phi::Allocation> d_slot_lod_tensor_ptr_;
L
lxsbupt 已提交
1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
  std::shared_ptr<phi::Allocation> d_reindex_table_key_;
  std::shared_ptr<phi::Allocation> d_reindex_table_value_;
  std::shared_ptr<phi::Allocation> d_reindex_table_index_;
  std::vector<std::shared_ptr<phi::Allocation>> edge_type_graph_;
  std::shared_ptr<phi::Allocation> d_sorted_keys_;
  std::shared_ptr<phi::Allocation> d_sorted_idx_;
  std::shared_ptr<phi::Allocation> d_offset_;
  std::shared_ptr<phi::Allocation> d_merged_cnts_;
  std::shared_ptr<phi::Allocation> d_buf_;

  // sage mode batch data
  std::vector<std::shared_ptr<phi::Allocation>> inverse_vec_;
  std::vector<std::shared_ptr<phi::Allocation>> final_sage_nodes_vec_;
1043
  std::vector<std::shared_ptr<phi::Allocation>> node_degree_vec_;
L
lxsbupt 已提交
1044 1045 1046 1047 1048
  std::vector<int> uniq_instance_vec_;
  std::vector<int> total_instance_vec_;
  std::vector<std::vector<std::shared_ptr<phi::Allocation>>> graph_edges_vec_;
  std::vector<std::vector<std::vector<int>>> edges_split_num_vec_;

1049
  int excluded_train_pair_len_;
L
lxsbupt 已提交
1050 1051 1052
  int64_t reindex_table_size_;
  int sage_batch_count_;
  int sage_batch_num_;
D
danleifeng 已提交
1053
  int ins_buf_pair_len_;
L
lxsbupt 已提交
1054

D
danleifeng 已提交
1055 1056 1057 1058 1059 1060 1061
  // size of a d_walk buf
  size_t buf_size_;
  int repeat_time_;
  std::vector<int> window_step_;
  BufState buf_state_;
  int batch_size_;
  int slot_num_;
L
lxsbupt 已提交
1062 1063
  std::vector<int> h_slot_feature_num_map_;
  int fea_num_per_node_;
D
danleifeng 已提交
1064 1065 1066
  int shuffle_seed_;
  int debug_mode_;
  bool gpu_graph_training_;
L
lxsbupt 已提交
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078
  bool sage_mode_;
  std::vector<int> samples_;
  bool epoch_finish_;
  std::vector<uint64_t> host_vec_;
  std::vector<uint64_t> h_device_keys_len_;
  uint64_t h_train_metapath_keys_len_;
  uint64_t train_table_cap_;
  uint64_t infer_table_cap_;
  uint64_t copy_unique_len_;
  int total_row_;
  size_t infer_node_start_;
  size_t infer_node_end_;
1079 1080 1081
  std::set<int> infer_node_type_index_set_;
  std::string infer_node_type_;
  bool get_degree_;
D
danleifeng 已提交
1082 1083
};

W
Wang Guibao 已提交
1084 1085
class DataFeed {
 public:
1086 1087 1088
  DataFeed() {
    mutex_for_pick_file_ = nullptr;
    file_idx_ = nullptr;
H
hutuxian 已提交
1089 1090
    mutex_for_fea_num_ = nullptr;
    total_fea_num_ = nullptr;
1091
  }
W
Wang Guibao 已提交
1092
  virtual ~DataFeed() {}
H
hutuxian 已提交
1093
  virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
W
Wang Guibao 已提交
1094
  virtual bool CheckFile(const char* filename) {
1095 1096
    PADDLE_THROW(platform::errors::Unimplemented(
        "This function(CheckFile) is not implemented."));
W
Wang Guibao 已提交
1097 1098 1099 1100 1101 1102
  }
  // Set filelist for DataFeed.
  // Pay attention that it must init all readers before call this function.
  // Otherwise, Init() function will init finish_set_filelist_ flag.
  virtual bool SetFileList(const std::vector<std::string>& files);
  virtual bool Start() = 0;
D
dongdaxiang 已提交
1103

W
Wang Guibao 已提交
1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
  // The trainer calls the Next() function, and the DataFeed will load a new
  // batch to the feed_vec. The return value of this function is the batch
  // size of the current batch.
  virtual int Next() = 0;
  // Get all slots' alias which defined in protofile
  virtual const std::vector<std::string>& GetAllSlotAlias() {
    return all_slots_;
  }
  // Get used slots' alias which defined in protofile
  virtual const std::vector<std::string>& GetUseSlotAlias() {
    return use_slots_;
  }
  // This function is used for binding feed_vec memory
  virtual void AddFeedVar(Variable* var, const std::string& name);

H
hutuxian 已提交
1119 1120 1121
  // This function is used for binding feed_vec memory in a given scope
  virtual void AssignFeedVar(const Scope& scope);

1122 1123 1124 1125 1126 1127 1128
  // This function will do nothing at default
  virtual void SetInputPvChannel(void* channel) {}
  // This function will do nothing at default
  virtual void SetOutputPvChannel(void* channel) {}
  // This function will do nothing at default
  virtual void SetConsumePvChannel(void* channel) {}

1129
  // This function will do nothing at default
J
jiaqi 已提交
1130 1131 1132
  virtual void SetInputChannel(void* channel) {}
  // This function will do nothing at default
  virtual void SetOutputChannel(void* channel) {}
1133
  // This function will do nothing at default
J
jiaqi 已提交
1134
  virtual void SetConsumeChannel(void* channel) {}
1135
  // This function will do nothing at default
1136
  virtual void SetThreadId(int thread_id) {}
1137
  // This function will do nothing at default
1138
  virtual void SetThreadNum(int thread_num) {}
1139 1140
  // This function will do nothing at default
  virtual void SetParseInsId(bool parse_ins_id) {}
1141
  virtual void SetParseUid(bool parse_uid) {}
1142
  virtual void SetParseContent(bool parse_content) {}
1143 1144 1145
  virtual void SetParseLogKey(bool parse_logkey) {}
  virtual void SetEnablePvMerge(bool enable_pv_merge) {}
  virtual void SetCurrentPhase(int current_phase) {}
D
danleifeng 已提交
1146
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
L
lxsbupt 已提交
1147 1148 1149
  virtual void InitGraphResource() {}
  virtual void InitGraphTrainResource() {}
  virtual void SetDeviceKeys(std::vector<uint64_t>* device_keys, int type) {
D
danleifeng 已提交
1150 1151
    gpu_graph_data_generator_.SetDeviceKeys(device_keys, type);
  }
L
lxsbupt 已提交
1152 1153
#endif

D
danleifeng 已提交
1154 1155 1156
  virtual void SetGpuGraphMode(int gpu_graph_mode) {
    gpu_graph_mode_ = gpu_graph_mode;
  }
1157 1158 1159
  virtual void SetFileListMutex(std::mutex* mutex) {
    mutex_for_pick_file_ = mutex;
  }
H
hutuxian 已提交
1160
  virtual void SetFeaNumMutex(std::mutex* mutex) { mutex_for_fea_num_ = mutex; }
1161
  virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
H
hutuxian 已提交
1162
  virtual void SetFeaNum(uint64_t* fea_num) { total_fea_num_ = fea_num; }
1163 1164 1165 1166 1167 1168 1169
  virtual const std::vector<std::string>& GetInsIdVec() const {
    return ins_id_vec_;
  }
  virtual const std::vector<std::string>& GetInsContentVec() const {
    return ins_content_vec_;
  }
  virtual int GetCurBatchSize() { return batch_size_; }
L
lxsbupt 已提交
1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205
  virtual int GetGraphPathNum() {
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
    return gpu_graph_data_generator_.GetPathNum();
#else
    return 0;
#endif
  }

#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  virtual const std::vector<uint64_t>* GetHostVec() {
    return &(gpu_graph_data_generator_.GetHostVec());
  }

  virtual void clear_gpu_mem() { gpu_graph_data_generator_.clear_gpu_mem(); }

  virtual bool get_epoch_finish() {
    return gpu_graph_data_generator_.get_epoch_finish();
  }

  virtual void ResetPathNum() { gpu_graph_data_generator_.ResetPathNum(); }

  virtual void ClearSampleState() {
    gpu_graph_data_generator_.ClearSampleState();
  }

  virtual void ResetEpochFinish() {
    gpu_graph_data_generator_.ResetEpochFinish();
  }

  virtual void DoWalkandSage() {
    PADDLE_THROW(platform::errors::Unimplemented(
        "This function(DoWalkandSage) is not implemented."));
  }
#endif

  virtual bool IsTrainMode() { return train_mode_; }
1206
  virtual void LoadIntoMemory() {
1207 1208
    PADDLE_THROW(platform::errors::Unimplemented(
        "This function(LoadIntoMemory) is not implemented."));
1209
  }
1210 1211 1212 1213
  virtual void SetPlace(const paddle::platform::Place& place) {
    place_ = place;
  }
  virtual const paddle::platform::Place& GetPlace() const { return place_; }
1214

1215 1216 1217 1218 1219
  virtual void DumpWalkPath(std::string dump_path, size_t dump_rate) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "This function(DumpWalkPath) is not implemented."));
  }

W
Wang Guibao 已提交
1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231
 protected:
  // The following three functions are used to check if it is executed in this
  // order:
  //   Init() -> SetFileList() -> Start() -> Next()
  virtual void CheckInit();
  virtual void CheckSetFileList();
  virtual void CheckStart();
  virtual void SetBatchSize(
      int batch);  // batch size will be set in Init() function
  // This function is used to pick one file from the global filelist(thread
  // safe).
  virtual bool PickOneFile(std::string* filename);
1232
  virtual void CopyToFeedTensor(void* dst, const void* src, size_t size);
W
Wang Guibao 已提交
1233

1234 1235 1236
  std::vector<std::string> filelist_;
  size_t* file_idx_;
  std::mutex* mutex_for_pick_file_;
H
hutuxian 已提交
1237 1238 1239
  std::mutex* mutex_for_fea_num_ = nullptr;
  uint64_t* total_fea_num_ = nullptr;
  uint64_t fea_num_ = 0;
W
Wang Guibao 已提交
1240 1241 1242 1243 1244 1245 1246 1247 1248 1249

  // the alias of used slots, and its order is determined by
  // data_feed_desc(proto object)
  std::vector<std::string> use_slots_;
  std::vector<bool> use_slots_is_dense_;

  // the alias of all slots, and its order is determined by data_feed_desc(proto
  // object)
  std::vector<std::string> all_slots_;
  std::vector<std::string> all_slots_type_;
1250
  std::vector<std::vector<int>> use_slots_shape_;
1251 1252
  std::vector<int> inductive_shape_index_;
  std::vector<int> total_dims_without_inductive_;
H
hutuxian 已提交
1253 1254
  // For the inductive shape passed within data
  std::vector<std::vector<int>> multi_inductive_shape_index_;
W
Wang Guibao 已提交
1255 1256 1257 1258
  std::vector<int>
      use_slots_index_;  // -1: not used; >=0: the index of use_slots_

  // The data read by DataFeed will be stored here
1259
  std::vector<phi::DenseTensor*> feed_vec_;
W
Wang Guibao 已提交
1260

1261
  phi::DenseTensor* rank_offset_;
1262

W
Wang Guibao 已提交
1263 1264 1265 1266 1267 1268
  // the batch size defined by user
  int default_batch_size_;
  // current batch size
  int batch_size_;

  bool finish_init_;
1269
  bool finish_set_filelist_;
W
Wang Guibao 已提交
1270
  bool finish_start_;
1271
  std::string pipe_command_;
T
Thunderbrook 已提交
1272 1273
  std::string so_parser_name_;
  std::vector<SlotConf> slot_conf_;
1274 1275
  std::vector<std::string> ins_id_vec_;
  std::vector<std::string> ins_content_vec_;
1276
  platform::Place place_;
1277
  std::string uid_slot_;
1278 1279 1280

  // The input type of pipe reader, 0 for one sample, 1 for one batch
  int input_type_;
D
danleifeng 已提交
1281 1282 1283 1284
  int gpu_graph_mode_ = 0;
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  GraphDataGenerator gpu_graph_data_generator_;
#endif
L
lxsbupt 已提交
1285
  bool train_mode_;
W
Wang Guibao 已提交
1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306
};

// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
// It use a read-thread to read file and parse data to a private-queue
// (thread level), and get data from this queue when trainer call Next().
template <typename T>
class PrivateQueueDataFeed : public DataFeed {
 public:
  PrivateQueueDataFeed() {}
  virtual ~PrivateQueueDataFeed() {}
  virtual bool Start();
  virtual int Next();

 protected:
  // The thread implementation function for reading file and parse.
  virtual void ReadThread();
  // This function is used to set private-queue size, and the most
  // efficient when the queue size is close to the batch size.
  virtual void SetQueueSize(int queue_size);
  // The reading and parsing method called in the ReadThread.
  virtual bool ParseOneInstance(T* instance) = 0;
D
dongdaxiang 已提交
1307
  virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
W
Wang Guibao 已提交
1308
  // This function is used to put instance to vec_ins
1309 1310
  virtual void AddInstanceToInsVec(T* vec_ins,
                                   const T& instance,
W
Wang Guibao 已提交
1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322
                                   int index) = 0;
  // This function is used to put ins_vec to feed_vec
  virtual void PutToFeedVec(const T& ins_vec) = 0;

  // The thread for read files
  std::thread read_thread_;
  // using ifstream one line and one line parse is faster
  // than using fread one buffer and one buffer parse.
  //   for a 601M real data:
  //     ifstream one line and one line parse: 6034 ms
  //     fread one buffer and one buffer parse: 7097 ms
  std::ifstream file_;
D
dongdaxiang 已提交
1323
  std::shared_ptr<FILE> fp_;
W
Wang Guibao 已提交
1324
  size_t queue_size_;
1325
  string::LineFileReader reader_;
W
Wang Guibao 已提交
1326
  // The queue for store parsed data
1327
  std::shared_ptr<paddle::framework::ChannelObject<T>> queue_;
W
Wang Guibao 已提交
1328 1329
};

1330
template <typename T>
J
jiaqi 已提交
1331
class InMemoryDataFeed : public DataFeed {
1332 1333 1334
 public:
  InMemoryDataFeed();
  virtual ~InMemoryDataFeed() {}
H
hutuxian 已提交
1335
  virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
1336 1337
  virtual bool Start();
  virtual int Next();
1338 1339 1340 1341
  virtual void SetInputPvChannel(void* channel);
  virtual void SetOutputPvChannel(void* channel);
  virtual void SetConsumePvChannel(void* channel);

J
jiaqi 已提交
1342 1343 1344
  virtual void SetInputChannel(void* channel);
  virtual void SetOutputChannel(void* channel);
  virtual void SetConsumeChannel(void* channel);
1345 1346
  virtual void SetThreadId(int thread_id);
  virtual void SetThreadNum(int thread_num);
1347
  virtual void SetParseInsId(bool parse_ins_id);
1348
  virtual void SetParseUid(bool parse_uid);
1349
  virtual void SetParseContent(bool parse_content);
1350 1351 1352
  virtual void SetParseLogKey(bool parse_logkey);
  virtual void SetEnablePvMerge(bool enable_pv_merge);
  virtual void SetCurrentPhase(int current_phase);
1353
  virtual void LoadIntoMemory();
T
Thunderbrook 已提交
1354
  virtual void LoadIntoMemoryFromSo();
Y
yaoxuefeng 已提交
1355 1356 1357 1358 1359
  virtual void SetRecord(T* records) { records_ = records; }
  int GetDefaultBatchSize() { return default_batch_size_; }
  void AddBatchOffset(const std::pair<int, int>& offset) {
    batch_offsets_.push_back(offset);
  }
X
xujiaqi01 已提交
1360

1361 1362 1363
 protected:
  virtual bool ParseOneInstance(T* instance) = 0;
  virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
1364 1365
  virtual void ParseOneInstanceFromSo(const char* str,
                                      T* instance,
T
Thunderbrook 已提交
1366
                                      CustomParser* parser) {}
1367 1368
  virtual int ParseInstanceFromSo(int len,
                                  const char* str,
T
Thunderbrook 已提交
1369 1370 1371 1372
                                  std::vector<T>* instances,
                                  CustomParser* parser) {
    return 0;
  }
J
jiaqi 已提交
1373
  virtual void PutToFeedVec(const std::vector<T>& ins_vec) = 0;
Y
yaoxuefeng 已提交
1374
  virtual void PutToFeedVec(const T* ins_vec, int num) = 0;
1375

Y
yaoxuefeng 已提交
1376 1377 1378 1379 1380
  std::vector<std::vector<float>> batch_float_feasigns_;
  std::vector<std::vector<uint64_t>> batch_uint64_feasigns_;
  std::vector<std::vector<size_t>> offset_;
  std::vector<bool> visit_;

1381 1382
  int thread_id_;
  int thread_num_;
1383
  bool parse_ins_id_;
1384
  bool parse_uid_;
1385
  bool parse_content_;
1386 1387 1388
  bool parse_logkey_;
  bool enable_pv_merge_;
  int current_phase_{-1};  // only for untest
J
jiaqi 已提交
1389 1390 1391 1392 1393
  std::ifstream file_;
  std::shared_ptr<FILE> fp_;
  paddle::framework::ChannelObject<T>* input_channel_;
  paddle::framework::ChannelObject<T>* output_channel_;
  paddle::framework::ChannelObject<T>* consume_channel_;
1394 1395 1396 1397

  paddle::framework::ChannelObject<PvInstance>* input_pv_channel_;
  paddle::framework::ChannelObject<PvInstance>* output_pv_channel_;
  paddle::framework::ChannelObject<PvInstance>* consume_pv_channel_;
Y
yaoxuefeng 已提交
1398 1399 1400 1401 1402

  std::vector<std::pair<int, int>> batch_offsets_;
  uint64_t offset_index_ = 0;
  bool enable_heterps_ = false;
  T* records_ = nullptr;
1403 1404
};

W
Wang Guibao 已提交
1405 1406 1407 1408 1409
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
class MultiSlotType {
 public:
  MultiSlotType() {}
  ~MultiSlotType() {}
H
hutuxian 已提交
1410
  void Init(const std::string& type, size_t reserved_size = 0) {
W
Wang Guibao 已提交
1411 1412 1413
    CheckType(type);
    if (type_[0] == 'f') {
      float_feasign_.clear();
H
hutuxian 已提交
1414 1415 1416
      if (reserved_size) {
        float_feasign_.reserve(reserved_size);
      }
W
Wang Guibao 已提交
1417 1418
    } else if (type_[0] == 'u') {
      uint64_feasign_.clear();
H
hutuxian 已提交
1419 1420 1421
      if (reserved_size) {
        uint64_feasign_.reserve(reserved_size);
      }
W
Wang Guibao 已提交
1422 1423 1424
    }
    type_ = type;
  }
H
hutuxian 已提交
1425 1426 1427 1428
  void InitOffset(size_t max_batch_size = 0) {
    if (max_batch_size > 0) {
      offset_.reserve(max_batch_size + 1);
    }
W
Wang Guibao 已提交
1429 1430 1431 1432 1433 1434
    offset_.resize(1);
    // LoDTensor' lod is counted from 0, the size of lod
    // is one size larger than the size of data.
    offset_[0] = 0;
  }
  const std::vector<size_t>& GetOffset() const { return offset_; }
1435
  std::vector<size_t>& MutableOffset() { return offset_; }
W
Wang Guibao 已提交
1436 1437 1438 1439 1440 1441 1442 1443
  void AddValue(const float v) {
    CheckFloat();
    float_feasign_.push_back(v);
  }
  void AddValue(const uint64_t v) {
    CheckUint64();
    uint64_feasign_.push_back(v);
  }
H
hutuxian 已提交
1444 1445 1446 1447 1448 1449 1450 1451 1452 1453
  void CopyValues(const float* input, size_t size) {
    CheckFloat();
    float_feasign_.resize(size);
    memcpy(float_feasign_.data(), input, size * sizeof(float));
  }
  void CopyValues(const uint64_t* input, size_t size) {
    CheckUint64();
    uint64_feasign_.resize(size);
    memcpy(uint64_feasign_.data(), input, size * sizeof(uint64_t));
  }
W
Wang Guibao 已提交
1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466
  void AddIns(const MultiSlotType& ins) {
    if (ins.GetType()[0] == 'f') {  // float
      CheckFloat();
      auto& vec = ins.GetFloatData();
      offset_.push_back(offset_.back() + vec.size());
      float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end());
    } else if (ins.GetType()[0] == 'u') {  // uint64
      CheckUint64();
      auto& vec = ins.GetUint64Data();
      offset_.push_back(offset_.back() + vec.size());
      uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
    }
  }
H
hutuxian 已提交
1467 1468 1469 1470 1471 1472 1473 1474
  void AppendValues(const uint64_t* input, size_t size) {
    CheckUint64();
    offset_.push_back(offset_.back() + size);
    uint64_feasign_.insert(uint64_feasign_.end(), input, input + size);
  }
  void AppendValues(const float* input, size_t size) {
    CheckFloat();
    offset_.push_back(offset_.back() + size);
1475

H
hutuxian 已提交
1476 1477
    float_feasign_.insert(float_feasign_.end(), input, input + size);
  }
W
Wang Guibao 已提交
1478
  const std::vector<float>& GetFloatData() const { return float_feasign_; }
1479
  std::vector<float>& MutableFloatData() { return float_feasign_; }
W
Wang Guibao 已提交
1480
  const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
1481
  std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
W
Wang Guibao 已提交
1482
  const std::string& GetType() const { return type_; }
H
hutuxian 已提交
1483
  size_t GetBatchSize() { return offset_.size() - 1; }
1484
  std::string& MutableType() { return type_; }
W
Wang Guibao 已提交
1485

X
xujiaqi01 已提交
1486 1487
  std::string DebugString() {
    std::stringstream ss;
W
wanghuancoder 已提交
1488

1489 1490
    ss << "\ntype: " << type_ << "\n";
    ss << "offset: ";
X
xujiaqi01 已提交
1491 1492 1493 1494
    ss << "[";
    for (const size_t& i : offset_) {
      ss << offset_[i] << ",";
    }
1495
    ss << "]\ndata: [";
X
xujiaqi01 已提交
1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508
    if (type_[0] == 'f') {
      for (const float& i : float_feasign_) {
        ss << i << ",";
      }
    } else {
      for (const uint64_t& i : uint64_feasign_) {
        ss << i << ",";
      }
    }
    ss << "]\n";
    return ss.str();
  }

W
Wang Guibao 已提交
1509 1510
 private:
  void CheckType(const std::string& type) const {
1511 1512
    PADDLE_ENFORCE_EQ((type == "uint64" || type == "float"),
                      true,
1513 1514 1515 1516
                      platform::errors::InvalidArgument(
                          "MultiSlotType error, expect type is uint64 or "
                          "float, but received type is %s.",
                          type));
W
Wang Guibao 已提交
1517 1518
  }
  void CheckFloat() const {
1519
    PADDLE_ENFORCE_EQ(
1520 1521
        type_[0],
        'f',
1522 1523
        platform::errors::InvalidArgument(
            "MultiSlotType error, add %s value to float slot.", type_));
W
Wang Guibao 已提交
1524 1525
  }
  void CheckUint64() const {
1526
    PADDLE_ENFORCE_EQ(
1527 1528
        type_[0],
        'u',
1529 1530
        platform::errors::InvalidArgument(
            "MultiSlotType error, add %s value to uint64 slot.", type_));
W
Wang Guibao 已提交
1531 1532 1533 1534 1535 1536 1537
  }
  std::vector<float> float_feasign_;
  std::vector<uint64_t> uint64_feasign_;
  std::string type_;
  std::vector<size_t> offset_;
};

J
jiaqi 已提交
1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
                                           const MultiSlotType& ins) {
  ar << ins.GetType();
#ifdef _LINUX
  ar << ins.GetOffset();
#else
  const auto& offset = ins.GetOffset();
  ar << (uint64_t)offset.size();
  for (const size_t& x : offset) {
    ar << (const uint64_t)x;
  }
#endif
  ar << ins.GetFloatData();
  ar << ins.GetUint64Data();
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
                                           MultiSlotType& ins) {
  ar >> ins.MutableType();
#ifdef _LINUX
  ar >> ins.MutableOffset();
#else
  auto& offset = ins.MutableOffset();
  offset.resize(ar.template Get<uint64_t>());
  for (size_t& x : offset) {
    uint64_t t;
    ar >> t;
Y
yaoxuefeng 已提交
1568
    x = static_cast<size_t>(t);
J
jiaqi 已提交
1569 1570 1571 1572 1573 1574 1575
  }
#endif
  ar >> ins.MutableFloatData();
  ar >> ins.MutableUint64Data();
  return ar;
}

1576 1577
struct RecordCandidate {
  std::string ins_id_;
T
Thunderbrook 已提交
1578
  std::unordered_multimap<uint16_t, FeatureFeasign> feas_;
1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590
  size_t shadow_index_ = -1;  // Optimization for Reservoir Sample

  RecordCandidate() {}
  RecordCandidate(const Record& rec,
                  const std::unordered_set<uint16_t>& slot_index_to_replace) {
    for (const auto& fea : rec.uint64_feasigns_) {
      if (slot_index_to_replace.find(fea.slot()) !=
          slot_index_to_replace.end()) {
        feas_.insert({fea.slot(), fea.sign()});
      }
    }
  }
1591 1592

  RecordCandidate& operator=(const Record& rec) {
1593
    feas_.clear();
1594 1595
    ins_id_ = rec.ins_id_;
    for (auto& fea : rec.uint64_feasigns_) {
1596
      feas_.insert({fea.slot(), fea.sign()});
1597 1598 1599 1600 1601 1602 1603 1604
    }
    return *this;
  }
};

class RecordCandidateList {
 public:
  RecordCandidateList() = default;
1605
  RecordCandidateList(const RecordCandidateList&) {}
1606

1607
  size_t Size() { return cur_size_; }
1608 1609 1610
  void ReSize(size_t length);

  void ReInit();
1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622
  void ReInitPass() {
    for (size_t i = 0; i < cur_size_; ++i) {
      if (candidate_list_[i].shadow_index_ != i) {
        candidate_list_[i].ins_id_ =
            candidate_list_[candidate_list_[i].shadow_index_].ins_id_;
        candidate_list_[i].feas_.swap(
            candidate_list_[candidate_list_[i].shadow_index_].feas_);
        candidate_list_[i].shadow_index_ = i;
      }
    }
    candidate_list_.resize(cur_size_);
  }
1623 1624

  void AddAndGet(const Record& record, RecordCandidate* result);
1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646
  void AddAndGet(const Record& record, size_t& index_result) {  // NOLINT
    // std::unique_lock<std::mutex> lock(mutex_);
    size_t index = 0;
    ++total_size_;
    auto fleet_ptr = FleetWrapper::GetInstance();
    if (!full_) {
      candidate_list_.emplace_back(record, slot_index_to_replace_);
      candidate_list_.back().shadow_index_ = cur_size_;
      ++cur_size_;
      full_ = (cur_size_ == capacity_);
    } else {
      index = fleet_ptr->LocalRandomEngine()() % total_size_;
      if (index < capacity_) {
        candidate_list_.emplace_back(record, slot_index_to_replace_);
        candidate_list_[index].shadow_index_ = candidate_list_.size() - 1;
      }
    }
    index = fleet_ptr->LocalRandomEngine()() % cur_size_;
    index_result = candidate_list_[index].shadow_index_;
  }
  const RecordCandidate& Get(size_t index) const {
    PADDLE_ENFORCE_LT(
1647 1648
        index,
        candidate_list_.size(),
1649 1650
        platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
                                     "elements in candidate_list[%lu].",
1651 1652
                                     index,
                                     candidate_list_.size()));
1653 1654 1655 1656 1657 1658
    return candidate_list_[index];
  }
  void SetSlotIndexToReplace(
      const std::unordered_set<uint16_t>& slot_index_to_replace) {
    slot_index_to_replace_ = slot_index_to_replace;
  }
1659 1660

 private:
1661 1662 1663 1664 1665 1666 1667
  size_t capacity_ = 0;
  std::mutex mutex_;
  bool full_ = false;
  size_t cur_size_ = 0;
  size_t total_size_ = 0;
  std::vector<RecordCandidate> candidate_list_;
  std::unordered_set<uint16_t> slot_index_to_replace_;
1668 1669
};

J
jiaqi 已提交
1670 1671
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
T
Thunderbrook 已提交
1672
                                           const FeatureFeasign& fk) {
J
jiaqi 已提交
1673 1674 1675 1676 1677 1678 1679
  ar << fk.uint64_feasign_;
  ar << fk.float_feasign_;
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
T
Thunderbrook 已提交
1680
                                           FeatureFeasign& fk) {
J
jiaqi 已提交
1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719
  ar >> fk.uint64_feasign_;
  ar >> fk.float_feasign_;
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
                                           const FeatureItem& fi) {
  ar << fi.sign();
  ar << fi.slot();
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
                                           FeatureItem& fi) {
  ar >> fi.sign();
  ar >> fi.slot();
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
                                           const Record& r) {
  ar << r.uint64_feasigns_;
  ar << r.float_feasigns_;
  ar << r.ins_id_;
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
                                           Record& r) {
  ar >> r.uint64_feasigns_;
  ar >> r.float_feasigns_;
  ar >> r.ins_id_;
  return ar;
}

W
Wang Guibao 已提交
1720 1721 1722 1723 1724 1725 1726 1727
// This DataFeed is used to feed multi-slot type data.
// The format of multi-slot type data:
//   [n feasign_0 feasign_1 ... feasign_n]*
class MultiSlotDataFeed
    : public PrivateQueueDataFeed<std::vector<MultiSlotType>> {
 public:
  MultiSlotDataFeed() {}
  virtual ~MultiSlotDataFeed() {}
H
hutuxian 已提交
1728
  virtual void Init(const DataFeedDesc& data_feed_desc);
W
Wang Guibao 已提交
1729 1730 1731
  virtual bool CheckFile(const char* filename);

 protected:
D
dongdaxiang 已提交
1732
  virtual void ReadThread();
W
Wang Guibao 已提交
1733 1734 1735 1736
  virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
                                   const std::vector<MultiSlotType>& instance,
                                   int index);
  virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
D
dongdaxiang 已提交
1737
  virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
W
Wang Guibao 已提交
1738 1739
  virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
};
1740

J
jiaqi 已提交
1741
class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
1742 1743 1744
 public:
  MultiSlotInMemoryDataFeed() {}
  virtual ~MultiSlotInMemoryDataFeed() {}
H
hutuxian 已提交
1745
  virtual void Init(const DataFeedDesc& data_feed_desc);
Y
yaoxuefeng 已提交
1746
  // void SetRecord(Record* records) { records_ = records; }
1747

1748
 protected:
J
jiaqi 已提交
1749 1750
  virtual bool ParseOneInstance(Record* instance);
  virtual bool ParseOneInstanceFromPipe(Record* instance);
1751 1752
  virtual void ParseOneInstanceFromSo(const char* str,
                                      Record* instance,
1753
                                      CustomParser* parser) {}
1754 1755
  virtual int ParseInstanceFromSo(int len,
                                  const char* str,
T
Thunderbrook 已提交
1756 1757
                                  std::vector<Record>* instances,
                                  CustomParser* parser);
J
jiaqi 已提交
1758
  virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
1759 1760 1761 1762
  virtual void GetMsgFromLogKey(const std::string& log_key,
                                uint64_t* search_id,
                                uint32_t* cmatch,
                                uint32_t* rank);
Y
yaoxuefeng 已提交
1763
  virtual void PutToFeedVec(const Record* ins_vec, int num);
1764 1765
};

Y
yaoxuefeng 已提交
1766 1767 1768
class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
 public:
  SlotRecordInMemoryDataFeed() {}
1769 1770 1771 1772 1773 1774 1775
  virtual ~SlotRecordInMemoryDataFeed() {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
    if (pack_ != nullptr) {
      pack_ = nullptr;
    }
#endif
  }
Y
yaoxuefeng 已提交
1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797
  virtual void Init(const DataFeedDesc& data_feed_desc);
  virtual void LoadIntoMemory();
  void ExpandSlotRecord(SlotRecord* ins);

 protected:
  virtual bool Start();
  virtual int Next();
  virtual bool ParseOneInstance(SlotRecord* instance) { return false; }
  virtual bool ParseOneInstanceFromPipe(SlotRecord* instance) { return false; }
  // virtual void ParseOneInstanceFromSo(const char* str, T* instance,
  //                                    CustomParser* parser) {}
  virtual void PutToFeedVec(const std::vector<SlotRecord>& ins_vec) {}

  virtual void LoadIntoMemoryByCommand(void);
  virtual void LoadIntoMemoryByLib(void);
  virtual void LoadIntoMemoryByLine(void);
  virtual void LoadIntoMemoryByFile(void);
  virtual void SetInputChannel(void* channel) {
    input_channel_ = static_cast<ChannelObject<SlotRecord>*>(channel);
  }
  bool ParseOneInstance(const std::string& line, SlotRecord* rec);
  virtual void PutToFeedVec(const SlotRecord* ins_vec, int num);
1798 1799 1800
  virtual void AssignFeedVar(const Scope& scope);
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
  void BuildSlotBatchGPU(const int ins_num);
1801 1802
  void FillSlotValueOffset(const int ins_num,
                           const int used_slot_num,
1803 1804
                           size_t* slot_value_offsets,
                           const int* uint64_offsets,
1805 1806
                           const int uint64_slot_size,
                           const int* float_offsets,
1807 1808
                           const int float_slot_size,
                           const UsedSlotGpuType* used_slots);
1809 1810 1811
  void CopyForTensor(const int ins_num,
                     const int used_slot_num,
                     void** dest,
1812
                     const size_t* slot_value_offsets,
1813 1814 1815 1816 1817 1818 1819 1820
                     const uint64_t* uint64_feas,
                     const int* uint64_offsets,
                     const int* uint64_ins_lens,
                     const int uint64_slot_size,
                     const float* float_feas,
                     const int* float_offsets,
                     const int* float_ins_lens,
                     const int float_slot_size,
1821 1822
                     const UsedSlotGpuType* used_slots);
#endif
L
lxsbupt 已提交
1823 1824 1825 1826 1827 1828

#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  virtual void InitGraphResource(void);
  virtual void InitGraphTrainResource(void);
  virtual void DoWalkandSage();
#endif
1829
  virtual void DumpWalkPath(std::string dump_path, size_t dump_rate);
L
lxsbupt 已提交
1830

Y
yaoxuefeng 已提交
1831 1832 1833 1834 1835 1836 1837 1838
  float sample_rate_ = 1.0f;
  int use_slot_size_ = 0;
  int float_use_slot_size_ = 0;
  int uint64_use_slot_size_ = 0;
  std::vector<AllSlotInfo> all_slots_info_;
  std::vector<UsedSlotInfo> used_slots_info_;
  size_t float_total_dims_size_ = 0;
  std::vector<int> float_total_dims_without_inductives_;
1839 1840 1841 1842

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
  MiniBatchGpuPack* pack_ = nullptr;
#endif
Y
yaoxuefeng 已提交
1843 1844
};

1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
 public:
  PaddleBoxDataFeed() {}
  virtual ~PaddleBoxDataFeed() {}

 protected:
  virtual void Init(const DataFeedDesc& data_feed_desc);
  virtual bool Start();
  virtual int Next();
  virtual void AssignFeedVar(const Scope& scope);
  virtual void PutToFeedVec(const std::vector<PvInstance>& pv_vec);
  virtual void PutToFeedVec(const std::vector<Record*>& ins_vec);
  virtual int GetCurrentPhase();
  virtual void GetRankOffset(const std::vector<PvInstance>& pv_vec,
                             int ins_number);
  std::string rank_offset_name_;
  int pv_batch_size_;
1862 1863
};

1864
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
H
hutuxian 已提交
1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912
template <typename T>
class PrivateInstantDataFeed : public DataFeed {
 public:
  PrivateInstantDataFeed() {}
  virtual ~PrivateInstantDataFeed() {}
  void Init(const DataFeedDesc& data_feed_desc) override;
  bool Start() override { return true; }
  int Next() override;

 protected:
  // The batched data buffer
  std::vector<MultiSlotType> ins_vec_;

  // This function is used to preprocess with a given filename, e.g. open it or
  // mmap
  virtual bool Preprocess(const std::string& filename) = 0;

  // This function is used to postprocess system resource such as closing file
  // NOTICE: Ensure that it is safe to call before Preprocess
  virtual bool Postprocess() = 0;

  // The reading and parsing method.
  virtual bool ParseOneMiniBatch() = 0;

  // This function is used to put ins_vec to feed_vec
  virtual void PutToFeedVec();
};

class MultiSlotFileInstantDataFeed
    : public PrivateInstantDataFeed<std::vector<MultiSlotType>> {
 public:
  MultiSlotFileInstantDataFeed() {}
  virtual ~MultiSlotFileInstantDataFeed() {}

 protected:
  int fd_{-1};
  char* buffer_{nullptr};
  size_t end_{0};
  size_t offset_{0};

  bool Preprocess(const std::string& filename) override;

  bool Postprocess() override;

  bool ParseOneMiniBatch() override;
};
#endif

W
Wang Guibao 已提交
1913 1914
}  // namespace framework
}  // namespace paddle