data_set.cc 22.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

15
#include "paddle/fluid/framework/data_set.h"
16
#include <algorithm>
D
dongdaxiang 已提交
17
#include <random>
18
#include <unordered_map>
19 20 21
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
22
#include "paddle/fluid/framework/data_feed_factory.h"
23
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
24
#include "paddle/fluid/framework/io/fs.h"
25
#include "paddle/fluid/platform/timer.h"
26
#include "xxhash.h"  // NOLINT
27

D
dongdaxiang 已提交
28 29 30 31 32
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif

33 34 35
namespace paddle {
namespace framework {

X
xjqbest 已提交
36
// constructor
37
template <typename T>
D
dongdaxiang 已提交
38
DatasetImpl<T>::DatasetImpl() {
J
jiaqi 已提交
39
  VLOG(3) << "DatasetImpl<T>::DatasetImpl() constructor";
D
dongdaxiang 已提交
40
  thread_num_ = 1;
41
  trainer_num_ = 1;
J
jiaqi 已提交
42
  channel_num_ = 1;
43
  file_idx_ = 0;
J
jiaqi 已提交
44 45 46
  cur_channel_ = 0;
  fleet_send_batch_size_ = 80000;
  fleet_send_sleep_seconds_ = 2;
47 48 49 50
  merge_by_insid_ = false;
  erase_duplicate_feas_ = true;
  keep_unmerged_ins_ = true;
  min_merge_size_ = 2;
D
dongdaxiang 已提交
51
}
52

X
xjqbest 已提交
53
// set filelist, file_idx_ will reset to zero.
54 55
template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
56
  VLOG(3) << "filelist size: " << filelist.size();
57
  filelist_ = filelist;
58
  file_idx_ = 0;
59 60
}

X
xjqbest 已提交
61
// set expect thread num. actually it may change
62 63
template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) {
64
  VLOG(3) << "SetThreadNum thread_num=" << thread_num;
65 66 67
  thread_num_ = thread_num;
}

X
xjqbest 已提交
68 69 70
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerNum
71
template <typename T>
X
xujiaqi01 已提交
72 73
void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
  trainer_num_ = trainer_num;
74 75
}

X
xjqbest 已提交
76 77 78 79 80 81 82 83
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetFleetSendBatchSize
template <typename T>
void DatasetImpl<T>::SetFleetSendBatchSize(int64_t size) {
  fleet_send_batch_size_ = size;
}

84 85 86
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
                                   const std::string& fs_ugi) {
X
xjqbest 已提交
87 88
  fs_name_ = fs_name;
  fs_ugi_ = fs_ugi;
89 90 91 92
  std::string cmd = std::string("hadoop fs");
  cmd += " -D fs.default.name=" + fs_name;
  cmd += " -D hadoop.job.ugi=" + fs_ugi;
  paddle::framework::hdfs_set_command(cmd);
X
xujiaqi01 已提交
93
}
94

95 96
template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
97 98
  google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
                                                &data_feed_desc_);
99 100
}

101
template <typename T>
J
jiaqi 已提交
102 103 104 105
void DatasetImpl<T>::SetChannelNum(int channel_num) {
  channel_num_ = channel_num;
}

106 107 108 109 110 111 112 113 114 115 116
template <typename T>
void DatasetImpl<T>::SetMergeByInsId(
    const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas,
    int min_merge_size, bool keep_unmerged_ins) {
  merge_by_insid_ = true;
  merge_slots_list_ = merge_slot_list;
  erase_duplicate_feas_ = erase_duplicate_feas;
  min_merge_size_ = min_merge_size;
  keep_unmerged_ins_ = keep_unmerged_ins;
}

J
jiaqi 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
template <typename T>
std::vector<paddle::framework::DataFeed*> DatasetImpl<T>::GetReaders() {
  std::vector<paddle::framework::DataFeed*> ret;
  ret.reserve(readers_.size());
  for (auto i : readers_) {
    ret.push_back(i.get());
  }
  return ret;
}

template <typename T>
void DatasetImpl<T>::CreateChannel() {
  if (input_channel_ == nullptr) {
    input_channel_ = paddle::framework::MakeChannel<T>();
  }
  if (multi_output_channel_.size() == 0) {
    multi_output_channel_.reserve(channel_num_);
    for (int i = 0; i < channel_num_; ++i) {
      multi_output_channel_.push_back(paddle::framework::MakeChannel<T>());
    }
  }
  if (multi_consume_channel_.size() == 0) {
    multi_consume_channel_.reserve(channel_num_);
    for (int i = 0; i < channel_num_; ++i) {
      multi_consume_channel_.push_back(paddle::framework::MakeChannel<T>());
    }
  }
144 145
}

146 147 148 149 150 151 152 153 154 155 156 157
// if sent message between workers, should first call this function
template <typename T>
void DatasetImpl<T>::RegisterClientToClientMsgHandler() {
  auto fleet_ptr = FleetWrapper::GetInstance();
  VLOG(3) << "RegisterClientToClientMsgHandler";
  fleet_ptr->RegisterClientToClientMsgHandler(
      0, [this](int msg_type, int client_id, const std::string& msg) -> int {
        return this->ReceiveFromClient(msg_type, client_id, msg);
      });
  VLOG(3) << "RegisterClientToClientMsgHandler done";
}

X
xjqbest 已提交
158 159
// load data into memory, Dataset hold this memory,
// which will later be fed into readers' channel
160 161 162
template <typename T>
void DatasetImpl<T>::LoadIntoMemory() {
  VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
163 164
  platform::Timer timeline;
  timeline.Start();
165 166
  std::vector<std::thread> load_threads;
  for (int64_t i = 0; i < thread_num_; ++i) {
D
dongdaxiang 已提交
167 168
    load_threads.push_back(std::thread(
        &paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
169 170 171 172
  }
  for (std::thread& t : load_threads) {
    t.join();
  }
J
jiaqi 已提交
173 174 175
  input_channel_->Close();
  int64_t in_chan_size = input_channel_->Size();
  input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1);
176 177
  timeline.Pause();
  VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
J
jiaqi 已提交
178
          << ", memory data size=" << input_channel_->Size()
179
          << ", cost time=" << timeline.ElapsedSec() << " seconds";
180 181
}

J
jiaqi 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
template <typename T>
void DatasetImpl<T>::PreLoadIntoMemory() {
  VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() begin";
  preload_threads_.clear();
  for (int64_t i = 0; i < thread_num_; ++i) {
    preload_threads_.push_back(std::thread(
        &paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get()));
  }
  VLOG(3) << "DatasetImpl<T>::PreLoadIntoMemory() end";
}

template <typename T>
void DatasetImpl<T>::WaitPreLoadDone() {
  VLOG(3) << "DatasetImpl<T>::WaitPreLoadDone() begin";
  for (std::thread& t : preload_threads_) {
    t.join();
  }
  input_channel_->Close();
  int64_t in_chan_size = input_channel_->Size();
  input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1);
  VLOG(3) << "DatasetImpl<T>::WaitPreLoadDone() end";
}

205 206 207 208
// release memory data
template <typename T>
void DatasetImpl<T>::ReleaseMemory() {
  VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
J
jiaqi 已提交
209 210 211 212 213 214 215 216 217 218
  if (input_channel_) {
    input_channel_->Clear();
    input_channel_ = nullptr;
  }
  for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
    if (!multi_output_channel_[i]) {
      continue;
    }
    multi_output_channel_[i]->Clear();
    multi_output_channel_[i] = nullptr;
219
  }
J
jiaqi 已提交
220 221 222 223 224 225 226 227 228 229
  std::vector<paddle::framework::Channel<T>>().swap(multi_output_channel_);
  for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
    if (!multi_consume_channel_[i]) {
      continue;
    }
    multi_consume_channel_[i]->Clear();
    multi_consume_channel_[i] = nullptr;
  }
  std::vector<paddle::framework::Channel<T>>().swap(multi_consume_channel_);
  std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
230 231 232
  VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}

X
xjqbest 已提交
233
// do local shuffle
234 235 236
template <typename T>
void DatasetImpl<T>::LocalShuffle() {
  VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
237 238
  platform::Timer timeline;
  timeline.Start();
239

J
jiaqi 已提交
240 241 242
  if (!input_channel_ || input_channel_->Size() == 0) {
    VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, no data to shuffle";
    return;
243
  }
J
jiaqi 已提交
244 245 246 247 248 249 250 251 252 253 254
  auto fleet_ptr = FleetWrapper::GetInstance();
  input_channel_->Close();
  std::vector<T> data;
  input_channel_->ReadAll(data);
  std::shuffle(data.begin(), data.end(), fleet_ptr->LocalRandomEngine());
  input_channel_->Open();
  input_channel_->Write(std::move(data));
  data.clear();
  data.shrink_to_fit();
  input_channel_->Close();

255 256 257
  timeline.Pause();
  VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, cost time="
          << timeline.ElapsedSec() << " seconds";
258 259
}

260 261 262
template <typename T>
void DatasetImpl<T>::GlobalShuffle() {
  VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
263 264
  platform::Timer timeline;
  timeline.Start();
265
  auto fleet_ptr = FleetWrapper::GetInstance();
J
jiaqi 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286

  if (!input_channel_ || input_channel_->Size() == 0) {
    VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, no data to shuffle";
    return;
  }

  // local shuffle
  input_channel_->Close();
  std::vector<T> data;
  input_channel_->ReadAll(data);
  std::shuffle(data.begin(), data.end(), fleet_ptr->LocalRandomEngine());
  input_channel_->Open();
  input_channel_->Write(std::move(data));
  data.clear();
  data.shrink_to_fit();

  input_channel_->Close();
  input_channel_->SetBlockSize(fleet_send_batch_size_);
  VLOG(3) << "DatasetImpl<T>::GlobalShuffle() input_channel_ size "
          << input_channel_->Size();

287 288 289 290 291 292 293 294 295 296
  auto get_client_id = [this, fleet_ptr](const T& data) -> size_t {
    if (!this->merge_by_insid_) {
      return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
    } else {
      return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) %
             this->trainer_num_;
    }
  };

  auto global_shuffle_func = [this, get_client_id]() {
J
jiaqi 已提交
297 298 299 300 301
    auto fleet_ptr = FleetWrapper::GetInstance();
    std::vector<T> data;
    while (this->input_channel_->Read(data)) {
      std::vector<paddle::framework::BinaryArchive> ars(this->trainer_num_);
      for (auto& t : data) {
302
        auto client_id = get_client_id(t);
J
jiaqi 已提交
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
        ars[client_id] << t;
      }
      std::vector<std::future<int32_t>> total_status;
      std::vector<int> send_index(this->trainer_num_);
      for (int i = 0; i < this->trainer_num_; ++i) {
        send_index[i] = i;
      }
      std::shuffle(send_index.begin(), send_index.end(),
                   fleet_ptr->LocalRandomEngine());
      for (auto index = 0u; index < this->trainer_num_; ++index) {
        int i = send_index[index];
        if (ars[i].Length() == 0) {
          continue;
        }
        std::string msg(ars[i].Buffer(), ars[i].Length());
        auto ret = fleet_ptr->SendClientToClientMsg(0, i, msg);
        total_status.push_back(std::move(ret));
      }
      for (auto& t : total_status) {
        t.wait();
      }
      ars.clear();
      ars.shrink_to_fit();
      data.clear();
      data.shrink_to_fit();
      sleep(this->fleet_send_sleep_seconds_);
    }
  };

X
xujiaqi01 已提交
332
  VLOG(3) << "start global shuffle threads";
333
  std::vector<std::thread> global_shuffle_threads;
334
  for (int i = 0; i < thread_num_; ++i) {
J
jiaqi 已提交
335
    global_shuffle_threads.push_back(std::thread(global_shuffle_func));
336 337 338
  }
  for (std::thread& t : global_shuffle_threads) {
    t.join();
339
  }
J
jiaqi 已提交
340 341 342
  global_shuffle_threads.clear();
  global_shuffle_threads.shrink_to_fit();
  input_channel_->Clear();
343 344 345
  timeline.Pause();
  VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
          << timeline.ElapsedSec() << " seconds";
346 347
}

348 349
template <typename T>
void DatasetImpl<T>::CreateReaders() {
350
  VLOG(3) << "Calling CreateReaders()";
J
jiaqi 已提交
351 352 353 354 355 356
  VLOG(3) << "thread num in Dataset: " << thread_num_;
  VLOG(3) << "Filelist size in Dataset: " << filelist_.size();
  VLOG(3) << "channel num in Dataset: " << channel_num_;
  CHECK(thread_num_ > 0) << "thread num should > 0";
  CHECK(channel_num_ > 0) << "channel num should > 0";
  CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num";
357
  VLOG(3) << "readers size: " << readers_.size();
358
  if (readers_.size() != 0) {
J
jiaqi 已提交
359 360
    VLOG(3) << "readers_.size() = " << readers_.size()
            << ", will not create again";
361 362
    return;
  }
363
  VLOG(3) << "data feed class name: " << data_feed_desc_.name();
J
jiaqi 已提交
364
  int channel_idx = 0;
365
  for (int i = 0; i < thread_num_; ++i) {
366
    readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
J
jiaqi 已提交
367 368 369 370 371 372
    readers_[i]->Init(data_feed_desc_);
    readers_[i]->SetThreadId(i);
    readers_[i]->SetThreadNum(thread_num_);
    readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
    readers_[i]->SetFileListIndex(&file_idx_);
    readers_[i]->SetFileList(filelist_);
373
    readers_[i]->SetParseInsId(merge_by_insid_);
J
jiaqi 已提交
374 375 376 377 378 379 380 381 382 383 384 385 386 387
    if (input_channel_ != nullptr) {
      readers_[i]->SetInputChannel(input_channel_.get());
    }
    if (cur_channel_ == 0 && channel_idx < multi_output_channel_.size()) {
      readers_[i]->SetOutputChannel(multi_output_channel_[channel_idx].get());
      readers_[i]->SetConsumeChannel(multi_consume_channel_[channel_idx].get());
    } else if (channel_idx < multi_output_channel_.size()) {
      readers_[i]->SetOutputChannel(multi_consume_channel_[channel_idx].get());
      readers_[i]->SetConsumeChannel(multi_output_channel_[channel_idx].get());
    }
    ++channel_idx;
    if (channel_idx >= channel_num_) {
      channel_idx = 0;
    }
388
  }
J
jiaqi 已提交
389
  VLOG(3) << "readers size: " << readers_.size();
390 391
}

392 393 394
template <typename T>
void DatasetImpl<T>::DestroyReaders() {
  VLOG(3) << "Calling DestroyReaders()";
395
  VLOG(3) << "readers size1: " << readers_.size();
396
  std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
397
  VLOG(3) << "readers size: " << readers_.size();
J
jiaqi 已提交
398 399
  file_idx_ = 0;
  cur_channel_ = 1 - cur_channel_;
400 401
}

402 403
template <typename T>
int64_t DatasetImpl<T>::GetMemoryDataSize() {
J
jiaqi 已提交
404
  return input_channel_->Size();
405 406 407 408 409
}

template <typename T>
int64_t DatasetImpl<T>::GetShuffleDataSize() {
  int64_t sum = 0;
J
jiaqi 已提交
410 411
  for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
    sum += multi_output_channel_[i]->Size() + multi_consume_channel_[i]->Size();
412 413 414 415
  }
  return sum;
}

416 417
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
D
dongdaxiang 已提交
418
                                      const std::string& msg) {
D
dongdaxiang 已提交
419
#ifdef _LINUX
420
  VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
421
          << ", client_id=" << client_id << ", msg length=" << msg.length();
J
jiaqi 已提交
422 423 424 425 426 427 428 429 430 431 432 433 434 435
  if (msg.length() == 0) {
    return 0;
  }
  paddle::framework::BinaryArchive ar;
  ar.SetReadBuffer(const_cast<char*>(msg.c_str()), msg.length(), nullptr);
  if (ar.Cursor() == ar.Finish()) {
    return 0;
  }
  std::vector<T> data;
  while (ar.Cursor() < ar.Finish()) {
    data.push_back(ar.Get<T>());
  }
  CHECK(ar.Cursor() == ar.Finish());

436
  auto fleet_ptr = FleetWrapper::GetInstance();
J
jiaqi 已提交
437
  int64_t index = fleet_ptr->LocalRandomEngine()() % channel_num_;
438
  VLOG(3) << "ramdom index=" << index;
J
jiaqi 已提交
439 440 441 442
  multi_output_channel_[index]->Write(std::move(data));

  data.clear();
  data.shrink_to_fit();
D
dongdaxiang 已提交
443
#endif
444 445 446
  return 0;
}

447
// explicit instantiation
J
jiaqi 已提交
448
template class DatasetImpl<Record>;
449

450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
void MultiSlotDataset::MergeByInsId() {
  VLOG(3) << "MultiSlotDataset::MergeByInsId begin";
  if (!merge_by_insid_) {
    VLOG(3) << "merge_by_insid=false, will not MergeByInsId";
    return;
  }
  auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
  std::unordered_map<int, bool> merge_slots;
  std::vector<std::string> use_slots;
  std::vector<bool> use_slots_is_dense;
  for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) {
    const auto& slot = multi_slot_desc.slots(i);
    if (slot.is_used()) {
      use_slots.push_back(slot.name());
      use_slots_is_dense.push_back(slot.is_dense());
    }
  }
  for (size_t i = 0; i < use_slots.size(); ++i) {
    // currently, we don't merge dense slots
    if (std::find(merge_slots_list_.begin(), merge_slots_list_.end(),
                  use_slots[i]) != merge_slots_list_.end() &&
        !use_slots_is_dense[i]) {
      merge_slots[i] = true;
    }
  }
  CHECK(multi_output_channel_.size() != 0);  // NOLINT
  auto channel_data = paddle::framework::MakeChannel<Record>();
  VLOG(3) << "multi_output_channel_.size() " << multi_output_channel_.size();
  for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
    std::vector<Record> vec_data;
    multi_output_channel_[i]->Close();
    multi_output_channel_[i]->ReadAll(vec_data);
    channel_data->Write(std::move(vec_data));
    vec_data.clear();
    vec_data.shrink_to_fit();
    multi_output_channel_[i]->Clear();
  }
  channel_data->Close();
  std::vector<Record> recs;
  recs.reserve(channel_data->Size());
  channel_data->ReadAll(recs);
  channel_data->Clear();
  std::sort(recs.begin(), recs.end(), [](const Record& a, const Record& b) {
    return a.ins_id_ < b.ins_id_;
  });

  auto sort_cmp_uint64 = [&merge_slots](const FeatureItem& a,
                                        const FeatureItem& b) {
    auto& a_sign = a.sign().uint64_feasign_;
    auto& b_sign = b.sign().uint64_feasign_;
    return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
  };
  auto sort_cmp_float = [&merge_slots](const FeatureItem& a,
                                       const FeatureItem& b) {
    auto& a_sign = a.sign().float_feasign_;
    auto& b_sign = b.sign().float_feasign_;
    return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
  };
  auto unique_eq_uint64 = [&merge_slots](const FeatureItem& a,
                                         const FeatureItem& b) {
    if (a.slot() == b.slot() &&
        merge_slots.find(a.slot()) == merge_slots.end()) {
      return true;
    }
    auto& a_sign = a.sign().uint64_feasign_;
    auto& b_sign = b.sign().uint64_feasign_;
    return a_sign == b_sign && a.slot() == b.slot();
  };
  auto unique_eq_float = [&merge_slots](const FeatureItem& a,
                                        const FeatureItem& b) {
    if (a.slot() == b.slot() &&
        merge_slots.find(a.slot()) == merge_slots.end()) {
      return true;
    }
    auto& a_sign = a.sign().float_feasign_;
    auto& b_sign = b.sign().float_feasign_;
    return a_sign == b_sign && a.slot() == b.slot();
  };

  std::vector<Record> results;
  VLOG(3) << "recs.size() " << recs.size();
  for (size_t i = 0; i < recs.size();) {
    size_t j = i + 1;
    while (j < recs.size() && recs[j].ins_id_ == recs[i].ins_id_) {
      j++;
    }
    if (j - i < min_merge_size_) {
      if (keep_unmerged_ins_) {
        for (size_t k = i; k < j; ++k) {
          results.push_back(std::move(recs[k]));
        }
      }
      i = j;
      continue;
    }

    std::vector<FeatureItem> merge_uint64_feasigns;
    std::vector<FeatureItem> merge_float_feasigns;
    Record rec = std::move(recs[i]);

    for (size_t k = i + 1; k < j; k++) {
      for (auto& feature : recs[k].uint64_feasigns_) {
        if (merge_slots.find(feature.slot()) != merge_slots.end()) {
          merge_uint64_feasigns.push_back(std::move(feature));
        }
      }
      for (auto& feature : recs[k].float_feasigns_) {
        if (merge_slots.find(feature.slot()) != merge_slots.end()) {
          merge_float_feasigns.push_back(std::move(feature));
        }
      }
      recs[k] = Record();
    }
    i = j;

    if (!erase_duplicate_feas_) {
      rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
                                  merge_uint64_feasigns.begin(),
                                  merge_uint64_feasigns.end());
      rec.float_feasigns_.insert(rec.float_feasigns_.end(),
                                 merge_float_feasigns.begin(),
                                 merge_float_feasigns.end());
    } else {
      std::vector<FeatureItem> not_merge_uint64_feasigns;
      std::vector<FeatureItem> not_merge_float_feasigns;

      for (auto& feature : rec.uint64_feasigns_) {
        if (merge_slots.find(feature.slot()) != merge_slots.end()) {
          merge_uint64_feasigns.push_back(std::move(feature));
        } else {
          not_merge_uint64_feasigns.push_back(std::move(feature));
        }
      }
      for (auto& feature : rec.float_feasigns_) {
        if (merge_slots.find(feature.slot()) != merge_slots.end()) {
          merge_float_feasigns.push_back(std::move(feature));
        } else {
          not_merge_float_feasigns.push_back(std::move(feature));
        }
      }
      rec.uint64_feasigns_.clear();
      rec.float_feasigns_.clear();

      // erase duplicate uint64 feasigns
      std::sort(merge_uint64_feasigns.begin(), merge_uint64_feasigns.end(),
                sort_cmp_uint64);
      merge_uint64_feasigns.erase(
          std::unique(merge_uint64_feasigns.begin(),
                      merge_uint64_feasigns.end(), unique_eq_uint64),
          merge_uint64_feasigns.end());
      rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
                                  merge_uint64_feasigns.begin(),
                                  merge_uint64_feasigns.end());
      rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
                                  not_merge_uint64_feasigns.begin(),
                                  not_merge_uint64_feasigns.end());

      // erase duplicate float feasigns
      std::sort(merge_float_feasigns.begin(), merge_float_feasigns.end(),
                sort_cmp_float);
      merge_float_feasigns.erase(
          std::unique(merge_float_feasigns.begin(), merge_float_feasigns.end(),
                      unique_eq_float),
          merge_float_feasigns.end());
      rec.float_feasigns_.insert(rec.float_feasigns_.end(),
                                 merge_float_feasigns.begin(),
                                 merge_float_feasigns.end());
      rec.float_feasigns_.insert(rec.float_feasigns_.end(),
                                 not_merge_float_feasigns.begin(),
                                 not_merge_float_feasigns.end());
    }
    results.push_back(rec);
  }
  VLOG(3) << "results size " << results.size();
  results.shrink_to_fit();

  auto fleet_ptr = FleetWrapper::GetInstance();
  std::shuffle(results.begin(), results.end(), fleet_ptr->LocalRandomEngine());
  channel_data->Open();
  channel_data->Write(std::move(results));
  channel_data->Close();
  results.clear();
  results.shrink_to_fit();
  VLOG(3) << "channel data size " << channel_data->Size();
  channel_data->SetBlockSize(channel_data->Size() / channel_num_ + 1);
  VLOG(3) << "channel data block size " << channel_data->BlockSize();
  for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
    std::vector<Record> vec_data;
    channel_data->Read(vec_data);
    multi_output_channel_[i]->Open();
    multi_output_channel_[i]->Write(std::move(vec_data));
    vec_data.clear();
    vec_data.shrink_to_fit();
  }
  CHECK(channel_data->Size() == 0);  // NOLINT
  channel_data->Clear();
  VLOG(3) << "MultiSlotDataset::MergeByInsId end";
}

D
dongdaxiang 已提交
649 650
}  // end namespace framework
}  // end namespace paddle