data_feed.h 46.7 KB
Newer Older
W
Wang Guibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

J
jiaqi 已提交
17 18 19 20 21
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif

W
Wang Guibao 已提交
22
#include <fstream>
23
#include <future>  // NOLINT
W
Wang Guibao 已提交
24 25
#include <memory>
#include <mutex>  // NOLINT
26
#include <sstream>
W
Wang Guibao 已提交
27 28
#include <string>
#include <thread>  // NOLINT
29
#include <unordered_map>
30
#include <unordered_set>
31
#include <utility>
32
#include <vector>
W
Wang Guibao 已提交
33

J
jiaqi 已提交
34
#include "paddle/fluid/framework/archive.h"
35
#include "paddle/fluid/framework/blocking_queue.h"
J
jiaqi 已提交
36
#include "paddle/fluid/framework/channel.h"
W
Wang Guibao 已提交
37
#include "paddle/fluid/framework/data_feed.pb.h"
38
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
W
Wang Guibao 已提交
39 40 41
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
Y
yaoxuefeng 已提交
42
#include "paddle/fluid/platform/timer.h"
43
#include "paddle/fluid/string/string_helper.h"
44 45 46 47
#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#endif
W
Wang Guibao 已提交
48

Y
yaoxuefeng 已提交
49 50 51 52 53
DECLARE_int32(record_pool_max_size);
DECLARE_int32(slotpool_thread_num);
DECLARE_bool(enable_slotpool_wait_release);
DECLARE_bool(enable_slotrecord_reset_shrink);

W
wanghuancoder 已提交
54 55 56 57 58 59 60 61
namespace paddle {
namespace framework {
class DataFeedDesc;
class Scope;
class Variable;
}  // namespace framework
}  // namespace paddle

62
namespace phi {
63
class DenseTensor;
64
}  // namespace phi
65

W
Wang Guibao 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
namespace paddle {
namespace framework {

// DataFeed is the base virtual class for all ohther DataFeeds.
// It is used to read files and parse the data for subsequent trainer.
// Example:
//   DataFeed* reader =
//   paddle::framework::DataFeedFactory::CreateDataFeed(data_feed_name);
//   reader->Init(data_feed_desc); // data_feed_desc is a protobuf object
//   reader->SetFileList(filelist);
//   const std::vector<std::string> & use_slot_alias =
//   reader->GetUseSlotAlias();
//   for (auto name: use_slot_alias){ // for binding memory
//     reader->AddFeedVar(scope->Var(name), name);
//   }
//   reader->Start();
//   while (reader->Next()) {
//      // trainer do something
//   }
Y
yaoxuefeng 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128

template <typename T>
struct SlotValues {
  std::vector<T> slot_values;
  std::vector<uint32_t> slot_offsets;

  void add_values(const T* values, uint32_t num) {
    if (slot_offsets.empty()) {
      slot_offsets.push_back(0);
    }
    if (num > 0) {
      slot_values.insert(slot_values.end(), values, values + num);
    }
    slot_offsets.push_back(static_cast<uint32_t>(slot_values.size()));
  }
  T* get_values(int idx, size_t* size) {
    uint32_t& offset = slot_offsets[idx];
    (*size) = slot_offsets[idx + 1] - offset;
    return &slot_values[offset];
  }
  void add_slot_feasigns(const std::vector<std::vector<T>>& slot_feasigns,
                         uint32_t fea_num) {
    slot_values.reserve(fea_num);
    int slot_num = static_cast<int>(slot_feasigns.size());
    slot_offsets.resize(slot_num + 1);
    for (int i = 0; i < slot_num; ++i) {
      auto& slot_val = slot_feasigns[i];
      slot_offsets[i] = static_cast<uint32_t>(slot_values.size());
      uint32_t num = static_cast<uint32_t>(slot_val.size());
      if (num > 0) {
        slot_values.insert(slot_values.end(), slot_val.begin(), slot_val.end());
      }
    }
    slot_offsets[slot_num] = slot_values.size();
  }
  void clear(bool shrink) {
    slot_offsets.clear();
    slot_values.clear();
    if (shrink) {
      slot_values.shrink_to_fit();
      slot_offsets.shrink_to_fit();
    }
  }
};
T
Thunderbrook 已提交
129
union FeatureFeasign {
130 131 132 133 134 135
  uint64_t uint64_feasign_;
  float float_feasign_;
};

struct FeatureItem {
  FeatureItem() {}
T
Thunderbrook 已提交
136
  FeatureItem(FeatureFeasign sign, uint16_t slot) {
137 138 139
    this->sign() = sign;
    this->slot() = slot;
  }
T
Thunderbrook 已提交
140 141 142 143 144 145
  FeatureFeasign& sign() {
    return *(reinterpret_cast<FeatureFeasign*>(sign_buffer()));
  }
  const FeatureFeasign& sign() const {
    const FeatureFeasign* ret =
        reinterpret_cast<FeatureFeasign*>(sign_buffer());
146 147 148 149 150 151 152
    return *ret;
  }
  uint16_t& slot() { return slot_; }
  const uint16_t& slot() const { return slot_; }

 private:
  char* sign_buffer() const { return const_cast<char*>(sign_); }
T
Thunderbrook 已提交
153
  char sign_[sizeof(FeatureFeasign)];
154 155 156
  uint16_t slot_;
};

Y
yaoxuefeng 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
struct AllSlotInfo {
  std::string slot;
  std::string type;
  int used_idx;
  int slot_value_idx;
};
struct UsedSlotInfo {
  int idx;
  int slot_value_idx;
  std::string slot;
  std::string type;
  bool dense;
  std::vector<int> local_shape;
  int total_dims_without_inductive;
  int inductive_shape_index;
};
struct SlotRecordObject {
  uint64_t search_id;
  uint32_t rank;
  uint32_t cmatch;
  std::string ins_id_;
  SlotValues<uint64_t> slot_uint64_feasigns_;
  SlotValues<float> slot_float_feasigns_;

  ~SlotRecordObject() { clear(true); }
  void reset(void) { clear(FLAGS_enable_slotrecord_reset_shrink); }
  void clear(bool shrink) {
    slot_uint64_feasigns_.clear(shrink);
    slot_float_feasigns_.clear(shrink);
  }
};
using SlotRecord = SlotRecordObject*;
189 190 191 192 193 194 195 196 197
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
  std::vector<FeatureItem> uint64_feasigns_;
  std::vector<FeatureItem> float_feasigns_;
  std::string ins_id_;
  std::string content_;
  uint64_t search_id;
  uint32_t rank;
  uint32_t cmatch;
198
  std::string uid_;
199 200
};

Y
yaoxuefeng 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
inline SlotRecord make_slotrecord() {
  static const size_t slot_record_byte_size = sizeof(SlotRecordObject);
  void* p = malloc(slot_record_byte_size);
  new (p) SlotRecordObject;
  return reinterpret_cast<SlotRecordObject*>(p);
}

inline void free_slotrecord(SlotRecordObject* p) {
  p->~SlotRecordObject();
  free(p);
}

template <class T>
class SlotObjAllocator {
 public:
  explicit SlotObjAllocator(std::function<void(T*)> deleter)
      : free_nodes_(NULL), capacity_(0), deleter_(deleter) {}
  ~SlotObjAllocator() { clear(); }

  void clear() {
    T* tmp = NULL;
    while (free_nodes_ != NULL) {
      tmp = reinterpret_cast<T*>(reinterpret_cast<void*>(free_nodes_));
      free_nodes_ = free_nodes_->next;
      deleter_(tmp);
      --capacity_;
    }
    CHECK_EQ(capacity_, static_cast<size_t>(0));
  }
  T* acquire(void) {
    T* x = NULL;
    x = reinterpret_cast<T*>(reinterpret_cast<void*>(free_nodes_));
    free_nodes_ = free_nodes_->next;
    --capacity_;
    return x;
  }
  void release(T* x) {
    Node* node = reinterpret_cast<Node*>(reinterpret_cast<void*>(x));
    node->next = free_nodes_;
    free_nodes_ = node;
    ++capacity_;
  }
  size_t capacity(void) { return capacity_; }

 private:
  struct alignas(T) Node {
    union {
      Node* next;
      char data[sizeof(T)];
    };
  };
  Node* free_nodes_;  // a list
  size_t capacity_;
  std::function<void(T*)> deleter_ = nullptr;
};
static const int OBJPOOL_BLOCK_SIZE = 10000;
class SlotObjPool {
 public:
  SlotObjPool()
      : max_capacity_(FLAGS_record_pool_max_size), alloc_(free_slotrecord) {
    ins_chan_ = MakeChannel<SlotRecord>();
    ins_chan_->SetBlockSize(OBJPOOL_BLOCK_SIZE);
    for (int i = 0; i < FLAGS_slotpool_thread_num; ++i) {
      threads_.push_back(std::thread([this]() { run(); }));
    }
    disable_pool_ = false;
    count_ = 0;
  }
  ~SlotObjPool() {
    ins_chan_->Close();
    for (auto& t : threads_) {
      t.join();
    }
  }
  void disable_pool(bool disable) { disable_pool_ = disable; }
  void set_max_capacity(size_t max_capacity) { max_capacity_ = max_capacity; }
  void get(std::vector<SlotRecord>* output, int n) {
    output->resize(n);
    return get(&(*output)[0], n);
  }
  void get(SlotRecord* output, int n) {
    int size = 0;
    mutex_.lock();
    int left = static_cast<int>(alloc_.capacity());
    if (left > 0) {
      size = (left >= n) ? n : left;
      for (int i = 0; i < size; ++i) {
        output[i] = alloc_.acquire();
      }
    }
    mutex_.unlock();
    count_ += n;
    if (size == n) {
      return;
    }
    for (int i = size; i < n; ++i) {
      output[i] = make_slotrecord();
    }
  }
  void put(std::vector<SlotRecord>* input) {
    size_t size = input->size();
    if (size == 0) {
      return;
    }
    put(&(*input)[0], size);
    input->clear();
  }
  void put(SlotRecord* input, size_t size) {
    CHECK(ins_chan_->WriteMove(size, input) == size);
  }
  void run(void) {
    std::vector<SlotRecord> input;
    while (ins_chan_->ReadOnce(input, OBJPOOL_BLOCK_SIZE)) {
      if (input.empty()) {
        continue;
      }
      // over max capacity
      size_t n = input.size();
      count_ -= n;
      if (disable_pool_ || n + capacity() > max_capacity_) {
        for (auto& t : input) {
          free_slotrecord(t);
        }
      } else {
        for (auto& t : input) {
          t->reset();
        }
        mutex_.lock();
        for (auto& t : input) {
          alloc_.release(t);
        }
        mutex_.unlock();
      }
      input.clear();
    }
  }
  void clear(void) {
    platform::Timer timeline;
    timeline.Start();
    mutex_.lock();
    alloc_.clear();
    mutex_.unlock();
    // wait release channel data
    if (FLAGS_enable_slotpool_wait_release) {
      while (!ins_chan_->Empty()) {
        sleep(1);
      }
    }
    timeline.Pause();
    VLOG(3) << "clear slot pool data size=" << count_.load()
            << ", span=" << timeline.ElapsedSec();
  }
  size_t capacity(void) {
    mutex_.lock();
    size_t total = alloc_.capacity();
    mutex_.unlock();
    return total;
  }

 private:
  size_t max_capacity_;
  Channel<SlotRecord> ins_chan_;
  std::vector<std::thread> threads_;
  std::mutex mutex_;
  SlotObjAllocator<SlotRecordObject> alloc_;
  bool disable_pool_;
  std::atomic<long> count_;  // NOLINT
};

inline SlotObjPool& SlotRecordPool() {
  static SlotObjPool pool;
  return pool;
}
374 375 376 377 378 379 380 381 382
struct PvInstanceObject {
  std::vector<Record*> ads;
  void merge_instance(Record* ins) { ads.push_back(ins); }
};

using PvInstance = PvInstanceObject*;

inline PvInstance make_pv_instance() { return new PvInstanceObject(); }

T
Thunderbrook 已提交
383 384 385 386 387 388 389 390 391 392 393 394
struct SlotConf {
  std::string name;
  std::string type;
  int use_slots_index;
  int use_slots_is_dense;
};

class CustomParser {
 public:
  CustomParser() {}
  virtual ~CustomParser() {}
  virtual void Init(const std::vector<SlotConf>& slots) = 0;
T
Thunderbrook 已提交
395
  virtual bool Init(const std::vector<AllSlotInfo>& slots) = 0;
T
Thunderbrook 已提交
396
  virtual void ParseOneInstance(const char* str, Record* instance) = 0;
397 398
  virtual int ParseInstance(int len,
                            const char* str,
T
Thunderbrook 已提交
399 400 401
                            std::vector<Record>* instances) {
    return 0;
  };
Y
yaoxuefeng 已提交
402 403 404 405 406 407 408 409 410 411 412 413 414
  virtual bool ParseOneInstance(
      const std::string& line,
      std::function<void(std::vector<SlotRecord>&, int)>
          GetInsFunc) {  // NOLINT
    return true;
  }
  virtual bool ParseFileInstance(
      std::function<int(char* buf, int len)> ReadBuffFunc,
      std::function<void(std::vector<SlotRecord>&, int, int)>
          PullRecordsFunc,  // NOLINT
      int& lines) {         // NOLINT
    return false;
  }
T
Thunderbrook 已提交
415 416
};

417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
struct UsedSlotGpuType {
  int is_uint64_value;
  int slot_value_idx;
};

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
#define CUDA_CHECK(val) CHECK(val == gpuSuccess)
template <typename T>
struct CudaBuffer {
  T* cu_buffer;
  uint64_t buf_size;

  CudaBuffer<T>() {
    cu_buffer = NULL;
    buf_size = 0;
  }
  ~CudaBuffer<T>() { free(); }
  T* data() { return cu_buffer; }
  uint64_t size() { return buf_size; }
  void malloc(uint64_t size) {
    buf_size = size;
    CUDA_CHECK(
        cudaMalloc(reinterpret_cast<void**>(&cu_buffer), size * sizeof(T)));
  }
  void free() {
    if (cu_buffer != NULL) {
      CUDA_CHECK(cudaFree(cu_buffer));
      cu_buffer = NULL;
    }
    buf_size = 0;
  }
  void resize(uint64_t size) {
    if (size <= buf_size) {
      return;
    }
    free();
    malloc(size);
  }
};
template <typename T>
struct HostBuffer {
  T* host_buffer;
  size_t buf_size;
  size_t data_len;

  HostBuffer<T>() {
    host_buffer = NULL;
    buf_size = 0;
    data_len = 0;
  }
  ~HostBuffer<T>() { free(); }

  T* data() { return host_buffer; }
  const T* data() const { return host_buffer; }
  size_t size() const { return data_len; }
  void clear() { free(); }
  T& back() { return host_buffer[data_len - 1]; }

  T& operator[](size_t i) { return host_buffer[i]; }
  const T& operator[](size_t i) const { return host_buffer[i]; }
  void malloc(size_t len) {
    buf_size = len;
    CUDA_CHECK(cudaHostAlloc(reinterpret_cast<void**>(&host_buffer),
480 481
                             buf_size * sizeof(T),
                             cudaHostAllocDefault));
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
    CHECK(host_buffer != NULL);
  }
  void free() {
    if (host_buffer != NULL) {
      CUDA_CHECK(cudaFreeHost(host_buffer));
      host_buffer = NULL;
    }
    buf_size = 0;
  }
  void resize(size_t size) {
    if (size <= buf_size) {
      data_len = size;
      return;
    }
    data_len = size;
    free();
    malloc(size);
  }
};

struct BatchCPUValue {
  HostBuffer<int> h_uint64_lens;
  HostBuffer<uint64_t> h_uint64_keys;
  HostBuffer<int> h_uint64_offset;

  HostBuffer<int> h_float_lens;
  HostBuffer<float> h_float_keys;
  HostBuffer<int> h_float_offset;

  HostBuffer<int> h_rank;
  HostBuffer<int> h_cmatch;
  HostBuffer<int> h_ad_offset;
};

struct BatchGPUValue {
  CudaBuffer<int> d_uint64_lens;
  CudaBuffer<uint64_t> d_uint64_keys;
  CudaBuffer<int> d_uint64_offset;

  CudaBuffer<int> d_float_lens;
  CudaBuffer<float> d_float_keys;
  CudaBuffer<int> d_float_offset;

  CudaBuffer<int> d_rank;
  CudaBuffer<int> d_cmatch;
  CudaBuffer<int> d_ad_offset;
};

class MiniBatchGpuPack {
 public:
  MiniBatchGpuPack(const paddle::platform::Place& place,
                   const std::vector<UsedSlotInfo>& infos);
  ~MiniBatchGpuPack();
  void reset(const paddle::platform::Place& place);
  void pack_instance(const SlotRecord* ins_vec, int num);
  int ins_num() { return ins_num_; }
  int pv_num() { return pv_num_; }
  BatchGPUValue& value() { return value_; }
  BatchCPUValue& cpu_value() { return buf_; }
  UsedSlotGpuType* get_gpu_slots(void) {
    return reinterpret_cast<UsedSlotGpuType*>(gpu_slots_.data());
  }
  SlotRecord* get_records(void) { return &ins_vec_[0]; }

  // tensor gpu memory reused
  void resize_tensor(void) {
    if (used_float_num_ > 0) {
      int float_total_len = buf_.h_float_lens.back();
      if (float_total_len > 0) {
        float_tensor_.mutable_data<float>({float_total_len, 1}, this->place_);
      }
    }
    if (used_uint64_num_ > 0) {
      int uint64_total_len = buf_.h_uint64_lens.back();
      if (uint64_total_len > 0) {
        uint64_tensor_.mutable_data<int64_t>({uint64_total_len, 1},
                                             this->place_);
      }
    }
  }
  LoDTensor& float_tensor(void) { return float_tensor_; }
  LoDTensor& uint64_tensor(void) { return uint64_tensor_; }

  HostBuffer<size_t>& offsets(void) { return offsets_; }
  HostBuffer<void*>& h_tensor_ptrs(void) { return h_tensor_ptrs_; }

  void* gpu_slot_offsets(void) { return gpu_slot_offsets_->ptr(); }

  void* slot_buf_ptr(void) { return slot_buf_ptr_->ptr(); }

  void resize_gpu_slot_offsets(const size_t slot_total_bytes) {
    if (gpu_slot_offsets_ == nullptr) {
      gpu_slot_offsets_ = memory::AllocShared(place_, slot_total_bytes);
    } else if (gpu_slot_offsets_->size() < slot_total_bytes) {
      auto buf = memory::AllocShared(place_, slot_total_bytes);
      gpu_slot_offsets_.swap(buf);
      buf = nullptr;
    }
  }
  const std::string& get_lineid(int idx) {
    if (enable_pv_) {
      return ins_vec_[idx]->ins_id_;
    }
    return batch_ins_[idx]->ins_id_;
  }

 private:
  void transfer_to_gpu(void);
  void pack_all_data(const SlotRecord* ins_vec, int num);
  void pack_uint64_data(const SlotRecord* ins_vec, int num);
  void pack_float_data(const SlotRecord* ins_vec, int num);

 public:
  template <typename T>
  void copy_host2device(CudaBuffer<T>* buf, const T* val, size_t size) {
    if (size == 0) {
      return;
    }
    buf->resize(size);
601 602
    CUDA_CHECK(cudaMemcpyAsync(
        buf->data(), val, size * sizeof(T), cudaMemcpyHostToDevice, stream_));
603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677
  }
  template <typename T>
  void copy_host2device(CudaBuffer<T>* buf, const HostBuffer<T>& val) {
    copy_host2device(buf, val.data(), val.size());
  }

 private:
  paddle::platform::Place place_;
  cudaStream_t stream_;
  BatchGPUValue value_;
  BatchCPUValue buf_;
  int ins_num_ = 0;
  int pv_num_ = 0;

  bool enable_pv_ = false;
  int used_float_num_ = 0;
  int used_uint64_num_ = 0;
  int used_slot_size_ = 0;

  CudaBuffer<UsedSlotGpuType> gpu_slots_;
  std::vector<UsedSlotGpuType> gpu_used_slots_;
  std::vector<SlotRecord> ins_vec_;
  const SlotRecord* batch_ins_ = nullptr;

  // uint64 tensor
  LoDTensor uint64_tensor_;
  // float tensor
  LoDTensor float_tensor_;
  // batch
  HostBuffer<size_t> offsets_;
  HostBuffer<void*> h_tensor_ptrs_;

  std::shared_ptr<phi::Allocation> gpu_slot_offsets_ = nullptr;
  std::shared_ptr<phi::Allocation> slot_buf_ptr_ = nullptr;
};
class MiniBatchGpuPackMgr {
  static const int MAX_DEIVCE_NUM = 16;

 public:
  MiniBatchGpuPackMgr() {
    for (int i = 0; i < MAX_DEIVCE_NUM; ++i) {
      pack_list_[i] = nullptr;
    }
  }
  ~MiniBatchGpuPackMgr() {
    for (int i = 0; i < MAX_DEIVCE_NUM; ++i) {
      if (pack_list_[i] == nullptr) {
        continue;
      }
      delete pack_list_[i];
      pack_list_[i] = nullptr;
    }
  }
  // one device one thread
  MiniBatchGpuPack* get(const paddle::platform::Place& place,
                        const std::vector<UsedSlotInfo>& infos) {
    int device_id = place.GetDeviceId();
    if (pack_list_[device_id] == nullptr) {
      pack_list_[device_id] = new MiniBatchGpuPack(place, infos);
    } else {
      pack_list_[device_id]->reset(place);
    }
    return pack_list_[device_id];
  }

 private:
  MiniBatchGpuPack* pack_list_[MAX_DEIVCE_NUM];
};
// global mgr
inline MiniBatchGpuPackMgr& BatchGpuPackMgr() {
  static MiniBatchGpuPackMgr mgr;
  return mgr;
}
#endif

T
Thunderbrook 已提交
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712
typedef paddle::framework::CustomParser* (*CreateParserObjectFunc)();

class DLManager {
  struct DLHandle {
    void* module;
    paddle::framework::CustomParser* parser;
  };

 public:
  DLManager() {}

  ~DLManager() {
#ifdef _LINUX
    std::lock_guard<std::mutex> lock(mutex_);
    for (auto it = handle_map_.begin(); it != handle_map_.end(); ++it) {
      delete it->second.parser;
      dlclose(it->second.module);
    }
#endif
  }

  bool Close(const std::string& name) {
#ifdef _LINUX
    auto it = handle_map_.find(name);
    if (it == handle_map_.end()) {
      return true;
    }
    delete it->second.parser;
    dlclose(it->second.module);
#endif
    VLOG(0) << "Not implement in windows";
    return false;
  }

  paddle::framework::CustomParser* Load(const std::string& name,
Y
yaoxuefeng 已提交
713
                                        const std::vector<SlotConf>& conf) {
T
Thunderbrook 已提交
714 715 716 717 718 719 720 721 722 723
#ifdef _LINUX
    std::lock_guard<std::mutex> lock(mutex_);
    DLHandle handle;
    std::map<std::string, DLHandle>::iterator it = handle_map_.find(name);
    if (it != handle_map_.end()) {
      return it->second.parser;
    }

    handle.module = dlopen(name.c_str(), RTLD_NOW);
    if (handle.module == nullptr) {
T
Thunderbrook 已提交
724
      VLOG(0) << "Create so of " << name << " fail, " << dlerror();
T
Thunderbrook 已提交
725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
      return nullptr;
    }

    CreateParserObjectFunc create_parser_func =
        (CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject");
    handle.parser = create_parser_func();
    handle.parser->Init(conf);
    handle_map_.insert({name, handle});

    return handle.parser;
#endif
    VLOG(0) << "Not implement in windows";
    return nullptr;
  }

Y
yaoxuefeng 已提交
740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767
  paddle::framework::CustomParser* Load(const std::string& name,
                                        const std::vector<AllSlotInfo>& conf) {
#ifdef _LINUX
    std::lock_guard<std::mutex> lock(mutex_);
    DLHandle handle;
    std::map<std::string, DLHandle>::iterator it = handle_map_.find(name);
    if (it != handle_map_.end()) {
      return it->second.parser;
    }
    handle.module = dlopen(name.c_str(), RTLD_NOW);
    if (handle.module == nullptr) {
      VLOG(0) << "Create so of " << name << " fail";
      exit(-1);
      return nullptr;
    }

    CreateParserObjectFunc create_parser_func =
        (CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject");
    handle.parser = create_parser_func();
    handle.parser->Init(conf);
    handle_map_.insert({name, handle});

    return handle.parser;
#endif
    VLOG(0) << "Not implement in windows";
    return nullptr;
  }

T
Thunderbrook 已提交
768
  paddle::framework::CustomParser* ReLoad(const std::string& name,
Y
yaoxuefeng 已提交
769
                                          const std::vector<SlotConf>& conf) {
T
Thunderbrook 已提交
770 771 772 773 774 775 776 777 778
    Close(name);
    return Load(name, conf);
  }

 private:
  std::mutex mutex_;
  std::map<std::string, DLHandle> handle_map_;
};

W
Wang Guibao 已提交
779 780
class DataFeed {
 public:
781 782 783
  DataFeed() {
    mutex_for_pick_file_ = nullptr;
    file_idx_ = nullptr;
H
hutuxian 已提交
784 785
    mutex_for_fea_num_ = nullptr;
    total_fea_num_ = nullptr;
786
  }
W
Wang Guibao 已提交
787
  virtual ~DataFeed() {}
H
hutuxian 已提交
788
  virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
W
Wang Guibao 已提交
789
  virtual bool CheckFile(const char* filename) {
790 791
    PADDLE_THROW(platform::errors::Unimplemented(
        "This function(CheckFile) is not implemented."));
W
Wang Guibao 已提交
792 793 794 795 796 797
  }
  // Set filelist for DataFeed.
  // Pay attention that it must init all readers before call this function.
  // Otherwise, Init() function will init finish_set_filelist_ flag.
  virtual bool SetFileList(const std::vector<std::string>& files);
  virtual bool Start() = 0;
D
dongdaxiang 已提交
798

W
Wang Guibao 已提交
799 800 801 802 803 804 805 806 807 808 809 810 811 812 813
  // The trainer calls the Next() function, and the DataFeed will load a new
  // batch to the feed_vec. The return value of this function is the batch
  // size of the current batch.
  virtual int Next() = 0;
  // Get all slots' alias which defined in protofile
  virtual const std::vector<std::string>& GetAllSlotAlias() {
    return all_slots_;
  }
  // Get used slots' alias which defined in protofile
  virtual const std::vector<std::string>& GetUseSlotAlias() {
    return use_slots_;
  }
  // This function is used for binding feed_vec memory
  virtual void AddFeedVar(Variable* var, const std::string& name);

H
hutuxian 已提交
814 815 816
  // This function is used for binding feed_vec memory in a given scope
  virtual void AssignFeedVar(const Scope& scope);

817 818 819 820 821 822 823
  // This function will do nothing at default
  virtual void SetInputPvChannel(void* channel) {}
  // This function will do nothing at default
  virtual void SetOutputPvChannel(void* channel) {}
  // This function will do nothing at default
  virtual void SetConsumePvChannel(void* channel) {}

824
  // This function will do nothing at default
J
jiaqi 已提交
825 826 827
  virtual void SetInputChannel(void* channel) {}
  // This function will do nothing at default
  virtual void SetOutputChannel(void* channel) {}
828
  // This function will do nothing at default
J
jiaqi 已提交
829
  virtual void SetConsumeChannel(void* channel) {}
830
  // This function will do nothing at default
831
  virtual void SetThreadId(int thread_id) {}
832
  // This function will do nothing at default
833
  virtual void SetThreadNum(int thread_num) {}
834 835
  // This function will do nothing at default
  virtual void SetParseInsId(bool parse_ins_id) {}
836
  virtual void SetParseUid(bool parse_uid) {}
837
  virtual void SetParseContent(bool parse_content) {}
838 839 840
  virtual void SetParseLogKey(bool parse_logkey) {}
  virtual void SetEnablePvMerge(bool enable_pv_merge) {}
  virtual void SetCurrentPhase(int current_phase) {}
841 842 843
  virtual void SetFileListMutex(std::mutex* mutex) {
    mutex_for_pick_file_ = mutex;
  }
H
hutuxian 已提交
844
  virtual void SetFeaNumMutex(std::mutex* mutex) { mutex_for_fea_num_ = mutex; }
845
  virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
H
hutuxian 已提交
846
  virtual void SetFeaNum(uint64_t* fea_num) { total_fea_num_ = fea_num; }
847 848 849 850 851 852 853
  virtual const std::vector<std::string>& GetInsIdVec() const {
    return ins_id_vec_;
  }
  virtual const std::vector<std::string>& GetInsContentVec() const {
    return ins_content_vec_;
  }
  virtual int GetCurBatchSize() { return batch_size_; }
854
  virtual void LoadIntoMemory() {
855 856
    PADDLE_THROW(platform::errors::Unimplemented(
        "This function(LoadIntoMemory) is not implemented."));
857
  }
858 859 860 861
  virtual void SetPlace(const paddle::platform::Place& place) {
    place_ = place;
  }
  virtual const paddle::platform::Place& GetPlace() const { return place_; }
862

W
Wang Guibao 已提交
863 864 865 866 867 868 869 870 871 872 873 874
 protected:
  // The following three functions are used to check if it is executed in this
  // order:
  //   Init() -> SetFileList() -> Start() -> Next()
  virtual void CheckInit();
  virtual void CheckSetFileList();
  virtual void CheckStart();
  virtual void SetBatchSize(
      int batch);  // batch size will be set in Init() function
  // This function is used to pick one file from the global filelist(thread
  // safe).
  virtual bool PickOneFile(std::string* filename);
875
  virtual void CopyToFeedTensor(void* dst, const void* src, size_t size);
W
Wang Guibao 已提交
876

877 878 879
  std::vector<std::string> filelist_;
  size_t* file_idx_;
  std::mutex* mutex_for_pick_file_;
H
hutuxian 已提交
880 881 882
  std::mutex* mutex_for_fea_num_ = nullptr;
  uint64_t* total_fea_num_ = nullptr;
  uint64_t fea_num_ = 0;
W
Wang Guibao 已提交
883 884 885 886 887 888 889 890 891 892

  // the alias of used slots, and its order is determined by
  // data_feed_desc(proto object)
  std::vector<std::string> use_slots_;
  std::vector<bool> use_slots_is_dense_;

  // the alias of all slots, and its order is determined by data_feed_desc(proto
  // object)
  std::vector<std::string> all_slots_;
  std::vector<std::string> all_slots_type_;
893
  std::vector<std::vector<int>> use_slots_shape_;
894 895
  std::vector<int> inductive_shape_index_;
  std::vector<int> total_dims_without_inductive_;
H
hutuxian 已提交
896 897
  // For the inductive shape passed within data
  std::vector<std::vector<int>> multi_inductive_shape_index_;
W
Wang Guibao 已提交
898 899 900 901
  std::vector<int>
      use_slots_index_;  // -1: not used; >=0: the index of use_slots_

  // The data read by DataFeed will be stored here
902
  std::vector<LoDTensor*> feed_vec_;
W
Wang Guibao 已提交
903

904 905
  LoDTensor* rank_offset_;

W
Wang Guibao 已提交
906 907 908 909 910 911
  // the batch size defined by user
  int default_batch_size_;
  // current batch size
  int batch_size_;

  bool finish_init_;
912
  bool finish_set_filelist_;
W
Wang Guibao 已提交
913
  bool finish_start_;
914
  std::string pipe_command_;
T
Thunderbrook 已提交
915 916
  std::string so_parser_name_;
  std::vector<SlotConf> slot_conf_;
917 918
  std::vector<std::string> ins_id_vec_;
  std::vector<std::string> ins_content_vec_;
919
  platform::Place place_;
920
  std::string uid_slot_;
921 922 923

  // The input type of pipe reader, 0 for one sample, 1 for one batch
  int input_type_;
W
Wang Guibao 已提交
924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944
};

// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
// It use a read-thread to read file and parse data to a private-queue
// (thread level), and get data from this queue when trainer call Next().
template <typename T>
class PrivateQueueDataFeed : public DataFeed {
 public:
  PrivateQueueDataFeed() {}
  virtual ~PrivateQueueDataFeed() {}
  virtual bool Start();
  virtual int Next();

 protected:
  // The thread implementation function for reading file and parse.
  virtual void ReadThread();
  // This function is used to set private-queue size, and the most
  // efficient when the queue size is close to the batch size.
  virtual void SetQueueSize(int queue_size);
  // The reading and parsing method called in the ReadThread.
  virtual bool ParseOneInstance(T* instance) = 0;
D
dongdaxiang 已提交
945
  virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
W
Wang Guibao 已提交
946
  // This function is used to put instance to vec_ins
947 948
  virtual void AddInstanceToInsVec(T* vec_ins,
                                   const T& instance,
W
Wang Guibao 已提交
949 950 951 952 953 954 955 956 957 958 959 960
                                   int index) = 0;
  // This function is used to put ins_vec to feed_vec
  virtual void PutToFeedVec(const T& ins_vec) = 0;

  // The thread for read files
  std::thread read_thread_;
  // using ifstream one line and one line parse is faster
  // than using fread one buffer and one buffer parse.
  //   for a 601M real data:
  //     ifstream one line and one line parse: 6034 ms
  //     fread one buffer and one buffer parse: 7097 ms
  std::ifstream file_;
D
dongdaxiang 已提交
961
  std::shared_ptr<FILE> fp_;
W
Wang Guibao 已提交
962
  size_t queue_size_;
963
  string::LineFileReader reader_;
W
Wang Guibao 已提交
964
  // The queue for store parsed data
965
  std::shared_ptr<paddle::framework::ChannelObject<T>> queue_;
W
Wang Guibao 已提交
966 967
};

968
template <typename T>
J
jiaqi 已提交
969
class InMemoryDataFeed : public DataFeed {
970 971 972
 public:
  InMemoryDataFeed();
  virtual ~InMemoryDataFeed() {}
H
hutuxian 已提交
973
  virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
974 975
  virtual bool Start();
  virtual int Next();
976 977 978 979
  virtual void SetInputPvChannel(void* channel);
  virtual void SetOutputPvChannel(void* channel);
  virtual void SetConsumePvChannel(void* channel);

J
jiaqi 已提交
980 981 982
  virtual void SetInputChannel(void* channel);
  virtual void SetOutputChannel(void* channel);
  virtual void SetConsumeChannel(void* channel);
983 984
  virtual void SetThreadId(int thread_id);
  virtual void SetThreadNum(int thread_num);
985
  virtual void SetParseInsId(bool parse_ins_id);
986
  virtual void SetParseUid(bool parse_uid);
987
  virtual void SetParseContent(bool parse_content);
988 989 990
  virtual void SetParseLogKey(bool parse_logkey);
  virtual void SetEnablePvMerge(bool enable_pv_merge);
  virtual void SetCurrentPhase(int current_phase);
991
  virtual void LoadIntoMemory();
T
Thunderbrook 已提交
992
  virtual void LoadIntoMemoryFromSo();
Y
yaoxuefeng 已提交
993 994 995 996 997
  virtual void SetRecord(T* records) { records_ = records; }
  int GetDefaultBatchSize() { return default_batch_size_; }
  void AddBatchOffset(const std::pair<int, int>& offset) {
    batch_offsets_.push_back(offset);
  }
X
xujiaqi01 已提交
998

999 1000 1001
 protected:
  virtual bool ParseOneInstance(T* instance) = 0;
  virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
1002 1003
  virtual void ParseOneInstanceFromSo(const char* str,
                                      T* instance,
T
Thunderbrook 已提交
1004
                                      CustomParser* parser) {}
1005 1006
  virtual int ParseInstanceFromSo(int len,
                                  const char* str,
T
Thunderbrook 已提交
1007 1008 1009 1010
                                  std::vector<T>* instances,
                                  CustomParser* parser) {
    return 0;
  }
J
jiaqi 已提交
1011
  virtual void PutToFeedVec(const std::vector<T>& ins_vec) = 0;
Y
yaoxuefeng 已提交
1012
  virtual void PutToFeedVec(const T* ins_vec, int num) = 0;
1013

Y
yaoxuefeng 已提交
1014 1015 1016 1017 1018
  std::vector<std::vector<float>> batch_float_feasigns_;
  std::vector<std::vector<uint64_t>> batch_uint64_feasigns_;
  std::vector<std::vector<size_t>> offset_;
  std::vector<bool> visit_;

1019 1020
  int thread_id_;
  int thread_num_;
1021
  bool parse_ins_id_;
1022
  bool parse_uid_;
1023
  bool parse_content_;
1024 1025 1026
  bool parse_logkey_;
  bool enable_pv_merge_;
  int current_phase_{-1};  // only for untest
J
jiaqi 已提交
1027 1028 1029 1030 1031
  std::ifstream file_;
  std::shared_ptr<FILE> fp_;
  paddle::framework::ChannelObject<T>* input_channel_;
  paddle::framework::ChannelObject<T>* output_channel_;
  paddle::framework::ChannelObject<T>* consume_channel_;
1032 1033 1034 1035

  paddle::framework::ChannelObject<PvInstance>* input_pv_channel_;
  paddle::framework::ChannelObject<PvInstance>* output_pv_channel_;
  paddle::framework::ChannelObject<PvInstance>* consume_pv_channel_;
Y
yaoxuefeng 已提交
1036 1037 1038 1039 1040

  std::vector<std::pair<int, int>> batch_offsets_;
  uint64_t offset_index_ = 0;
  bool enable_heterps_ = false;
  T* records_ = nullptr;
1041 1042
};

W
Wang Guibao 已提交
1043 1044 1045 1046 1047
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
class MultiSlotType {
 public:
  MultiSlotType() {}
  ~MultiSlotType() {}
H
hutuxian 已提交
1048
  void Init(const std::string& type, size_t reserved_size = 0) {
W
Wang Guibao 已提交
1049 1050 1051
    CheckType(type);
    if (type_[0] == 'f') {
      float_feasign_.clear();
H
hutuxian 已提交
1052 1053 1054
      if (reserved_size) {
        float_feasign_.reserve(reserved_size);
      }
W
Wang Guibao 已提交
1055 1056
    } else if (type_[0] == 'u') {
      uint64_feasign_.clear();
H
hutuxian 已提交
1057 1058 1059
      if (reserved_size) {
        uint64_feasign_.reserve(reserved_size);
      }
W
Wang Guibao 已提交
1060 1061 1062
    }
    type_ = type;
  }
H
hutuxian 已提交
1063 1064 1065 1066
  void InitOffset(size_t max_batch_size = 0) {
    if (max_batch_size > 0) {
      offset_.reserve(max_batch_size + 1);
    }
W
Wang Guibao 已提交
1067 1068 1069 1070 1071 1072
    offset_.resize(1);
    // LoDTensor' lod is counted from 0, the size of lod
    // is one size larger than the size of data.
    offset_[0] = 0;
  }
  const std::vector<size_t>& GetOffset() const { return offset_; }
1073
  std::vector<size_t>& MutableOffset() { return offset_; }
W
Wang Guibao 已提交
1074 1075 1076 1077 1078 1079 1080 1081
  void AddValue(const float v) {
    CheckFloat();
    float_feasign_.push_back(v);
  }
  void AddValue(const uint64_t v) {
    CheckUint64();
    uint64_feasign_.push_back(v);
  }
H
hutuxian 已提交
1082 1083 1084 1085 1086 1087 1088 1089 1090 1091
  void CopyValues(const float* input, size_t size) {
    CheckFloat();
    float_feasign_.resize(size);
    memcpy(float_feasign_.data(), input, size * sizeof(float));
  }
  void CopyValues(const uint64_t* input, size_t size) {
    CheckUint64();
    uint64_feasign_.resize(size);
    memcpy(uint64_feasign_.data(), input, size * sizeof(uint64_t));
  }
W
Wang Guibao 已提交
1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104
  void AddIns(const MultiSlotType& ins) {
    if (ins.GetType()[0] == 'f') {  // float
      CheckFloat();
      auto& vec = ins.GetFloatData();
      offset_.push_back(offset_.back() + vec.size());
      float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end());
    } else if (ins.GetType()[0] == 'u') {  // uint64
      CheckUint64();
      auto& vec = ins.GetUint64Data();
      offset_.push_back(offset_.back() + vec.size());
      uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
    }
  }
H
hutuxian 已提交
1105 1106 1107 1108 1109 1110 1111 1112
  void AppendValues(const uint64_t* input, size_t size) {
    CheckUint64();
    offset_.push_back(offset_.back() + size);
    uint64_feasign_.insert(uint64_feasign_.end(), input, input + size);
  }
  void AppendValues(const float* input, size_t size) {
    CheckFloat();
    offset_.push_back(offset_.back() + size);
1113

H
hutuxian 已提交
1114 1115
    float_feasign_.insert(float_feasign_.end(), input, input + size);
  }
W
Wang Guibao 已提交
1116
  const std::vector<float>& GetFloatData() const { return float_feasign_; }
1117
  std::vector<float>& MutableFloatData() { return float_feasign_; }
W
Wang Guibao 已提交
1118
  const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
1119
  std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
W
Wang Guibao 已提交
1120
  const std::string& GetType() const { return type_; }
H
hutuxian 已提交
1121
  size_t GetBatchSize() { return offset_.size() - 1; }
1122
  std::string& MutableType() { return type_; }
W
Wang Guibao 已提交
1123

X
xujiaqi01 已提交
1124 1125
  std::string DebugString() {
    std::stringstream ss;
W
wanghuancoder 已提交
1126

1127 1128
    ss << "\ntype: " << type_ << "\n";
    ss << "offset: ";
X
xujiaqi01 已提交
1129 1130 1131 1132
    ss << "[";
    for (const size_t& i : offset_) {
      ss << offset_[i] << ",";
    }
1133
    ss << "]\ndata: [";
X
xujiaqi01 已提交
1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146
    if (type_[0] == 'f') {
      for (const float& i : float_feasign_) {
        ss << i << ",";
      }
    } else {
      for (const uint64_t& i : uint64_feasign_) {
        ss << i << ",";
      }
    }
    ss << "]\n";
    return ss.str();
  }

W
Wang Guibao 已提交
1147 1148
 private:
  void CheckType(const std::string& type) const {
1149 1150
    PADDLE_ENFORCE_EQ((type == "uint64" || type == "float"),
                      true,
1151 1152 1153 1154
                      platform::errors::InvalidArgument(
                          "MultiSlotType error, expect type is uint64 or "
                          "float, but received type is %s.",
                          type));
W
Wang Guibao 已提交
1155 1156
  }
  void CheckFloat() const {
1157
    PADDLE_ENFORCE_EQ(
1158 1159
        type_[0],
        'f',
1160 1161
        platform::errors::InvalidArgument(
            "MultiSlotType error, add %s value to float slot.", type_));
W
Wang Guibao 已提交
1162 1163
  }
  void CheckUint64() const {
1164
    PADDLE_ENFORCE_EQ(
1165 1166
        type_[0],
        'u',
1167 1168
        platform::errors::InvalidArgument(
            "MultiSlotType error, add %s value to uint64 slot.", type_));
W
Wang Guibao 已提交
1169 1170 1171 1172 1173 1174 1175
  }
  std::vector<float> float_feasign_;
  std::vector<uint64_t> uint64_feasign_;
  std::string type_;
  std::vector<size_t> offset_;
};

J
jiaqi 已提交
1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
                                           const MultiSlotType& ins) {
  ar << ins.GetType();
#ifdef _LINUX
  ar << ins.GetOffset();
#else
  const auto& offset = ins.GetOffset();
  ar << (uint64_t)offset.size();
  for (const size_t& x : offset) {
    ar << (const uint64_t)x;
  }
#endif
  ar << ins.GetFloatData();
  ar << ins.GetUint64Data();
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
                                           MultiSlotType& ins) {
  ar >> ins.MutableType();
#ifdef _LINUX
  ar >> ins.MutableOffset();
#else
  auto& offset = ins.MutableOffset();
  offset.resize(ar.template Get<uint64_t>());
  for (size_t& x : offset) {
    uint64_t t;
    ar >> t;
Y
yaoxuefeng 已提交
1206
    x = static_cast<size_t>(t);
J
jiaqi 已提交
1207 1208 1209 1210 1211 1212 1213
  }
#endif
  ar >> ins.MutableFloatData();
  ar >> ins.MutableUint64Data();
  return ar;
}

1214 1215
struct RecordCandidate {
  std::string ins_id_;
T
Thunderbrook 已提交
1216
  std::unordered_multimap<uint16_t, FeatureFeasign> feas_;
1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228
  size_t shadow_index_ = -1;  // Optimization for Reservoir Sample

  RecordCandidate() {}
  RecordCandidate(const Record& rec,
                  const std::unordered_set<uint16_t>& slot_index_to_replace) {
    for (const auto& fea : rec.uint64_feasigns_) {
      if (slot_index_to_replace.find(fea.slot()) !=
          slot_index_to_replace.end()) {
        feas_.insert({fea.slot(), fea.sign()});
      }
    }
  }
1229 1230

  RecordCandidate& operator=(const Record& rec) {
1231
    feas_.clear();
1232 1233
    ins_id_ = rec.ins_id_;
    for (auto& fea : rec.uint64_feasigns_) {
1234
      feas_.insert({fea.slot(), fea.sign()});
1235 1236 1237 1238 1239 1240 1241 1242
    }
    return *this;
  }
};

class RecordCandidateList {
 public:
  RecordCandidateList() = default;
1243
  RecordCandidateList(const RecordCandidateList&) {}
1244

1245
  size_t Size() { return cur_size_; }
1246 1247 1248
  void ReSize(size_t length);

  void ReInit();
1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260
  void ReInitPass() {
    for (size_t i = 0; i < cur_size_; ++i) {
      if (candidate_list_[i].shadow_index_ != i) {
        candidate_list_[i].ins_id_ =
            candidate_list_[candidate_list_[i].shadow_index_].ins_id_;
        candidate_list_[i].feas_.swap(
            candidate_list_[candidate_list_[i].shadow_index_].feas_);
        candidate_list_[i].shadow_index_ = i;
      }
    }
    candidate_list_.resize(cur_size_);
  }
1261 1262

  void AddAndGet(const Record& record, RecordCandidate* result);
1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284
  void AddAndGet(const Record& record, size_t& index_result) {  // NOLINT
    // std::unique_lock<std::mutex> lock(mutex_);
    size_t index = 0;
    ++total_size_;
    auto fleet_ptr = FleetWrapper::GetInstance();
    if (!full_) {
      candidate_list_.emplace_back(record, slot_index_to_replace_);
      candidate_list_.back().shadow_index_ = cur_size_;
      ++cur_size_;
      full_ = (cur_size_ == capacity_);
    } else {
      index = fleet_ptr->LocalRandomEngine()() % total_size_;
      if (index < capacity_) {
        candidate_list_.emplace_back(record, slot_index_to_replace_);
        candidate_list_[index].shadow_index_ = candidate_list_.size() - 1;
      }
    }
    index = fleet_ptr->LocalRandomEngine()() % cur_size_;
    index_result = candidate_list_[index].shadow_index_;
  }
  const RecordCandidate& Get(size_t index) const {
    PADDLE_ENFORCE_LT(
1285 1286
        index,
        candidate_list_.size(),
1287 1288
        platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
                                     "elements in candidate_list[%lu].",
1289 1290
                                     index,
                                     candidate_list_.size()));
1291 1292 1293 1294 1295 1296
    return candidate_list_[index];
  }
  void SetSlotIndexToReplace(
      const std::unordered_set<uint16_t>& slot_index_to_replace) {
    slot_index_to_replace_ = slot_index_to_replace;
  }
1297 1298

 private:
1299 1300 1301 1302 1303 1304 1305
  size_t capacity_ = 0;
  std::mutex mutex_;
  bool full_ = false;
  size_t cur_size_ = 0;
  size_t total_size_ = 0;
  std::vector<RecordCandidate> candidate_list_;
  std::unordered_set<uint16_t> slot_index_to_replace_;
1306 1307
};

J
jiaqi 已提交
1308 1309
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
T
Thunderbrook 已提交
1310
                                           const FeatureFeasign& fk) {
J
jiaqi 已提交
1311 1312 1313 1314 1315 1316 1317
  ar << fk.uint64_feasign_;
  ar << fk.float_feasign_;
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
T
Thunderbrook 已提交
1318
                                           FeatureFeasign& fk) {
J
jiaqi 已提交
1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357
  ar >> fk.uint64_feasign_;
  ar >> fk.float_feasign_;
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
                                           const FeatureItem& fi) {
  ar << fi.sign();
  ar << fi.slot();
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
                                           FeatureItem& fi) {
  ar >> fi.sign();
  ar >> fi.slot();
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
                                           const Record& r) {
  ar << r.uint64_feasigns_;
  ar << r.float_feasigns_;
  ar << r.ins_id_;
  return ar;
}

template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
                                           Record& r) {
  ar >> r.uint64_feasigns_;
  ar >> r.float_feasigns_;
  ar >> r.ins_id_;
  return ar;
}

W
Wang Guibao 已提交
1358 1359 1360 1361 1362 1363 1364 1365
// This DataFeed is used to feed multi-slot type data.
// The format of multi-slot type data:
//   [n feasign_0 feasign_1 ... feasign_n]*
class MultiSlotDataFeed
    : public PrivateQueueDataFeed<std::vector<MultiSlotType>> {
 public:
  MultiSlotDataFeed() {}
  virtual ~MultiSlotDataFeed() {}
H
hutuxian 已提交
1366
  virtual void Init(const DataFeedDesc& data_feed_desc);
W
Wang Guibao 已提交
1367 1368 1369
  virtual bool CheckFile(const char* filename);

 protected:
D
dongdaxiang 已提交
1370
  virtual void ReadThread();
W
Wang Guibao 已提交
1371 1372 1373 1374
  virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
                                   const std::vector<MultiSlotType>& instance,
                                   int index);
  virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
D
dongdaxiang 已提交
1375
  virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
W
Wang Guibao 已提交
1376 1377
  virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
};
1378

J
jiaqi 已提交
1379
class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
1380 1381 1382
 public:
  MultiSlotInMemoryDataFeed() {}
  virtual ~MultiSlotInMemoryDataFeed() {}
H
hutuxian 已提交
1383
  virtual void Init(const DataFeedDesc& data_feed_desc);
Y
yaoxuefeng 已提交
1384
  // void SetRecord(Record* records) { records_ = records; }
1385

1386
 protected:
J
jiaqi 已提交
1387 1388
  virtual bool ParseOneInstance(Record* instance);
  virtual bool ParseOneInstanceFromPipe(Record* instance);
1389 1390
  virtual void ParseOneInstanceFromSo(const char* str,
                                      Record* instance,
T
Thunderbrook 已提交
1391
                                      CustomParser* parser){};
1392 1393
  virtual int ParseInstanceFromSo(int len,
                                  const char* str,
T
Thunderbrook 已提交
1394 1395
                                  std::vector<Record>* instances,
                                  CustomParser* parser);
J
jiaqi 已提交
1396
  virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
1397 1398 1399 1400
  virtual void GetMsgFromLogKey(const std::string& log_key,
                                uint64_t* search_id,
                                uint32_t* cmatch,
                                uint32_t* rank);
Y
yaoxuefeng 已提交
1401
  virtual void PutToFeedVec(const Record* ins_vec, int num);
1402 1403
};

Y
yaoxuefeng 已提交
1404 1405 1406
class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
 public:
  SlotRecordInMemoryDataFeed() {}
1407 1408 1409 1410 1411 1412 1413
  virtual ~SlotRecordInMemoryDataFeed() {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
    if (pack_ != nullptr) {
      pack_ = nullptr;
    }
#endif
  }
Y
yaoxuefeng 已提交
1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435
  virtual void Init(const DataFeedDesc& data_feed_desc);
  virtual void LoadIntoMemory();
  void ExpandSlotRecord(SlotRecord* ins);

 protected:
  virtual bool Start();
  virtual int Next();
  virtual bool ParseOneInstance(SlotRecord* instance) { return false; }
  virtual bool ParseOneInstanceFromPipe(SlotRecord* instance) { return false; }
  // virtual void ParseOneInstanceFromSo(const char* str, T* instance,
  //                                    CustomParser* parser) {}
  virtual void PutToFeedVec(const std::vector<SlotRecord>& ins_vec) {}

  virtual void LoadIntoMemoryByCommand(void);
  virtual void LoadIntoMemoryByLib(void);
  virtual void LoadIntoMemoryByLine(void);
  virtual void LoadIntoMemoryByFile(void);
  virtual void SetInputChannel(void* channel) {
    input_channel_ = static_cast<ChannelObject<SlotRecord>*>(channel);
  }
  bool ParseOneInstance(const std::string& line, SlotRecord* rec);
  virtual void PutToFeedVec(const SlotRecord* ins_vec, int num);
1436 1437 1438
  virtual void AssignFeedVar(const Scope& scope);
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
  void BuildSlotBatchGPU(const int ins_num);
1439 1440
  void FillSlotValueOffset(const int ins_num,
                           const int used_slot_num,
1441 1442
                           size_t* slot_value_offsets,
                           const int* uint64_offsets,
1443 1444
                           const int uint64_slot_size,
                           const int* float_offsets,
1445 1446
                           const int float_slot_size,
                           const UsedSlotGpuType* used_slots);
1447 1448 1449
  void CopyForTensor(const int ins_num,
                     const int used_slot_num,
                     void** dest,
1450
                     const size_t* slot_value_offsets,
1451 1452 1453 1454 1455 1456 1457 1458
                     const uint64_t* uint64_feas,
                     const int* uint64_offsets,
                     const int* uint64_ins_lens,
                     const int uint64_slot_size,
                     const float* float_feas,
                     const int* float_offsets,
                     const int* float_ins_lens,
                     const int float_slot_size,
1459 1460
                     const UsedSlotGpuType* used_slots);
#endif
Y
yaoxuefeng 已提交
1461 1462 1463 1464 1465 1466 1467 1468
  float sample_rate_ = 1.0f;
  int use_slot_size_ = 0;
  int float_use_slot_size_ = 0;
  int uint64_use_slot_size_ = 0;
  std::vector<AllSlotInfo> all_slots_info_;
  std::vector<UsedSlotInfo> used_slots_info_;
  size_t float_total_dims_size_ = 0;
  std::vector<int> float_total_dims_without_inductives_;
1469 1470 1471 1472

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
  MiniBatchGpuPack* pack_ = nullptr;
#endif
Y
yaoxuefeng 已提交
1473 1474
};

1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
 public:
  PaddleBoxDataFeed() {}
  virtual ~PaddleBoxDataFeed() {}

 protected:
  virtual void Init(const DataFeedDesc& data_feed_desc);
  virtual bool Start();
  virtual int Next();
  virtual void AssignFeedVar(const Scope& scope);
  virtual void PutToFeedVec(const std::vector<PvInstance>& pv_vec);
  virtual void PutToFeedVec(const std::vector<Record*>& ins_vec);
  virtual int GetCurrentPhase();
  virtual void GetRankOffset(const std::vector<PvInstance>& pv_vec,
                             int ins_number);
  std::string rank_offset_name_;
  int pv_batch_size_;
1492 1493
};

1494
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
H
hutuxian 已提交
1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542
template <typename T>
class PrivateInstantDataFeed : public DataFeed {
 public:
  PrivateInstantDataFeed() {}
  virtual ~PrivateInstantDataFeed() {}
  void Init(const DataFeedDesc& data_feed_desc) override;
  bool Start() override { return true; }
  int Next() override;

 protected:
  // The batched data buffer
  std::vector<MultiSlotType> ins_vec_;

  // This function is used to preprocess with a given filename, e.g. open it or
  // mmap
  virtual bool Preprocess(const std::string& filename) = 0;

  // This function is used to postprocess system resource such as closing file
  // NOTICE: Ensure that it is safe to call before Preprocess
  virtual bool Postprocess() = 0;

  // The reading and parsing method.
  virtual bool ParseOneMiniBatch() = 0;

  // This function is used to put ins_vec to feed_vec
  virtual void PutToFeedVec();
};

class MultiSlotFileInstantDataFeed
    : public PrivateInstantDataFeed<std::vector<MultiSlotType>> {
 public:
  MultiSlotFileInstantDataFeed() {}
  virtual ~MultiSlotFileInstantDataFeed() {}

 protected:
  int fd_{-1};
  char* buffer_{nullptr};
  size_t end_{0};
  size_t offset_{0};

  bool Preprocess(const std::string& filename) override;

  bool Postprocess() override;

  bool ParseOneMiniBatch() override;
};
#endif

W
Wang Guibao 已提交
1543 1544
}  // namespace framework
}  // namespace paddle