/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "paddle/fluid/framework/data_set.h" #include #include #include #include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" #include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/platform/timer.h" #include "xxhash.h" // NOLINT #if defined _WIN32 || defined __APPLE__ #else #define _LINUX #endif namespace paddle { namespace framework { // constructor template DatasetImpl::DatasetImpl() { VLOG(3) << "DatasetImpl::DatasetImpl() constructor"; thread_num_ = 1; trainer_num_ = 1; channel_num_ = 1; file_idx_ = 0; cur_channel_ = 0; fleet_send_batch_size_ = 1024; fleet_send_sleep_seconds_ = 0; merge_by_insid_ = false; merge_size_ = 2; parse_ins_id_ = false; parse_content_ = false; preload_thread_num_ = 0; global_index_ = 0; } // set filelist, file_idx_ will reset to zero. template void DatasetImpl::SetFileList(const std::vector& filelist) { VLOG(3) << "filelist size: " << filelist.size(); filelist_ = filelist; file_idx_ = 0; } // set expect thread num. actually it may change template void DatasetImpl::SetThreadNum(int thread_num) { VLOG(3) << "SetThreadNum thread_num=" << thread_num; thread_num_ = thread_num; } // if you run distributed, and want to do global shuffle, // set this before global shuffle. // be sure you call CreateReaders before SetTrainerNum template void DatasetImpl::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } // if you run distributed, and want to do global shuffle, // set this before global shuffle. // be sure you call CreateReaders before SetFleetSendBatchSize template void DatasetImpl::SetFleetSendBatchSize(int64_t size) { fleet_send_batch_size_ = size; } template void DatasetImpl::SetHdfsConfig(const std::string& fs_name, const std::string& fs_ugi) { fs_name_ = fs_name; fs_ugi_ = fs_ugi; std::string cmd = std::string("hadoop fs"); cmd += " -D fs.default.name=" + fs_name; cmd += " -D hadoop.job.ugi=" + fs_ugi; paddle::framework::hdfs_set_command(cmd); } template void DatasetImpl::SetDataFeedDesc(const std::string& data_feed_desc_str) { google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, &data_feed_desc_); } template void DatasetImpl::SetChannelNum(int channel_num) { channel_num_ = channel_num; } template void DatasetImpl::SetParseInsId(bool parse_ins_id) { parse_ins_id_ = parse_ins_id; } template void DatasetImpl::SetParseContent(bool parse_content) { parse_content_ = parse_content; } template void DatasetImpl::SetMergeByInsId(int merge_size) { merge_by_insid_ = true; parse_ins_id_ = true; merge_size_ = merge_size; } template void DatasetImpl::SetFeaEval(bool fea_eval, int record_candidate_size) { slots_shuffle_fea_eval_ = fea_eval; slots_shuffle_rclist_.ReSize(record_candidate_size); VLOG(3) << "SetFeaEval fea eval mode: " << fea_eval << " with record candidate size: " << record_candidate_size; } template std::vector DatasetImpl::GetReaders() { std::vector ret; ret.reserve(readers_.size()); for (auto i : readers_) { ret.push_back(i.get()); } return ret; } template void DatasetImpl::CreateChannel() { if (input_channel_ == nullptr) { input_channel_ = paddle::framework::MakeChannel(); } if (multi_output_channel_.size() == 0) { multi_output_channel_.reserve(channel_num_); for (int i = 0; i < channel_num_; ++i) { multi_output_channel_.push_back(paddle::framework::MakeChannel()); } } if (multi_consume_channel_.size() == 0) { multi_consume_channel_.reserve(channel_num_); for (int i = 0; i < channel_num_; ++i) { multi_consume_channel_.push_back(paddle::framework::MakeChannel()); } } } // if sent message between workers, should first call this function template void DatasetImpl::RegisterClientToClientMsgHandler() { auto fleet_ptr = FleetWrapper::GetInstance(); VLOG(3) << "RegisterClientToClientMsgHandler"; fleet_ptr->RegisterClientToClientMsgHandler( 0, [this](int msg_type, int client_id, const std::string& msg) -> int { return this->ReceiveFromClient(msg_type, client_id, msg); }); VLOG(3) << "RegisterClientToClientMsgHandler done"; } // load data into memory, Dataset hold this memory, // which will later be fed into readers' channel template void DatasetImpl::LoadIntoMemory() { VLOG(3) << "DatasetImpl::LoadIntoMemory() begin"; platform::Timer timeline; timeline.Start(); std::vector load_threads; for (int64_t i = 0; i < thread_num_; ++i) { load_threads.push_back(std::thread( &paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get())); } for (std::thread& t : load_threads) { t.join(); } input_channel_->Close(); int64_t in_chan_size = input_channel_->Size(); input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1); timeline.Pause(); VLOG(3) << "DatasetImpl::LoadIntoMemory() end" << ", memory data size=" << input_channel_->Size() << ", cost time=" << timeline.ElapsedSec() << " seconds"; } template void DatasetImpl::PreLoadIntoMemory() { VLOG(3) << "DatasetImpl::PreLoadIntoMemory() begin"; if (preload_thread_num_ != 0) { CHECK(preload_thread_num_ == preload_readers_.size()); preload_threads_.clear(); for (int64_t i = 0; i < preload_thread_num_; ++i) { preload_threads_.push_back( std::thread(&paddle::framework::DataFeed::LoadIntoMemory, preload_readers_[i].get())); } } else { CHECK(thread_num_ == readers_.size()); preload_threads_.clear(); for (int64_t i = 0; i < thread_num_; ++i) { preload_threads_.push_back(std::thread( &paddle::framework::DataFeed::LoadIntoMemory, readers_[i].get())); } } VLOG(3) << "DatasetImpl::PreLoadIntoMemory() end"; } template void DatasetImpl::WaitPreLoadDone() { VLOG(3) << "DatasetImpl::WaitPreLoadDone() begin"; for (std::thread& t : preload_threads_) { t.join(); } input_channel_->Close(); int64_t in_chan_size = input_channel_->Size(); input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1); VLOG(3) << "DatasetImpl::WaitPreLoadDone() end"; } // release memory data template void DatasetImpl::ReleaseMemory() { VLOG(3) << "DatasetImpl::ReleaseMemory() begin"; if (input_channel_) { input_channel_->Clear(); input_channel_ = nullptr; } for (size_t i = 0; i < multi_output_channel_.size(); ++i) { if (!multi_output_channel_[i]) { continue; } multi_output_channel_[i]->Clear(); multi_output_channel_[i] = nullptr; } std::vector>().swap(multi_output_channel_); for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { if (!multi_consume_channel_[i]) { continue; } multi_consume_channel_[i]->Clear(); multi_consume_channel_[i] = nullptr; } std::vector>().swap(multi_consume_channel_); std::vector>().swap(readers_); VLOG(3) << "DatasetImpl::ReleaseMemory() end"; } // do local shuffle template void DatasetImpl::LocalShuffle() { VLOG(3) << "DatasetImpl::LocalShuffle() begin"; platform::Timer timeline; timeline.Start(); if (!input_channel_ || input_channel_->Size() == 0) { VLOG(3) << "DatasetImpl::LocalShuffle() end, no data to shuffle"; return; } auto fleet_ptr = FleetWrapper::GetInstance(); input_channel_->Close(); std::vector data; input_channel_->ReadAll(data); std::shuffle(data.begin(), data.end(), fleet_ptr->LocalRandomEngine()); input_channel_->Open(); input_channel_->Write(std::move(data)); data.clear(); data.shrink_to_fit(); input_channel_->Close(); timeline.Pause(); VLOG(3) << "DatasetImpl::LocalShuffle() end, cost time=" << timeline.ElapsedSec() << " seconds"; } template void DatasetImpl::GlobalShuffle(int thread_num) { VLOG(3) << "DatasetImpl::GlobalShuffle() begin"; platform::Timer timeline; timeline.Start(); auto fleet_ptr = FleetWrapper::GetInstance(); if (!input_channel_ || input_channel_->Size() == 0) { VLOG(3) << "DatasetImpl::GlobalShuffle() end, no data to shuffle"; return; } // local shuffle input_channel_->Close(); std::vector data; input_channel_->ReadAll(data); std::shuffle(data.begin(), data.end(), fleet_ptr->LocalRandomEngine()); input_channel_->Open(); input_channel_->Write(std::move(data)); data.clear(); data.shrink_to_fit(); input_channel_->Close(); input_channel_->SetBlockSize(fleet_send_batch_size_); VLOG(3) << "DatasetImpl::GlobalShuffle() input_channel_ size " << input_channel_->Size(); auto get_client_id = [this, fleet_ptr](const T& data) -> size_t { if (!this->merge_by_insid_) { return fleet_ptr->LocalRandomEngine()() % this->trainer_num_; } else { return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) % this->trainer_num_; } }; auto global_shuffle_func = [this, get_client_id]() { auto fleet_ptr = FleetWrapper::GetInstance(); std::vector data; while (this->input_channel_->Read(data)) { std::vector ars(this->trainer_num_); for (auto& t : data) { auto client_id = get_client_id(t); ars[client_id] << t; } std::vector> total_status; std::vector send_index(this->trainer_num_); for (int i = 0; i < this->trainer_num_; ++i) { send_index[i] = i; } std::shuffle(send_index.begin(), send_index.end(), fleet_ptr->LocalRandomEngine()); for (auto index = 0u; index < this->trainer_num_; ++index) { int i = send_index[index]; if (ars[i].Length() == 0) { continue; } std::string msg(ars[i].Buffer(), ars[i].Length()); auto ret = fleet_ptr->SendClientToClientMsg(0, i, msg); total_status.push_back(std::move(ret)); } for (auto& t : total_status) { t.wait(); } ars.clear(); ars.shrink_to_fit(); data.clear(); data.shrink_to_fit(); // currently we find bottleneck is server not able to handle large data // in time, so we can remove this sleep and set fleet_send_batch_size to // 1024, and set server thread to 24. if (fleet_send_sleep_seconds_ != 0) { sleep(this->fleet_send_sleep_seconds_); } } }; std::vector global_shuffle_threads; if (thread_num == -1) { thread_num = thread_num_; } VLOG(3) << "start global shuffle threads, num = " << thread_num; for (int i = 0; i < thread_num; ++i) { global_shuffle_threads.push_back(std::thread(global_shuffle_func)); } for (std::thread& t : global_shuffle_threads) { t.join(); } global_shuffle_threads.clear(); global_shuffle_threads.shrink_to_fit(); input_channel_->Clear(); timeline.Pause(); VLOG(3) << "DatasetImpl::GlobalShuffle() end, cost time=" << timeline.ElapsedSec() << " seconds"; } template void DatasetImpl::DynamicAdjustChannelNum(int channel_num) { if (channel_num_ == channel_num) { VLOG(3) << "DatasetImpl::DynamicAdjustChannelNum channel_num_=" << channel_num_ << ", channel_num_=channel_num, no need to adjust"; return; } VLOG(3) << "adjust channel num from " << channel_num_ << " to " << channel_num; channel_num_ = channel_num; std::vector>* origin_channels = nullptr; std::vector>* other_channels = nullptr; // find out which channel (output or consume) has data int cur_channel = 0; uint64_t output_channels_data_size = 0; uint64_t consume_channels_data_size = 0; CHECK(multi_output_channel_.size() == multi_consume_channel_.size()); for (int i = 0; i < multi_output_channel_.size(); ++i) { output_channels_data_size += multi_output_channel_[i]->Size(); consume_channels_data_size += multi_consume_channel_[i]->Size(); } if (output_channels_data_size != 0) { CHECK(consume_channels_data_size == 0); // NOLINT cur_channel = 0; } else { CHECK(output_channels_data_size == 0); // NOLINT cur_channel = 1; } if (cur_channel == 0) { origin_channels = &multi_output_channel_; other_channels = &multi_consume_channel_; } else { origin_channels = &multi_consume_channel_; other_channels = &multi_output_channel_; } CHECK(origin_channels != nullptr); // NOLINT CHECK(other_channels != nullptr); // NOLINT paddle::framework::Channel total_data_channel = paddle::framework::MakeChannel(); std::vector> new_channels; std::vector> new_other_channels; std::vector local_vec; for (int i = 0; i < origin_channels->size(); ++i) { local_vec.clear(); (*origin_channels)[i]->Close(); (*origin_channels)[i]->ReadAll(local_vec); total_data_channel->Write(std::move(local_vec)); } total_data_channel->Close(); total_data_channel->SetBlockSize(total_data_channel->Size() / channel_num + 1); for (int i = 0; i < channel_num; ++i) { local_vec.clear(); total_data_channel->Read(local_vec); new_other_channels.push_back(paddle::framework::MakeChannel()); new_channels.push_back(paddle::framework::MakeChannel()); new_channels[i]->Write(std::move(local_vec)); } total_data_channel->Clear(); origin_channels->clear(); other_channels->clear(); *origin_channels = new_channels; *other_channels = new_other_channels; new_channels.clear(); new_other_channels.clear(); std::vector>().swap(new_channels); std::vector>().swap(new_other_channels); local_vec.clear(); std::vector().swap(local_vec); VLOG(3) << "adjust channel num done"; } template void DatasetImpl::DynamicAdjustReadersNum(int thread_num) { if (thread_num_ == thread_num) { VLOG(3) << "DatasetImpl::DynamicAdjustReadersNum thread_num_=" << thread_num_ << ", thread_num_=thread_num, no need to adjust"; return; } VLOG(3) << "adjust readers num from " << thread_num_ << " to " << thread_num; thread_num_ = thread_num; std::vector>().swap(readers_); CreateReaders(); VLOG(3) << "adjust readers num done"; } template void DatasetImpl::SetFleetSendSleepSeconds(int seconds) { fleet_send_sleep_seconds_ = seconds; } template void DatasetImpl::CreateReaders() { VLOG(3) << "Calling CreateReaders()"; VLOG(3) << "thread num in Dataset: " << thread_num_; VLOG(3) << "Filelist size in Dataset: " << filelist_.size(); VLOG(3) << "channel num in Dataset: " << channel_num_; CHECK(thread_num_ > 0) << "thread num should > 0"; CHECK(channel_num_ > 0) << "channel num should > 0"; CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num"; VLOG(3) << "readers size: " << readers_.size(); if (readers_.size() != 0) { VLOG(3) << "readers_.size() = " << readers_.size() << ", will not create again"; return; } VLOG(3) << "data feed class name: " << data_feed_desc_.name(); int channel_idx = 0; for (int i = 0; i < thread_num_; ++i) { readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name())); readers_[i]->Init(data_feed_desc_); readers_[i]->SetThreadId(i); readers_[i]->SetThreadNum(thread_num_); readers_[i]->SetFileListMutex(&mutex_for_pick_file_); readers_[i]->SetFileListIndex(&file_idx_); readers_[i]->SetFileList(filelist_); readers_[i]->SetParseInsId(parse_ins_id_); readers_[i]->SetParseContent(parse_content_); if (input_channel_ != nullptr) { readers_[i]->SetInputChannel(input_channel_.get()); } if (cur_channel_ == 0 && channel_idx < multi_output_channel_.size()) { readers_[i]->SetOutputChannel(multi_output_channel_[channel_idx].get()); readers_[i]->SetConsumeChannel(multi_consume_channel_[channel_idx].get()); } else if (channel_idx < multi_output_channel_.size()) { readers_[i]->SetOutputChannel(multi_consume_channel_[channel_idx].get()); readers_[i]->SetConsumeChannel(multi_output_channel_[channel_idx].get()); } ++channel_idx; if (channel_idx >= channel_num_) { channel_idx = 0; } } VLOG(3) << "readers size: " << readers_.size(); } template void DatasetImpl::DestroyReaders() { VLOG(3) << "Calling DestroyReaders()"; VLOG(3) << "readers size1: " << readers_.size(); std::vector>().swap(readers_); VLOG(3) << "readers size: " << readers_.size(); file_idx_ = 0; cur_channel_ = 1 - cur_channel_; } template void DatasetImpl::SetPreLoadThreadNum(int thread_num) { preload_thread_num_ = thread_num; } template void DatasetImpl::CreatePreLoadReaders() { VLOG(3) << "Begin CreatePreLoadReaders"; if (preload_thread_num_ == 0) { preload_thread_num_ = thread_num_; } CHECK(preload_thread_num_ > 0) << "thread num should > 0"; CHECK(input_channel_ != nullptr); preload_readers_.clear(); for (int i = 0; i < preload_thread_num_; ++i) { preload_readers_.push_back( DataFeedFactory::CreateDataFeed(data_feed_desc_.name())); preload_readers_[i]->Init(data_feed_desc_); preload_readers_[i]->SetThreadId(i); preload_readers_[i]->SetThreadNum(preload_thread_num_); preload_readers_[i]->SetFileListMutex(&mutex_for_pick_file_); preload_readers_[i]->SetFileListIndex(&file_idx_); preload_readers_[i]->SetFileList(filelist_); preload_readers_[i]->SetParseInsId(parse_ins_id_); preload_readers_[i]->SetParseContent(parse_content_); preload_readers_[i]->SetInputChannel(input_channel_.get()); preload_readers_[i]->SetOutputChannel(nullptr); preload_readers_[i]->SetConsumeChannel(nullptr); } VLOG(3) << "End CreatePreLoadReaders"; } template void DatasetImpl::DestroyPreLoadReaders() { VLOG(3) << "Begin DestroyPreLoadReaders"; preload_readers_.clear(); std::vector>().swap( preload_readers_); file_idx_ = 0; VLOG(3) << "End DestroyPreLoadReaders"; } template int64_t DatasetImpl::GetMemoryDataSize() { return input_channel_->Size(); } template int64_t DatasetImpl::GetShuffleDataSize() { int64_t sum = 0; for (size_t i = 0; i < multi_output_channel_.size(); ++i) { sum += multi_output_channel_[i]->Size() + multi_consume_channel_[i]->Size(); } return sum; } template int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, const std::string& msg) { #ifdef _LINUX VLOG(3) << "ReceiveFromClient msg_type=" << msg_type << ", client_id=" << client_id << ", msg length=" << msg.length(); if (msg.length() == 0) { return 0; } paddle::framework::BinaryArchive ar; ar.SetReadBuffer(const_cast(msg.c_str()), msg.length(), nullptr); if (ar.Cursor() == ar.Finish()) { return 0; } std::vector data; while (ar.Cursor() < ar.Finish()) { data.push_back(ar.Get()); } CHECK(ar.Cursor() == ar.Finish()); auto fleet_ptr = FleetWrapper::GetInstance(); // not use random because it doesn't perform well here. // to make sure each channel get data equally, we just put data to // channel one by one. // int64_t index = fleet_ptr->LocalRandomEngine()() % channel_num_; int64_t index = 0; { std::unique_lock lk(global_index_mutex_); index = global_index_++; } index = index % channel_num_; VLOG(3) << "ramdom index=" << index; multi_output_channel_[index]->Write(std::move(data)); data.clear(); data.shrink_to_fit(); #endif return 0; } // explicit instantiation template class DatasetImpl; void MultiSlotDataset::MergeByInsId() { VLOG(3) << "MultiSlotDataset::MergeByInsId begin"; if (!merge_by_insid_) { VLOG(3) << "merge_by_insid=false, will not MergeByInsId"; return; } auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); std::vector use_slots; for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) { const auto& slot = multi_slot_desc.slots(i); if (slot.is_used()) { use_slots.push_back(slot.name()); } } CHECK(multi_output_channel_.size() != 0); // NOLINT auto channel_data = paddle::framework::MakeChannel(); VLOG(3) << "multi_output_channel_.size() " << multi_output_channel_.size(); for (size_t i = 0; i < multi_output_channel_.size(); ++i) { std::vector vec_data; multi_output_channel_[i]->Close(); multi_output_channel_[i]->ReadAll(vec_data); channel_data->Write(std::move(vec_data)); vec_data.clear(); vec_data.shrink_to_fit(); multi_output_channel_[i]->Clear(); } channel_data->Close(); std::vector recs; recs.reserve(channel_data->Size()); channel_data->ReadAll(recs); channel_data->Clear(); std::sort(recs.begin(), recs.end(), [](const Record& a, const Record& b) { return a.ins_id_ < b.ins_id_; }); std::vector results; uint64_t drop_ins_num = 0; std::unordered_set all_int64; std::unordered_set all_float; std::unordered_set local_uint64; std::unordered_set local_float; VLOG(3) << "recs.size() " << recs.size(); for (size_t i = 0; i < recs.size();) { size_t j = i + 1; while (j < recs.size() && recs[j].ins_id_ == recs[i].ins_id_) { j++; } if (merge_size_ > 0 && j - i != merge_size_) { drop_ins_num += j - i; LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i << ", because merge_size=" << merge_size_; i = j; continue; } all_int64.clear(); all_float.clear(); bool has_conflict_slot = false; uint16_t conflict_slot = 0; Record rec; rec.ins_id_ = recs[i].ins_id_; rec.content_ = recs[i].content_; for (size_t k = i; k < j; k++) { local_uint64.clear(); local_float.clear(); for (auto& feature : recs[k].uint64_feasigns_) { uint16_t slot = feature.slot(); if (all_int64.find(slot) != all_int64.end()) { has_conflict_slot = true; conflict_slot = slot; break; } local_uint64.insert(slot); rec.uint64_feasigns_.push_back(std::move(feature)); } if (has_conflict_slot) { break; } all_int64.insert(local_uint64.begin(), local_uint64.end()); for (auto& feature : recs[k].float_feasigns_) { uint16_t slot = feature.slot(); if (all_float.find(slot) != all_float.end()) { has_conflict_slot = true; conflict_slot = slot; break; } local_float.insert(slot); rec.float_feasigns_.push_back(std::move(feature)); } if (has_conflict_slot) { break; } all_float.insert(local_float.begin(), local_float.end()); } if (has_conflict_slot) { LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i << ", because conflict_slot=" << use_slots[conflict_slot]; drop_ins_num += j - i; } else { results.push_back(std::move(rec)); } i = j; } std::vector().swap(recs); VLOG(3) << "results size " << results.size(); LOG(WARNING) << "total drop ins num: " << drop_ins_num; results.shrink_to_fit(); auto fleet_ptr = FleetWrapper::GetInstance(); std::shuffle(results.begin(), results.end(), fleet_ptr->LocalRandomEngine()); channel_data->Open(); channel_data->Write(std::move(results)); channel_data->Close(); results.clear(); results.shrink_to_fit(); VLOG(3) << "channel data size " << channel_data->Size(); channel_data->SetBlockSize(channel_data->Size() / channel_num_ + 1); VLOG(3) << "channel data block size " << channel_data->BlockSize(); for (size_t i = 0; i < multi_output_channel_.size(); ++i) { std::vector vec_data; channel_data->Read(vec_data); multi_output_channel_[i]->Open(); multi_output_channel_[i]->Write(std::move(vec_data)); vec_data.clear(); vec_data.shrink_to_fit(); } CHECK(channel_data->Size() == 0); // NOLINT channel_data->Clear(); VLOG(3) << "MultiSlotDataset::MergeByInsId end"; } void MultiSlotDataset::GetRandomData(const std::set& slots_to_replace, std::vector* result) { int debug_erase_cnt = 0; int debug_push_cnt = 0; auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); slots_shuffle_rclist_.ReInit(); for (const auto& rec : slots_shuffle_original_data_) { RecordCandidate rand_rec; Record new_rec = rec; slots_shuffle_rclist_.AddAndGet(rec, &rand_rec); for (auto it = new_rec.uint64_feasigns_.begin(); it != new_rec.uint64_feasigns_.end();) { if (slots_to_replace.find(it->slot()) != slots_to_replace.end()) { it = new_rec.uint64_feasigns_.erase(it); debug_erase_cnt += 1; } else { ++it; } } for (auto slot : slots_to_replace) { auto range = rand_rec.feas.equal_range(slot); for (auto it = range.first; it != range.second; ++it) { new_rec.uint64_feasigns_.push_back({it->second, it->first}); debug_push_cnt += 1; } } result->push_back(std::move(new_rec)); } VLOG(2) << "erase feasign num: " << debug_erase_cnt << " repush feasign num: " << debug_push_cnt; } // slots shuffle to input_channel_ with needed-shuffle slots void MultiSlotDataset::SlotsShuffle( const std::set& slots_to_replace) { int out_channel_size = 0; if (cur_channel_ == 0) { for (size_t i = 0; i < multi_output_channel_.size(); ++i) { out_channel_size += multi_output_channel_[i]->Size(); } } else { for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { out_channel_size += multi_consume_channel_[i]->Size(); } } VLOG(2) << "DatasetImpl::SlotsShuffle() begin with input channel size: " << input_channel_->Size() << " output channel size: " << out_channel_size; if (!slots_shuffle_fea_eval_) { VLOG(3) << "DatasetImpl::SlotsShuffle() end," "fea eval mode off, need to set on for slots shuffle"; return; } if ((!input_channel_ || input_channel_->Size() == 0) && slots_shuffle_original_data_.size() == 0 && out_channel_size == 0) { VLOG(3) << "DatasetImpl::SlotsShuffle() end, no data to slots shuffle"; return; } platform::Timer timeline; timeline.Start(); auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); std::set index_slots; for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) { std::string cur_slot = multi_slot_desc.slots(i).name(); if (slots_to_replace.find(cur_slot) != slots_to_replace.end()) { index_slots.insert(i); } } if (slots_shuffle_original_data_.size() == 0) { // before first slots shuffle, instances could be in // input_channel, oupput_channel or consume_channel if (input_channel_ && input_channel_->Size() != 0) { slots_shuffle_original_data_.reserve(input_channel_->Size()); input_channel_->Close(); input_channel_->ReadAll(slots_shuffle_original_data_); } else { CHECK(out_channel_size > 0); // NOLINT if (cur_channel_ == 0) { for (size_t i = 0; i < multi_output_channel_.size(); ++i) { std::vector vec_data; multi_output_channel_[i]->Close(); multi_output_channel_[i]->ReadAll(vec_data); slots_shuffle_original_data_.reserve( slots_shuffle_original_data_.size() + vec_data.size()); slots_shuffle_original_data_.insert( slots_shuffle_original_data_.end(), std::make_move_iterator(vec_data.begin()), std::make_move_iterator(vec_data.end())); vec_data.clear(); vec_data.shrink_to_fit(); multi_output_channel_[i]->Clear(); } } else { for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { std::vector vec_data; multi_consume_channel_[i]->Close(); multi_consume_channel_[i]->ReadAll(vec_data); slots_shuffle_original_data_.reserve( slots_shuffle_original_data_.size() + vec_data.size()); slots_shuffle_original_data_.insert( slots_shuffle_original_data_.end(), std::make_move_iterator(vec_data.begin()), std::make_move_iterator(vec_data.end())); vec_data.clear(); vec_data.shrink_to_fit(); multi_consume_channel_[i]->Clear(); } } } } else { // if already have original data for slots shuffle, clear channel input_channel_->Clear(); if (cur_channel_ == 0) { for (size_t i = 0; i < multi_output_channel_.size(); ++i) { if (!multi_output_channel_[i]) { continue; } multi_output_channel_[i]->Clear(); } } else { for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { if (!multi_consume_channel_[i]) { continue; } multi_consume_channel_[i]->Clear(); } } } int end_size = 0; if (cur_channel_ == 0) { for (size_t i = 0; i < multi_output_channel_.size(); ++i) { if (!multi_output_channel_[i]) { continue; } end_size += multi_output_channel_[i]->Size(); } } else { for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { if (!multi_consume_channel_[i]) { continue; } end_size += multi_consume_channel_[i]->Size(); } } CHECK(input_channel_->Size() == 0) << "input channel should be empty before slots shuffle"; std::vector random_data; random_data.clear(); // get slots shuffled random_data GetRandomData(index_slots, &random_data); input_channel_->Open(); input_channel_->Write(std::move(random_data)); random_data.clear(); random_data.shrink_to_fit(); input_channel_->Close(); timeline.Pause(); VLOG(2) << "DatasetImpl::SlotsShuffle() end" << ", memory data size for slots shuffle=" << input_channel_->Size() << ", cost time=" << timeline.ElapsedSec() << " seconds"; } } // end namespace framework } // end namespace paddle