data_feed.cc 37.2 KB
Newer Older
W
Wang Guibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

D
dongdaxiang 已提交
15 16 17 18 19
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif

20
#include "paddle/fluid/framework/data_feed.h"
D
dongdaxiang 已提交
21
#ifdef _LINUX
D
dongdaxiang 已提交
22
#include <stdio_ext.h>
H
hutuxian 已提交
23 24 25
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
D
dongdaxiang 已提交
26
#endif
27
#include <utility>
28
#include "gflags/gflags.h"
W
Wang Guibao 已提交
29 30 31
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
32 33
#include "io/fs.h"
#include "io/shell.h"
W
Wang Guibao 已提交
34 35
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
36
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
37
#include "paddle/fluid/platform/timer.h"
W
Wang Guibao 已提交
38 39 40 41

namespace paddle {
namespace framework {

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
void RecordCandidateList::ReSize(size_t length) {
  _mutex.lock();
  _capacity = length;
  CHECK(_capacity > 0);  // NOLINT
  _candidate_list.clear();
  _candidate_list.resize(_capacity);
  _full = false;
  _cur_size = 0;
  _total_size = 0;
  _mutex.unlock();
}

void RecordCandidateList::ReInit() {
  _mutex.lock();
  _full = false;
  _cur_size = 0;
  _total_size = 0;
  _mutex.unlock();
}

void RecordCandidateList::AddAndGet(const Record& record,
                                    RecordCandidate* result) {
  _mutex.lock();
  size_t index = 0;
  ++_total_size;
  auto fleet_ptr = FleetWrapper::GetInstance();
  if (!_full) {
    _candidate_list[_cur_size++] = record;
    _full = (_cur_size == _capacity);
  } else {
    CHECK(_cur_size == _capacity);
    index = fleet_ptr->LocalRandomEngine()() % _total_size;
    if (index < _capacity) {
      _candidate_list[index] = record;
    }
  }
  index = fleet_ptr->LocalRandomEngine()() % _cur_size;
  *result = _candidate_list[index];
  _mutex.unlock();
}

W
Wang Guibao 已提交
83 84 85 86
void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
  CheckInit();
  for (size_t i = 0; i < use_slots_.size(); ++i) {
    if (name == use_slots_[i]) {
87 88 89 90 91
      if (var == nullptr) {
        feed_vec_[i] = nullptr;
      } else {
        feed_vec_[i] = var->GetMutable<LoDTensor>();
      }
W
Wang Guibao 已提交
92 93 94 95 96
    }
  }
}

bool DataFeed::SetFileList(const std::vector<std::string>& files) {
97
  std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
W
Wang Guibao 已提交
98
  CheckInit();
99 100
  // Do not set finish_set_filelist_ flag,
  // since a user may set file many times after init reader
W
Wang Guibao 已提交
101 102 103 104 105 106 107 108 109 110 111 112
  filelist_.assign(files.begin(), files.end());

  finish_set_filelist_ = true;
  return true;
}

void DataFeed::SetBatchSize(int batch_size) {
  PADDLE_ENFORCE(batch_size > 0, "Illegal batch size: %d.", batch_size);
  default_batch_size_ = batch_size;
}

bool DataFeed::PickOneFile(std::string* filename) {
113 114 115 116 117 118
  PADDLE_ENFORCE(mutex_for_pick_file_ != nullptr,
                 "should call SetFileListMutex before PickOneFile");
  PADDLE_ENFORCE(file_idx_ != nullptr,
                 "should call SetFileListIndex before PickOneFile");
  std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
  if (*file_idx_ == filelist_.size()) {
119
    VLOG(3) << "DataFeed::PickOneFile no more file to pick";
W
Wang Guibao 已提交
120 121
    return false;
  }
122 123
  VLOG(3) << "file_idx_=" << *file_idx_;
  *filename = filelist_[(*file_idx_)++];
W
Wang Guibao 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
  return true;
}

void DataFeed::CheckInit() {
  PADDLE_ENFORCE(finish_init_, "Initialization did not succeed.");
}

void DataFeed::CheckSetFileList() {
  PADDLE_ENFORCE(finish_set_filelist_, "Set filelist did not succeed.");
}

void DataFeed::CheckStart() {
  PADDLE_ENFORCE(finish_start_, "Datafeed has not started running yet.");
}

H
hutuxian 已提交
139 140 141 142 143 144 145
void DataFeed::AssignFeedVar(const Scope& scope) {
  CheckInit();
  for (size_t i = 0; i < use_slots_.size(); ++i) {
    feed_vec_[i] = scope.FindVar(use_slots_[i])->GetMutable<LoDTensor>();
  }
}

146 147 148 149 150 151 152 153 154 155 156 157
void DataFeed::CopyToFeedTensor(void* dst, const void* src, size_t size) {
  if (platform::is_cpu_place(this->place_)) {
    memcpy(dst, src, size);
  } else {
#ifdef PADDLE_WITH_CUDA
    cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
#else
    PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
#endif
  }
}

W
Wang Guibao 已提交
158 159 160 161
template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
  PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size);
  queue_size_ = queue_size;
162
  queue_ = paddle::framework::MakeChannel<T>();
J
jiaqi 已提交
163
  queue_->SetCapacity(queue_size);
W
Wang Guibao 已提交
164 165 166 167 168
}

template <typename T>
bool PrivateQueueDataFeed<T>::Start() {
  CheckSetFileList();
169 170
  read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this);
  read_thread_.detach();
W
Wang Guibao 已提交
171 172 173 174 175 176 177

  finish_start_ = true;
  return true;
}

template <typename T>
void PrivateQueueDataFeed<T>::ReadThread() {
D
dongdaxiang 已提交
178
#ifdef _LINUX
179 180 181 182 183 184 185
  std::string filename;
  while (PickOneFile(&filename)) {
    int err_no = 0;
    fp_ = fs_open_read(filename, &err_no, pipe_command_);
    __fsetlocking(&*fp_, FSETLOCKING_BYCALLER);
    T instance;
    while (ParseOneInstanceFromPipe(&instance)) {
186
      queue_->Put(instance);
187
    }
W
Wang Guibao 已提交
188
  }
189
  queue_->Close();
D
dongdaxiang 已提交
190
#endif
W
Wang Guibao 已提交
191 192 193 194
}

template <typename T>
int PrivateQueueDataFeed<T>::Next() {
X
xjqbest 已提交
195
#ifdef _LINUX
W
Wang Guibao 已提交
196 197 198 199
  CheckStart();
  int index = 0;
  T ins_vec;
  while (index < default_batch_size_) {
200 201
    T instance;
    if (!queue_->Get(instance)) {
W
Wang Guibao 已提交
202 203 204 205 206 207 208 209 210
      break;
    }
    AddInstanceToInsVec(&ins_vec, instance, index++);
  }
  batch_size_ = index;
  if (batch_size_ != 0) {
    PutToFeedVec(ins_vec);
  }
  return batch_size_;
X
xjqbest 已提交
211 212 213
#else
  return 0;
#endif
W
Wang Guibao 已提交
214 215
}

216
// explicit instantiation
W
Wang Guibao 已提交
217 218
template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;

219 220
template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() {
221 222
  this->file_idx_ = nullptr;
  this->mutex_for_pick_file_ = nullptr;
J
jiaqi 已提交
223 224 225
  this->fp_ = nullptr;
  this->thread_id_ = 0;
  this->thread_num_ = 1;
226
  this->parse_ins_id_ = false;
227
  this->parse_content_ = false;
J
jiaqi 已提交
228 229 230
  this->input_channel_ = nullptr;
  this->output_channel_ = nullptr;
  this->consume_channel_ = nullptr;
231 232 233 234
}

template <typename T>
bool InMemoryDataFeed<T>::Start() {
X
xjqbest 已提交
235
#ifdef _LINUX
J
jiaqi 已提交
236 237 238 239 240
  this->CheckSetFileList();
  if (output_channel_->Size() == 0 && input_channel_->Size() != 0) {
    std::vector<T> data;
    input_channel_->Read(data);
    output_channel_->Write(std::move(data));
241
  }
X
xjqbest 已提交
242
#endif
J
jiaqi 已提交
243
  this->finish_start_ = true;
244 245 246 247 248
  return true;
}

template <typename T>
int InMemoryDataFeed<T>::Next() {
X
xjqbest 已提交
249
#ifdef _LINUX
J
jiaqi 已提交
250 251 252 253 254
  this->CheckStart();
  CHECK(output_channel_ != nullptr);
  CHECK(consume_channel_ != nullptr);
  VLOG(3) << "output_channel_ size=" << output_channel_->Size()
          << ", consume_channel_ size=" << consume_channel_->Size()
X
xujiaqi01 已提交
255
          << ", thread_id=" << thread_id_;
256
  int index = 0;
D
dongdaxiang 已提交
257
  T instance;
J
jiaqi 已提交
258 259 260 261
  std::vector<T> ins_vec;
  ins_vec.reserve(this->default_batch_size_);
  while (index < this->default_batch_size_) {
    if (output_channel_->Size() == 0) {
D
dongdaxiang 已提交
262
      break;
263
    }
J
jiaqi 已提交
264 265 266 267
    output_channel_->Get(instance);
    ins_vec.push_back(instance);
    ++index;
    consume_channel_->Put(std::move(instance));
D
dongdaxiang 已提交
268
  }
J
jiaqi 已提交
269 270
  this->batch_size_ = index;
  VLOG(3) << "batch_size_=" << this->batch_size_
271
          << ", thread_id=" << thread_id_;
J
jiaqi 已提交
272
  if (this->batch_size_ != 0) {
D
dongdaxiang 已提交
273 274
    PutToFeedVec(ins_vec);
  } else {
J
jiaqi 已提交
275 276 277 278
    VLOG(3) << "finish reading, output_channel_ size="
            << output_channel_->Size()
            << ", consume_channel_ size=" << consume_channel_->Size()
            << ", thread_id=" << thread_id_;
D
dongdaxiang 已提交
279
  }
J
jiaqi 已提交
280
  return this->batch_size_;
X
xjqbest 已提交
281 282 283
#else
  return 0;
#endif
284 285
}

286
template <typename T>
J
jiaqi 已提交
287 288 289 290 291 292 293
void InMemoryDataFeed<T>::SetInputChannel(void* channel) {
  input_channel_ = static_cast<paddle::framework::ChannelObject<T>*>(channel);
}

template <typename T>
void InMemoryDataFeed<T>::SetOutputChannel(void* channel) {
  output_channel_ = static_cast<paddle::framework::ChannelObject<T>*>(channel);
294 295 296
}

template <typename T>
J
jiaqi 已提交
297 298
void InMemoryDataFeed<T>::SetConsumeChannel(void* channel) {
  consume_channel_ = static_cast<paddle::framework::ChannelObject<T>*>(channel);
299 300 301 302 303 304 305 306 307 308 309 310
}

template <typename T>
void InMemoryDataFeed<T>::SetThreadId(int thread_id) {
  thread_id_ = thread_id;
}

template <typename T>
void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
  thread_num_ = thread_num;
}

311 312 313 314 315
template <typename T>
void InMemoryDataFeed<T>::SetParseContent(bool parse_content) {
  parse_content_ = parse_content;
}

316 317 318 319 320
template <typename T>
void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
  parse_ins_id_ = parse_ins_id;
}

321 322
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
D
dongdaxiang 已提交
323
#ifdef _LINUX
X
xujiaqi01 已提交
324
  VLOG(3) << "LoadIntoMemory() begin, thread_id=" << thread_id_;
325
  std::string filename;
J
jiaqi 已提交
326
  while (this->PickOneFile(&filename)) {
X
xujiaqi01 已提交
327 328
    VLOG(3) << "PickOneFile, filename=" << filename
            << ", thread_id=" << thread_id_;
329
    int err_no = 0;
J
jiaqi 已提交
330 331 332 333
    this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_);
    CHECK(this->fp_ != nullptr);
    __fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER);
    paddle::framework::ChannelWriter<T> writer(input_channel_);
334
    T instance;
335 336
    platform::Timer timeline;
    timeline.Start();
D
dongdaxiang 已提交
337
    while (ParseOneInstanceFromPipe(&instance)) {
J
jiaqi 已提交
338 339
      writer << std::move(instance);
      instance = T();
340
    }
J
jiaqi 已提交
341
    writer.Flush();
342
    timeline.Pause();
343 344
    VLOG(3) << "LoadIntoMemory() read all lines, file=" << filename
            << ", cost time=" << timeline.ElapsedSec()
345
            << " seconds, thread_id=" << thread_id_;
346
  }
X
xujiaqi01 已提交
347
  VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_;
D
dongdaxiang 已提交
348
#endif
349 350
}

351
// explicit instantiation
J
jiaqi 已提交
352
template class InMemoryDataFeed<Record>;
353

W
Wang Guibao 已提交
354 355 356 357 358 359 360 361 362 363 364
void MultiSlotDataFeed::Init(
    const paddle::framework::DataFeedDesc& data_feed_desc) {
  finish_init_ = false;
  finish_set_filelist_ = false;
  finish_start_ = false;

  PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
                 "Multi_slot_desc has not been set.");
  paddle::framework::MultiSlotDesc multi_slot_desc =
      data_feed_desc.multi_slot_desc();
  SetBatchSize(data_feed_desc.batch_size());
J
jiaqi 已提交
365 366
  // temporarily set queue size = batch size * 100
  SetQueueSize(data_feed_desc.batch_size() * 100);
W
Wang Guibao 已提交
367 368 369 370
  size_t all_slot_num = multi_slot_desc.slots_size();
  all_slots_.resize(all_slot_num);
  all_slots_type_.resize(all_slot_num);
  use_slots_index_.resize(all_slot_num);
371 372
  total_dims_without_inductive_.resize(all_slot_num);
  inductive_shape_index_.resize(all_slot_num);
W
Wang Guibao 已提交
373 374 375 376 377 378 379
  use_slots_.clear();
  use_slots_is_dense_.clear();
  for (size_t i = 0; i < all_slot_num; ++i) {
    const auto& slot = multi_slot_desc.slots(i);
    all_slots_[i] = slot.name();
    all_slots_type_[i] = slot.type();
    use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
380 381
    total_dims_without_inductive_[i] = 1;
    inductive_shape_index_[i] = -1;
W
Wang Guibao 已提交
382 383 384
    if (slot.is_used()) {
      use_slots_.push_back(all_slots_[i]);
      use_slots_is_dense_.push_back(slot.is_dense());
385 386
      std::vector<int> local_shape;
      if (slot.is_dense()) {
387 388 389
        for (size_t j = 0; j < slot.shape_size(); ++j) {
          if (slot.shape(j) > 0) {
            total_dims_without_inductive_[i] *= slot.shape(j);
390
          }
391 392
          if (slot.shape(j) == -1) {
            inductive_shape_index_[i] = j;
393
          }
394 395
        }
      }
396 397
      for (size_t j = 0; j < slot.shape_size(); ++j) {
        local_shape.push_back(slot.shape(j));
398 399
      }
      use_slots_shape_.push_back(local_shape);
W
Wang Guibao 已提交
400 401 402
    }
  }
  feed_vec_.resize(use_slots_.size());
403
  pipe_command_ = data_feed_desc.pipe_command();
W
Wang Guibao 已提交
404 405 406
  finish_init_ = true;
}

D
dongdaxiang 已提交
407
void MultiSlotDataFeed::ReadThread() {
408
#ifdef _LINUX
409 410 411 412
  std::string filename;
  while (PickOneFile(&filename)) {
    int err_no = 0;
    fp_ = fs_open_read(filename, &err_no, pipe_command_);
D
dongdaxiang 已提交
413
    CHECK(fp_ != nullptr);
414 415 416 417 418
    __fsetlocking(&*fp_, FSETLOCKING_BYCALLER);
    std::vector<MultiSlotType> instance;
    int ins_num = 0;
    while (ParseOneInstanceFromPipe(&instance)) {
      ins_num++;
419
      queue_->Put(instance);
420
    }
D
dongdaxiang 已提交
421
    VLOG(3) << "filename: " << filename << " inst num: " << ins_num;
D
dongdaxiang 已提交
422
  }
423
  queue_->Close();
424
#endif
D
dongdaxiang 已提交
425 426
}

W
Wang Guibao 已提交
427
bool MultiSlotDataFeed::CheckFile(const char* filename) {
428
#ifdef _LINUX
W
Wang Guibao 已提交
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
  CheckInit();  // get info of slots
  std::ifstream fin(filename);
  if (!fin.good()) {
    VLOG(1) << "error: open file<" << filename << "> fail";
    return false;
  }
  std::string line;
  int instance_cout = 0;
  std::string all_slots_alias = "";
  for (const auto& alias : all_slots_) {
    all_slots_alias += alias + " ";
  }
  std::string use_slots_alias = "";
  for (const auto& alias : use_slots_) {
    use_slots_alias += alias + " ";
  }
  VLOG(3) << "total slots num: " << all_slots_.size();
  VLOG(3) << "total slots alias: " << all_slots_alias;
  VLOG(3) << "used slots num: " << use_slots_.size();
  VLOG(3) << "used slots alias: " << use_slots_alias;
  while (getline(fin, line)) {
    ++instance_cout;
    const char* str = line.c_str();
    char* endptr = const_cast<char*>(str);
    int len = line.length();
    for (size_t i = 0; i < all_slots_.size(); ++i) {
X
xjqbest 已提交
455
      auto num = strtol(endptr, &endptr, 10);
W
Wang Guibao 已提交
456
      if (num < 0) {
457 458
        VLOG(0) << "error: the number of ids is a negative number: " << num;
        VLOG(0) << "please check line<" << instance_cout << "> in file<"
W
Wang Guibao 已提交
459 460 461
                << filename << ">";
        return false;
      } else if (num == 0) {
462
        VLOG(0)
W
Wang Guibao 已提交
463 464 465 466
            << "error: the number of ids can not be zero, you need "
               "padding it in data generator; or if there is something wrong"
               " with the data, please check if the data contains unresolvable "
               "characters.";
467
        VLOG(0) << "please check line<" << instance_cout << "> in file<"
W
Wang Guibao 已提交
468 469
                << filename << ">";
        return false;
X
xjqbest 已提交
470
      } else if (errno == ERANGE || num > INT_MAX) {
471 472
        VLOG(0) << "error: the number of ids greater than INT_MAX";
        VLOG(0) << "please check line<" << instance_cout << "> in file<"
W
Wang Guibao 已提交
473 474 475 476 477 478 479
                << filename << ">";
        return false;
      }
      if (all_slots_type_[i] == "float") {
        for (int i = 0; i < num; ++i) {
          strtof(endptr, &endptr);
          if (errno == ERANGE) {
480
            VLOG(0) << "error: the value is out of the range of "
W
Wang Guibao 已提交
481
                       "representable values for float";
482
            VLOG(0) << "please check line<" << instance_cout << "> in file<"
W
Wang Guibao 已提交
483 484 485 486
                    << filename << ">";
            return false;
          }
          if (i + 1 != num && endptr - str == len) {
487 488
            VLOG(0) << "error: there is a wrong with the number of ids.";
            VLOG(0) << "please check line<" << instance_cout << "> in file<"
W
Wang Guibao 已提交
489 490 491 492 493 494 495 496
                    << filename << ">";
            return false;
          }
        }
      } else if (all_slots_type_[i] == "uint64") {
        for (int i = 0; i < num; ++i) {
          strtoull(endptr, &endptr, 10);
          if (errno == ERANGE) {
497
            VLOG(0) << "error: the value is out of the range of "
W
Wang Guibao 已提交
498
                       "representable values for uint64_t";
499
            VLOG(0) << "please check line<" << instance_cout << "> in file<"
W
Wang Guibao 已提交
500 501 502 503
                    << filename << ">";
            return false;
          }
          if (i + 1 != num && endptr - str == len) {
504 505
            VLOG(0) << "error: there is a wrong with the number of ids.";
            VLOG(0) << "please check line<" << instance_cout << "> in file<"
W
Wang Guibao 已提交
506 507 508 509 510
                    << filename << ">";
            return false;
          }
        }
      } else {
511
        VLOG(0) << "error: this type<" << all_slots_type_[i]
W
Wang Guibao 已提交
512 513 514 515
                << "> is not supported";
        return false;
      }
    }
516 517 518
    // It may be added '\t' character to the end of the output of reduce
    // task when processes data by Hadoop(when the output of the reduce
    // task of Hadoop has only one field, it will add a '\t' at the end
519 520 521 522 523
    // of the line by default, and you can use this option to avoid it:
    // `-D mapred.textoutputformat.ignoreseparator=true`), which does
    // not affect the correctness of the data. Therefore, it should be
    // judged that the data is not normal when the end of each line of
    // data contains characters which are not spaces.
524 525 526 527 528 529 530 531
    while (endptr - str != len) {
      if (!isspace(*(endptr++))) {
        VLOG(0)
            << "error: there is some extra characters at the end of the line.";
        VLOG(0) << "please check line<" << instance_cout << "> in file<"
                << filename << ">";
        return false;
      }
W
Wang Guibao 已提交
532 533 534 535
    }
  }
  VLOG(3) << "instances cout: " << instance_cout;
  VLOG(3) << "The file format is correct";
536
#endif
W
Wang Guibao 已提交
537 538 539
  return true;
}

D
dongdaxiang 已提交
540 541
bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
    std::vector<MultiSlotType>* instance) {
542
#ifdef _LINUX
543 544 545
  thread_local string::LineFileReader reader;

  if (!reader.getline(&*(fp_.get()))) {
D
dongdaxiang 已提交
546 547
    return false;
  } else {
548 549 550
    int use_slots_num = use_slots_.size();
    instance->resize(use_slots_num);

D
dongdaxiang 已提交
551 552
    const char* str = reader.get();
    std::string line = std::string(str);
553
    // VLOG(3) << line;
D
dongdaxiang 已提交
554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
    char* endptr = const_cast<char*>(str);
    int pos = 0;
    for (size_t i = 0; i < use_slots_index_.size(); ++i) {
      int idx = use_slots_index_[i];
      int num = strtol(&str[pos], &endptr, 10);
      PADDLE_ENFORCE(
          num,
          "The number of ids can not be zero, you need padding "
          "it in data generator; or if there is something wrong with "
          "the data, please check if the data contains unresolvable "
          "characters.\nplease check this error line: %s",
          str);
      if (idx != -1) {
        (*instance)[idx].Init(all_slots_type_[i]);
        if ((*instance)[idx].GetType()[0] == 'f') {  // float
          for (int j = 0; j < num; ++j) {
            float feasign = strtof(endptr, &endptr);
            (*instance)[idx].AddValue(feasign);
          }
        } else if ((*instance)[idx].GetType()[0] == 'u') {  // uint64
          for (int j = 0; j < num; ++j) {
            uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
            (*instance)[idx].AddValue(feasign);
          }
        }
        pos = endptr - str;
      } else {
        for (int j = 0; j <= num; ++j) {
D
dongdaxiang 已提交
582 583 584 585
          // pos = line.find_first_of(' ', pos + 1);
          while (line[pos + 1] != ' ') {
            pos++;
          }
D
dongdaxiang 已提交
586 587 588 589 590
        }
      }
    }
    return true;
  }
591 592 593
#else
  return true;
#endif
D
dongdaxiang 已提交
594 595
}

W
Wang Guibao 已提交
596
bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
X
xjqbest 已提交
597
#ifdef _LINUX
W
Wang Guibao 已提交
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
  std::string line;
  if (getline(file_, line)) {
    int use_slots_num = use_slots_.size();
    instance->resize(use_slots_num);
    // parse line
    const char* str = line.c_str();
    char* endptr = const_cast<char*>(str);
    int pos = 0;
    for (size_t i = 0; i < use_slots_index_.size(); ++i) {
      int idx = use_slots_index_[i];
      int num = strtol(&str[pos], &endptr, 10);
      PADDLE_ENFORCE(
          num,
          "The number of ids can not be zero, you need padding "
          "it in data generator; or if there is something wrong with "
          "the data, please check if the data contains unresolvable "
          "characters.\nplease check this error line: %s",
          str);
616

W
Wang Guibao 已提交
617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639
      if (idx != -1) {
        (*instance)[idx].Init(all_slots_type_[i]);
        if ((*instance)[idx].GetType()[0] == 'f') {  // float
          for (int j = 0; j < num; ++j) {
            float feasign = strtof(endptr, &endptr);
            (*instance)[idx].AddValue(feasign);
          }
        } else if ((*instance)[idx].GetType()[0] == 'u') {  // uint64
          for (int j = 0; j < num; ++j) {
            uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
            (*instance)[idx].AddValue(feasign);
          }
        }
        pos = endptr - str;
      } else {
        for (int j = 0; j <= num; ++j) {
          pos = line.find_first_of(' ', pos + 1);
        }
      }
    }
  } else {
    return false;
  }
X
xjqbest 已提交
640 641
#endif
  return false;
W
Wang Guibao 已提交
642 643 644 645 646
}

void MultiSlotDataFeed::AddInstanceToInsVec(
    std::vector<MultiSlotType>* ins_vec,
    const std::vector<MultiSlotType>& instance, int index) {
X
xjqbest 已提交
647
#ifdef _LINUX
W
Wang Guibao 已提交
648 649 650 651 652 653 654
  if (index == 0) {
    ins_vec->resize(instance.size());
    for (size_t i = 0; i < instance.size(); ++i) {
      (*ins_vec)[i].Init(instance[i].GetType());
      (*ins_vec)[i].InitOffset();
    }
  }
655

W
Wang Guibao 已提交
656 657 658
  for (size_t i = 0; i < instance.size(); ++i) {
    (*ins_vec)[i].AddIns(instance[i]);
  }
X
xjqbest 已提交
659
#endif
W
Wang Guibao 已提交
660 661 662 663
}

void MultiSlotDataFeed::PutToFeedVec(
    const std::vector<MultiSlotType>& ins_vec) {
X
xjqbest 已提交
664
#ifdef _LINUX
W
Wang Guibao 已提交
665
  for (size_t i = 0; i < use_slots_.size(); ++i) {
666 667 668
    if (feed_vec_[i] == nullptr) {
      continue;
    }
W
Wang Guibao 已提交
669 670 671
    const auto& type = ins_vec[i].GetType();
    const auto& offset = ins_vec[i].GetOffset();
    int total_instance = static_cast<int>(offset.back());
672

W
Wang Guibao 已提交
673 674
    if (type[0] == 'f') {  // float
      const auto& feasign = ins_vec[i].GetFloatData();
675 676 677
      float* tensor_ptr =
          feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
      CopyToFeedTensor(tensor_ptr, &feasign[0], total_instance * sizeof(float));
W
Wang Guibao 已提交
678 679 680
    } else if (type[0] == 'u') {  // uint64
      // no uint64_t type in paddlepaddle
      const auto& feasign = ins_vec[i].GetUint64Data();
681
      int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
682 683 684
          {total_instance, 1}, this->place_);
      CopyToFeedTensor(tensor_ptr, &feasign[0],
                       total_instance * sizeof(int64_t));
685
    }
686

687 688 689
    LoD data_lod{offset};
    feed_vec_[i]->set_lod(data_lod);
    if (use_slots_is_dense_[i]) {
690 691 692 693
      if (inductive_shape_index_[i] != -1) {
        use_slots_shape_[i][inductive_shape_index_[i]] =
            total_instance / total_dims_without_inductive_[i];
      }
694
      feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
W
Wang Guibao 已提交
695 696
    }
  }
X
xjqbest 已提交
697
#endif
W
Wang Guibao 已提交
698 699
}

700 701 702 703 704 705 706 707 708 709 710 711 712 713 714
void MultiSlotInMemoryDataFeed::Init(
    const paddle::framework::DataFeedDesc& data_feed_desc) {
  finish_init_ = false;
  finish_set_filelist_ = false;
  finish_start_ = false;

  PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
                 "Multi_slot_desc has not been set.");
  paddle::framework::MultiSlotDesc multi_slot_desc =
      data_feed_desc.multi_slot_desc();
  SetBatchSize(data_feed_desc.batch_size());
  size_t all_slot_num = multi_slot_desc.slots_size();
  all_slots_.resize(all_slot_num);
  all_slots_type_.resize(all_slot_num);
  use_slots_index_.resize(all_slot_num);
715 716
  total_dims_without_inductive_.resize(all_slot_num);
  inductive_shape_index_.resize(all_slot_num);
717 718 719 720 721 722 723
  use_slots_.clear();
  use_slots_is_dense_.clear();
  for (size_t i = 0; i < all_slot_num; ++i) {
    const auto& slot = multi_slot_desc.slots(i);
    all_slots_[i] = slot.name();
    all_slots_type_[i] = slot.type();
    use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
724 725
    total_dims_without_inductive_[i] = 1;
    inductive_shape_index_[i] = -1;
726 727 728
    if (slot.is_used()) {
      use_slots_.push_back(all_slots_[i]);
      use_slots_is_dense_.push_back(slot.is_dense());
729 730
      std::vector<int> local_shape;
      if (slot.is_dense()) {
731 732 733
        for (size_t j = 0; j < slot.shape_size(); ++j) {
          if (slot.shape(j) > 0) {
            total_dims_without_inductive_[i] *= slot.shape(j);
734
          }
735 736
          if (slot.shape(j) == -1) {
            inductive_shape_index_[i] = j;
737
          }
738 739
        }
      }
740 741
      for (size_t j = 0; j < slot.shape_size(); ++j) {
        local_shape.push_back(slot.shape(j));
742 743
      }
      use_slots_shape_.push_back(local_shape);
744 745 746 747 748 749 750
    }
  }
  feed_vec_.resize(use_slots_.size());
  pipe_command_ = data_feed_desc.pipe_command();
  finish_init_ = true;
}

J
jiaqi 已提交
751
bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
X
xjqbest 已提交
752
#ifdef _LINUX
753 754 755 756 757 758 759
  thread_local string::LineFileReader reader;

  if (!reader.getline(&*(fp_.get()))) {
    return false;
  } else {
    const char* str = reader.get();
    std::string line = std::string(str);
760
    // VLOG(3) << line;
761 762
    char* endptr = const_cast<char*>(str);
    int pos = 0;
763 764 765 766 767 768 769 770 771 772 773 774
    if (parse_ins_id_) {
      int num = strtol(&str[pos], &endptr, 10);
      CHECK(num == 1);  // NOLINT
      pos = endptr - str + 1;
      size_t len = 0;
      while (str[pos + len] != ' ') {
        ++len;
      }
      instance->ins_id_ = std::string(str + pos, len);
      pos += len + 1;
      VLOG(3) << "ins_id " << instance->ins_id_;
    }
775 776 777 778 779 780 781 782 783 784 785 786
    if (parse_content_) {
      int num = strtol(&str[pos], &endptr, 10);
      CHECK(num == 1);  // NOLINT
      pos = endptr - str + 1;
      size_t len = 0;
      while (str[pos + len] != ' ') {
        ++len;
      }
      instance->content_ = std::string(str + pos, len);
      pos += len + 1;
      VLOG(3) << "content " << instance->content_;
    }
787 788 789 790 791 792 793 794 795 796 797
    for (size_t i = 0; i < use_slots_index_.size(); ++i) {
      int idx = use_slots_index_[i];
      int num = strtol(&str[pos], &endptr, 10);
      PADDLE_ENFORCE(
          num,
          "The number of ids can not be zero, you need padding "
          "it in data generator; or if there is something wrong with "
          "the data, please check if the data contains unresolvable "
          "characters.\nplease check this error line: %s",
          str);
      if (idx != -1) {
J
jiaqi 已提交
798
        if (all_slots_type_[i][0] == 'f') {  // float
799 800
          for (int j = 0; j < num; ++j) {
            float feasign = strtof(endptr, &endptr);
J
jiaqi 已提交
801
            // if float feasign is equal to zero, ignore it
802 803
            // except when slot is dense
            if (fabs(feasign) < 1e-6 && !use_slots_is_dense_[i]) {
J
jiaqi 已提交
804 805 806 807 808
              continue;
            }
            FeatureKey f;
            f.float_feasign_ = feasign;
            instance->float_feasigns_.push_back(FeatureItem(f, idx));
809
          }
J
jiaqi 已提交
810
        } else if (all_slots_type_[i][0] == 'u') {  // uint64
811 812
          for (int j = 0; j < num; ++j) {
            uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
J
jiaqi 已提交
813
            // if uint64 feasign is equal to zero, ignore it
814 815
            // except when slot is dense
            if (feasign == 0 && !use_slots_is_dense_[i]) {
J
jiaqi 已提交
816 817 818 819 820
              continue;
            }
            FeatureKey f;
            f.uint64_feasign_ = feasign;
            instance->uint64_feasigns_.push_back(FeatureItem(f, idx));
821 822 823 824 825 826 827 828 829 830 831 832
          }
        }
        pos = endptr - str;
      } else {
        for (int j = 0; j <= num; ++j) {
          // pos = line.find_first_of(' ', pos + 1);
          while (line[pos + 1] != ' ') {
            pos++;
          }
        }
      }
    }
J
jiaqi 已提交
833 834
    instance->float_feasigns_.shrink_to_fit();
    instance->uint64_feasigns_.shrink_to_fit();
835 836
    return true;
  }
X
xjqbest 已提交
837 838 839
#else
  return false;
#endif
840 841
}

J
jiaqi 已提交
842
bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
X
xjqbest 已提交
843
#ifdef _LINUX
844 845
  std::string line;
  if (getline(file_, line)) {
846
    VLOG(3) << line;
847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
    // parse line
    const char* str = line.c_str();
    char* endptr = const_cast<char*>(str);
    int pos = 0;
    for (size_t i = 0; i < use_slots_index_.size(); ++i) {
      int idx = use_slots_index_[i];
      int num = strtol(&str[pos], &endptr, 10);
      PADDLE_ENFORCE(
          num,
          "The number of ids can not be zero, you need padding "
          "it in data generator; or if there is something wrong with "
          "the data, please check if the data contains unresolvable "
          "characters.\nplease check this error line: %s",
          str);

      if (idx != -1) {
J
jiaqi 已提交
863
        if (all_slots_type_[i][0] == 'f') {  // float
864 865
          for (int j = 0; j < num; ++j) {
            float feasign = strtof(endptr, &endptr);
J
jiaqi 已提交
866 867 868 869 870 871
            if (fabs(feasign) < 1e-6) {
              continue;
            }
            FeatureKey f;
            f.float_feasign_ = feasign;
            instance->float_feasigns_.push_back(FeatureItem(f, idx));
872
          }
J
jiaqi 已提交
873
        } else if (all_slots_type_[i][0] == 'u') {  // uint64
874 875
          for (int j = 0; j < num; ++j) {
            uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
J
jiaqi 已提交
876 877 878 879 880 881
            if (feasign == 0) {
              continue;
            }
            FeatureKey f;
            f.uint64_feasign_ = feasign;
            instance->uint64_feasigns_.push_back(FeatureItem(f, idx));
882 883 884 885 886 887 888 889 890
          }
        }
        pos = endptr - str;
      } else {
        for (int j = 0; j <= num; ++j) {
          pos = line.find_first_of(' ', pos + 1);
        }
      }
    }
J
jiaqi 已提交
891 892 893
    instance->float_feasigns_.shrink_to_fit();
    instance->uint64_feasigns_.shrink_to_fit();
    return true;
894 895 896
  } else {
    return false;
  }
X
xjqbest 已提交
897 898
#endif
  return false;
899 900
}

J
jiaqi 已提交
901 902
void MultiSlotInMemoryDataFeed::PutToFeedVec(
    const std::vector<Record>& ins_vec) {
X
xjqbest 已提交
903
#ifdef _LINUX
J
jiaqi 已提交
904 905 906 907 908 909 910
  std::vector<std::vector<float>> batch_float_feasigns(use_slots_.size(),
                                                       std::vector<float>());
  std::vector<std::vector<uint64_t>> batch_uint64_feasigns(
      use_slots_.size(), std::vector<uint64_t>());
  std::vector<std::vector<size_t>> offset(use_slots_.size(),
                                          std::vector<size_t>{0});
  std::vector<bool> visit(use_slots_.size(), false);
911 912 913 914
  ins_content_vec_.clear();
  ins_content_vec_.reserve(ins_vec.size());
  ins_id_vec_.clear();
  ins_id_vec_.reserve(ins_vec.size());
J
jiaqi 已提交
915 916
  for (size_t i = 0; i < ins_vec.size(); ++i) {
    auto& r = ins_vec[i];
917 918
    ins_id_vec_.push_back(r.ins_id_);
    ins_content_vec_.push_back(r.content_);
J
jiaqi 已提交
919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944
    for (auto& item : r.float_feasigns_) {
      batch_float_feasigns[item.slot()].push_back(item.sign().float_feasign_);
      visit[item.slot()] = true;
    }
    for (auto& item : r.uint64_feasigns_) {
      batch_uint64_feasigns[item.slot()].push_back(item.sign().uint64_feasign_);
      visit[item.slot()] = true;
    }
    for (size_t j = 0; j < use_slots_.size(); ++j) {
      const auto& type = all_slots_type_[j];
      if (visit[j]) {
        visit[j] = false;
      } else {
        // fill slot value with default value 0
        if (type[0] == 'f') {  // float
          batch_float_feasigns[j].push_back(0.0);
        } else if (type[0] == 'u') {  // uint64
          batch_uint64_feasigns[j].push_back(0);
        }
      }
      // get offset of this ins in this slot
      if (type[0] == 'f') {  // float
        offset[j].push_back(batch_float_feasigns[j].size());
      } else if (type[0] == 'u') {  // uint64
        offset[j].push_back(batch_uint64_feasigns[j].size());
      }
945 946 947 948
    }
  }

  for (size_t i = 0; i < use_slots_.size(); ++i) {
949 950 951
    if (feed_vec_[i] == nullptr) {
      continue;
    }
J
jiaqi 已提交
952 953
    int total_instance = offset[i].back();
    const auto& type = all_slots_type_[i];
954
    if (type[0] == 'f') {  // float
J
jiaqi 已提交
955
      float* feasign = batch_float_feasigns[i].data();
956 957 958
      float* tensor_ptr =
          feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
      CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float));
959 960
    } else if (type[0] == 'u') {  // uint64
      // no uint64_t type in paddlepaddle
J
jiaqi 已提交
961
      uint64_t* feasign = batch_uint64_feasigns[i].data();
962
      int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
963 964
          {total_instance, 1}, this->place_);
      CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t));
965
    }
J
jiaqi 已提交
966 967
    auto& slot_offset = offset[i];
    LoD data_lod{slot_offset};
968 969
    feed_vec_[i]->set_lod(data_lod);
    if (use_slots_is_dense_[i]) {
970 971 972 973
      if (inductive_shape_index_[i] != -1) {
        use_slots_shape_[i][inductive_shape_index_[i]] =
            total_instance / total_dims_without_inductive_[i];
      }
974
      feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
975 976
    }
  }
X
xjqbest 已提交
977
#endif
978 979
}

H
hutuxian 已提交
980 981 982 983 984 985 986 987 988 989
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
void PrivateInstantDataFeed<T>::PutToFeedVec() {
  for (size_t i = 0; i < use_slots_.size(); ++i) {
    const auto& type = ins_vec_[i].GetType();
    const auto& offset = ins_vec_[i].GetOffset();
    int total_instance = static_cast<int>(offset.back());

    if (type[0] == 'f') {  // float
      const auto& feasign = ins_vec_[i].GetFloatData();
990 991 992
      float* tensor_ptr =
          feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
      CopyToFeedTensor(tensor_ptr, &feasign[0], total_instance * sizeof(float));
H
hutuxian 已提交
993 994 995 996
    } else if (type[0] == 'u') {  // uint64
      // no uint64_t type in paddlepaddle
      const auto& feasign = ins_vec_[i].GetUint64Data();
      int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
997 998 999
          {total_instance, 1}, this->place_);
      CopyToFeedTensor(tensor_ptr, &feasign[0],
                       total_instance * sizeof(int64_t));
H
hutuxian 已提交
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180
    }

    LoD data_lod{offset};
    feed_vec_[i]->set_lod(data_lod);
    if (use_slots_is_dense_[i]) {
      int64_t total_dims = 1;
      for (const auto e : use_slots_shape_[i]) {
        total_dims *= e;
      }
      PADDLE_ENFORCE(
          total_dims == total_instance,
          "The actual data size of slot[%s] doesn't match its declaration",
          use_slots_[i].c_str());
      feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
    }
  }
}

template <typename T>
int PrivateInstantDataFeed<T>::Next() {
  if (ParseOneMiniBatch()) {
    PutToFeedVec();
    return ins_vec_[0].GetBatchSize();
  }
  Postprocess();

  std::string filename;
  if (!PickOneFile(&filename)) {
    return -1;
  }
  if (!Preprocess(filename)) {
    return -1;
  }

  PADDLE_ENFORCE(true == ParseOneMiniBatch(), "Fail to parse mini-batch data");
  PutToFeedVec();
  return ins_vec_[0].GetBatchSize();
}

template <typename T>
void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) {
  finish_init_ = false;
  finish_set_filelist_ = false;
  finish_start_ = false;

  PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
                 "Multi_slot_desc has not been set.");
  paddle::framework::MultiSlotDesc multi_slot_desc =
      data_feed_desc.multi_slot_desc();
  SetBatchSize(data_feed_desc.batch_size());
  size_t all_slot_num = multi_slot_desc.slots_size();
  all_slots_.resize(all_slot_num);
  all_slots_type_.resize(all_slot_num);
  use_slots_index_.resize(all_slot_num);
  multi_inductive_shape_index_.resize(all_slot_num);
  use_slots_.clear();
  use_slots_is_dense_.clear();
  for (size_t i = 0; i < all_slot_num; ++i) {
    const auto& slot = multi_slot_desc.slots(i);
    all_slots_[i] = slot.name();
    all_slots_type_[i] = slot.type();
    use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
    if (slot.is_used()) {
      use_slots_.push_back(all_slots_[i]);
      use_slots_is_dense_.push_back(slot.is_dense());
      std::vector<int> local_shape;
      if (slot.is_dense()) {
        for (size_t j = 0; j < slot.shape_size(); ++j) {
          if (slot.shape(j) == -1) {
            multi_inductive_shape_index_[i].push_back(j);
          }
        }
      }
      for (size_t j = 0; j < slot.shape_size(); ++j) {
        local_shape.push_back(slot.shape(j));
      }
      use_slots_shape_.push_back(local_shape);
    }
  }
  feed_vec_.resize(use_slots_.size());
  ins_vec_.resize(use_slots_.size());

  finish_init_ = true;
}

template class PrivateInstantDataFeed<std::vector<MultiSlotType>>;

bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
  fd_ = open(filename.c_str(), O_RDONLY);
  PADDLE_ENFORCE(fd_ != -1, "Fail to open file: %s", filename.c_str());

  struct stat sb;
  fstat(fd_, &sb);
  end_ = static_cast<size_t>(sb.st_size);

  buffer_ =
      reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0));
  PADDLE_ENFORCE(buffer_ != MAP_FAILED, strerror(errno));

  offset_ = 0;
  return true;
}

bool MultiSlotFileInstantDataFeed::Postprocess() {
  if (buffer_ != nullptr) {
    munmap(buffer_, end_);
    buffer_ = nullptr;
  }
  if (fd_ != -1) {
    close(fd_);
    fd_ = -1;
    end_ = 0;
    offset_ = 0;
  }
  return true;
}

bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
  if (offset_ == end_) {
    return false;
  }

  batch_size_ = 0;
  while (batch_size_ < default_batch_size_ && offset_ < end_) {
    for (size_t i = 0; i < use_slots_index_.size(); ++i) {
      int idx = use_slots_index_[i];
      char type = all_slots_type_[i][0];

      uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_);
      PADDLE_ENFORCE(
          num,
          "The number of ids can not be zero, you need padding "
          "it in data generator; or if there is something wrong with "
          "the data, please check if the data contains unresolvable "
          "characters.");
      offset_ += sizeof(uint16_t);

      if (idx != -1) {
        int inductive_size = multi_inductive_shape_index_[i].size();
        if (UNLIKELY(batch_size_ == 0)) {
          ins_vec_[idx].Init(all_slots_type_[i], default_batch_size_ * num);
          ins_vec_[idx].InitOffset(default_batch_size_);
          uint64_t* inductive_shape =
              reinterpret_cast<uint64_t*>(buffer_ + offset_);
          for (int inductive_id = 0; inductive_id < inductive_size;
               ++inductive_id) {
            use_slots_shape_[i][multi_inductive_shape_index_[i][inductive_id]] =
                static_cast<int>(*(inductive_shape + inductive_id));
          }
        }
        num -= inductive_size;
        offset_ += sizeof(uint64_t) * inductive_size;

        if (type == 'f') {
          ins_vec_[idx].AppendValues(
              reinterpret_cast<float*>(buffer_ + offset_), num);
          offset_ += num * sizeof(float);
        } else if (type == 'u') {
          ins_vec_[idx].AppendValues(
              reinterpret_cast<uint64_t*>(buffer_ + offset_), num);
          offset_ += num * sizeof(uint64_t);
        }
      } else {
        if (type == 'f') {
          offset_ += num * sizeof(float);
        } else if (type == 'u') {
          offset_ += num * sizeof(uint64_t);
        }
      }
    }
    ++batch_size_;
    // OPTIMIZE: It is better to insert check codes between instances for format
    // checking
  }

  PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_,
                 "offset_ != end_");
  return true;
}
#endif

W
Wang Guibao 已提交
1181 1182
}  // namespace framework
}  // namespace paddle