From 79bd5f90f304c239f2b51778c977648016174381 Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Wed, 29 Sep 2021 14:59:53 +0800 Subject: [PATCH] add slot record dataset (#36200) --- paddle/fluid/framework/channel.h | 20 +- paddle/fluid/framework/data_feed.cc | 112 +++++++- paddle/fluid/framework/data_feed.h | 317 +++++++++++++++++++++- paddle/fluid/framework/data_set.cc | 166 +++++++++-- paddle/fluid/framework/data_set.h | 40 ++- paddle/fluid/framework/dataset_factory.cc | 3 +- paddle/fluid/platform/flags.cc | 8 + paddle/fluid/pybind/data_set_py.cc | 2 - 8 files changed, 622 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index 503f1513aad..80fee94f1c8 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -157,7 +157,19 @@ class ChannelObject { p.resize(finished); return finished; } + // read once only + size_t ReadOnce(std::vector& p, size_t size) { // NOLINT + if (size == 0) { + return 0; + } + std::unique_lock lock(mutex_); + p.resize(size); + size_t finished = Read(size, &p[0], lock, true); + p.resize(finished); + Notify(); + return finished; + } size_t ReadAll(std::vector& p) { // NOLINT p.clear(); size_t finished = 0; @@ -241,17 +253,21 @@ class ChannelObject { return !closed_; } - size_t Read(size_t n, T* p, std::unique_lock& lock) { // NOLINT + size_t Read(size_t n, T* p, std::unique_lock& lock, // NOLINT + bool once = false) { // NOLINT size_t finished = 0; CHECK(n <= MaxCapacity() - reading_count_); reading_count_ += n; while (finished < n && WaitForRead(lock)) { - size_t m = std::min(n - finished, data_.size()); + size_t m = (std::min)(n - finished, data_.size()); for (size_t i = 0; i < m; i++) { p[finished++] = std::move(data_.front()); data_.pop_front(); } reading_count_ -= m; + if (once && m > 0) { + break; + } } reading_count_ -= n - finished; return finished; diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index fdb24ee18ec..4463fd9fd53 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -36,6 +36,107 @@ DLManager& global_dlmanager_pool() { return manager; } +class BufferedLineFileReader { + typedef std::function SampleFunc; + static const int MAX_FILE_BUFF_SIZE = 4 * 1024 * 1024; + class FILEReader { + public: + explicit FILEReader(FILE* fp) : fp_(fp) {} + int read(char* buf, int len) { return fread(buf, sizeof(char), len, fp_); } + + private: + FILE* fp_; + }; + + public: + typedef std::function LineFunc; + + private: + template + int read_lines(T* reader, LineFunc func, int skip_lines) { + int lines = 0; + size_t ret = 0; + char* ptr = NULL; + char* eol = NULL; + total_len_ = 0; + error_line_ = 0; + + SampleFunc spfunc = get_sample_func(); + std::string x; + while (!is_error() && (ret = reader->read(buff_, MAX_FILE_BUFF_SIZE)) > 0) { + total_len_ += ret; + ptr = buff_; + eol = reinterpret_cast(memchr(ptr, '\n', ret)); + while (eol != NULL) { + int size = static_cast((eol - ptr) + 1); + x.append(ptr, size - 1); + ++lines; + if (lines > skip_lines && spfunc()) { + if (!func(x)) { + ++error_line_; + } + } + + x.clear(); + ptr += size; + ret -= size; + eol = reinterpret_cast(memchr(ptr, '\n', ret)); + } + if (ret > 0) { + x.append(ptr, ret); + } + } + if (!is_error() && !x.empty()) { + ++lines; + if (lines > skip_lines && spfunc()) { + if (!func(x)) { + ++error_line_; + } + } + } + return lines; + } + + public: + BufferedLineFileReader() + : random_engine_(std::random_device()()), + uniform_distribution_(0.0f, 1.0f) { + total_len_ = 0; + sample_line_ = 0; + buff_ = + reinterpret_cast(calloc(MAX_FILE_BUFF_SIZE + 1, sizeof(char))); + } + ~BufferedLineFileReader() { free(buff_); } + + int read_file(FILE* fp, LineFunc func, int skip_lines) { + FILEReader reader(fp); + return read_lines(&reader, func, skip_lines); + } + uint64_t file_size(void) { return total_len_; } + void set_sample_rate(float r) { sample_rate_ = r; } + size_t get_sample_line() { return sample_line_; } + bool is_error(void) { return (error_line_ > 10); } + + private: + SampleFunc get_sample_func() { + if (std::abs(sample_rate_ - 1.0f) < 1e-5f) { + return [this](void) { return true; }; + } + return [this](void) { + return (uniform_distribution_(random_engine_) < sample_rate_); + }; + } + + private: + char* buff_ = nullptr; + uint64_t total_len_ = 0; + + std::default_random_engine random_engine_; + std::uniform_real_distribution uniform_distribution_; + float sample_rate_ = 1.0f; + size_t sample_line_ = 0; + size_t error_line_ = 0; +}; void RecordCandidateList::ReSize(size_t length) { mutex_.lock(); capacity_ = length; @@ -301,7 +402,7 @@ int InMemoryDataFeed::Next() { << ", thread_id=" << thread_id_; } } else { - VLOG(3) << "enable heter NEXT: " << offset_index_ + VLOG(3) << "enable heter next: " << offset_index_ << " batch_offsets: " << batch_offsets_.size(); if (offset_index_ >= batch_offsets_.size()) { VLOG(3) << "offset_index: " << offset_index_ @@ -318,14 +419,7 @@ int InMemoryDataFeed::Next() { VLOG(3) << "finish reading for heterps, batch size zero, thread_id=" << thread_id_; } - /* - if (offset_index_ == batch_offsets_.size() - 1) { - std::vector data; - output_channel_->ReadAll(data); - consume_channel_->Write(std::move(data)); - } - */ - VLOG(3) << "#15 enable heter NEXT: " << offset_index_ + VLOG(3) << "enable heter next: " << offset_index_ << " batch_offsets: " << batch_offsets_.size() << " baych_size: " << this->batch_size_; } diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 198bc51463a..5527eaf1f6f 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -39,8 +39,14 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/string/string_helper.h" +DECLARE_int32(record_pool_max_size); +DECLARE_int32(slotpool_thread_num); +DECLARE_bool(enable_slotpool_wait_release); +DECLARE_bool(enable_slotrecord_reset_shrink); + namespace paddle { namespace framework { class DataFeedDesc; @@ -69,6 +75,50 @@ namespace framework { // while (reader->Next()) { // // trainer do something // } + +template +struct SlotValues { + std::vector slot_values; + std::vector slot_offsets; + + void add_values(const T* values, uint32_t num) { + if (slot_offsets.empty()) { + slot_offsets.push_back(0); + } + if (num > 0) { + slot_values.insert(slot_values.end(), values, values + num); + } + slot_offsets.push_back(static_cast(slot_values.size())); + } + T* get_values(int idx, size_t* size) { + uint32_t& offset = slot_offsets[idx]; + (*size) = slot_offsets[idx + 1] - offset; + return &slot_values[offset]; + } + void add_slot_feasigns(const std::vector>& slot_feasigns, + uint32_t fea_num) { + slot_values.reserve(fea_num); + int slot_num = static_cast(slot_feasigns.size()); + slot_offsets.resize(slot_num + 1); + for (int i = 0; i < slot_num; ++i) { + auto& slot_val = slot_feasigns[i]; + slot_offsets[i] = static_cast(slot_values.size()); + uint32_t num = static_cast(slot_val.size()); + if (num > 0) { + slot_values.insert(slot_values.end(), slot_val.begin(), slot_val.end()); + } + } + slot_offsets[slot_num] = slot_values.size(); + } + void clear(bool shrink) { + slot_offsets.clear(); + slot_values.clear(); + if (shrink) { + slot_values.shrink_to_fit(); + slot_offsets.shrink_to_fit(); + } + } +}; union FeatureFeasign { uint64_t uint64_feasign_; float float_feasign_; @@ -97,6 +147,38 @@ struct FeatureItem { uint16_t slot_; }; +struct AllSlotInfo { + std::string slot; + std::string type; + int used_idx; + int slot_value_idx; +}; +struct UsedSlotInfo { + int idx; + int slot_value_idx; + std::string slot; + std::string type; + bool dense; + std::vector local_shape; + int total_dims_without_inductive; + int inductive_shape_index; +}; +struct SlotRecordObject { + uint64_t search_id; + uint32_t rank; + uint32_t cmatch; + std::string ins_id_; + SlotValues slot_uint64_feasigns_; + SlotValues slot_float_feasigns_; + + ~SlotRecordObject() { clear(true); } + void reset(void) { clear(FLAGS_enable_slotrecord_reset_shrink); } + void clear(bool shrink) { + slot_uint64_feasigns_.clear(shrink); + slot_float_feasigns_.clear(shrink); + } +}; +using SlotRecord = SlotRecordObject*; // sizeof Record is much less than std::vector struct Record { std::vector uint64_feasigns_; @@ -108,6 +190,179 @@ struct Record { uint32_t cmatch; }; +inline SlotRecord make_slotrecord() { + static const size_t slot_record_byte_size = sizeof(SlotRecordObject); + void* p = malloc(slot_record_byte_size); + new (p) SlotRecordObject; + return reinterpret_cast(p); +} + +inline void free_slotrecord(SlotRecordObject* p) { + p->~SlotRecordObject(); + free(p); +} + +template +class SlotObjAllocator { + public: + explicit SlotObjAllocator(std::function deleter) + : free_nodes_(NULL), capacity_(0), deleter_(deleter) {} + ~SlotObjAllocator() { clear(); } + + void clear() { + T* tmp = NULL; + while (free_nodes_ != NULL) { + tmp = reinterpret_cast(reinterpret_cast(free_nodes_)); + free_nodes_ = free_nodes_->next; + deleter_(tmp); + --capacity_; + } + CHECK_EQ(capacity_, static_cast(0)); + } + T* acquire(void) { + T* x = NULL; + x = reinterpret_cast(reinterpret_cast(free_nodes_)); + free_nodes_ = free_nodes_->next; + --capacity_; + return x; + } + void release(T* x) { + Node* node = reinterpret_cast(reinterpret_cast(x)); + node->next = free_nodes_; + free_nodes_ = node; + ++capacity_; + } + size_t capacity(void) { return capacity_; } + + private: + struct alignas(T) Node { + union { + Node* next; + char data[sizeof(T)]; + }; + }; + Node* free_nodes_; // a list + size_t capacity_; + std::function deleter_ = nullptr; +}; +static const int OBJPOOL_BLOCK_SIZE = 10000; +class SlotObjPool { + public: + SlotObjPool() + : max_capacity_(FLAGS_record_pool_max_size), alloc_(free_slotrecord) { + ins_chan_ = MakeChannel(); + ins_chan_->SetBlockSize(OBJPOOL_BLOCK_SIZE); + for (int i = 0; i < FLAGS_slotpool_thread_num; ++i) { + threads_.push_back(std::thread([this]() { run(); })); + } + disable_pool_ = false; + count_ = 0; + } + ~SlotObjPool() { + ins_chan_->Close(); + for (auto& t : threads_) { + t.join(); + } + } + void disable_pool(bool disable) { disable_pool_ = disable; } + void set_max_capacity(size_t max_capacity) { max_capacity_ = max_capacity; } + void get(std::vector* output, int n) { + output->resize(n); + return get(&(*output)[0], n); + } + void get(SlotRecord* output, int n) { + int size = 0; + mutex_.lock(); + int left = static_cast(alloc_.capacity()); + if (left > 0) { + size = (left >= n) ? n : left; + for (int i = 0; i < size; ++i) { + output[i] = alloc_.acquire(); + } + } + mutex_.unlock(); + count_ += n; + if (size == n) { + return; + } + for (int i = size; i < n; ++i) { + output[i] = make_slotrecord(); + } + } + void put(std::vector* input) { + size_t size = input->size(); + if (size == 0) { + return; + } + put(&(*input)[0], size); + input->clear(); + } + void put(SlotRecord* input, size_t size) { + CHECK(ins_chan_->WriteMove(size, input) == size); + } + void run(void) { + std::vector input; + while (ins_chan_->ReadOnce(input, OBJPOOL_BLOCK_SIZE)) { + if (input.empty()) { + continue; + } + // over max capacity + size_t n = input.size(); + count_ -= n; + if (disable_pool_ || n + capacity() > max_capacity_) { + for (auto& t : input) { + free_slotrecord(t); + } + } else { + for (auto& t : input) { + t->reset(); + } + mutex_.lock(); + for (auto& t : input) { + alloc_.release(t); + } + mutex_.unlock(); + } + input.clear(); + } + } + void clear(void) { + platform::Timer timeline; + timeline.Start(); + mutex_.lock(); + alloc_.clear(); + mutex_.unlock(); + // wait release channel data + if (FLAGS_enable_slotpool_wait_release) { + while (!ins_chan_->Empty()) { + sleep(1); + } + } + timeline.Pause(); + VLOG(3) << "clear slot pool data size=" << count_.load() + << ", span=" << timeline.ElapsedSec(); + } + size_t capacity(void) { + mutex_.lock(); + size_t total = alloc_.capacity(); + mutex_.unlock(); + return total; + } + + private: + size_t max_capacity_; + Channel ins_chan_; + std::vector threads_; + std::mutex mutex_; + SlotObjAllocator alloc_; + bool disable_pool_; + std::atomic count_; // NOLINT +}; + +inline SlotObjPool& SlotRecordPool() { + static SlotObjPool pool; + return pool; +} struct PvInstanceObject { std::vector ads; void merge_instance(Record* ins) { ads.push_back(ins); } @@ -129,7 +384,21 @@ class CustomParser { CustomParser() {} virtual ~CustomParser() {} virtual void Init(const std::vector& slots) = 0; + virtual bool Init(const std::vector& slots) = 0; virtual void ParseOneInstance(const char* str, Record* instance) = 0; + virtual bool ParseOneInstance( + const std::string& line, + std::function&, int)> + GetInsFunc) { // NOLINT + return true; + } + virtual bool ParseFileInstance( + std::function ReadBuffFunc, + std::function&, int, int)> + PullRecordsFunc, // NOLINT + int& lines) { // NOLINT + return false; + } }; typedef paddle::framework::CustomParser* (*CreateParserObjectFunc)(); @@ -194,6 +463,34 @@ class DLManager { return nullptr; } + paddle::framework::CustomParser* Load(const std::string& name, + const std::vector& conf) { +#ifdef _LINUX + std::lock_guard lock(mutex_); + DLHandle handle; + std::map::iterator it = handle_map_.find(name); + if (it != handle_map_.end()) { + return it->second.parser; + } + handle.module = dlopen(name.c_str(), RTLD_NOW); + if (handle.module == nullptr) { + VLOG(0) << "Create so of " << name << " fail"; + exit(-1); + return nullptr; + } + + CreateParserObjectFunc create_parser_func = + (CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject"); + handle.parser = create_parser_func(); + handle.parser->Init(conf); + handle_map_.insert({name, handle}); + + return handle.parser; +#endif + VLOG(0) << "Not implement in windows"; + return nullptr; + } + paddle::framework::CustomParser* ReLoad(const std::string& name, const std::vector& conf) { Close(name); @@ -415,6 +712,11 @@ class InMemoryDataFeed : public DataFeed { virtual void SetCurrentPhase(int current_phase); virtual void LoadIntoMemory(); virtual void LoadIntoMemoryFromSo(); + virtual void SetRecord(T* records) { records_ = records; } + int GetDefaultBatchSize() { return default_batch_size_; } + void AddBatchOffset(const std::pair& offset) { + batch_offsets_.push_back(offset); + } protected: virtual bool ParseOneInstance(T* instance) = 0; @@ -424,6 +726,11 @@ class InMemoryDataFeed : public DataFeed { virtual void PutToFeedVec(const std::vector& ins_vec) = 0; virtual void PutToFeedVec(const T* ins_vec, int num) = 0; + std::vector> batch_float_feasigns_; + std::vector> batch_uint64_feasigns_; + std::vector> offset_; + std::vector visit_; + int thread_id_; int thread_num_; bool parse_ins_id_; @@ -783,11 +1090,7 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed { MultiSlotInMemoryDataFeed() {} virtual ~MultiSlotInMemoryDataFeed() {} virtual void Init(const DataFeedDesc& data_feed_desc); - void SetRecord(Record* records) { records_ = records; } - int GetDefaultBatchSize() { return default_batch_size_; } - void AddBatchOffset(const std::pair& offset) { - batch_offsets_.push_back(offset); - } + // void SetRecord(Record* records) { records_ = records; } protected: virtual bool ParseOneInstance(Record* instance); @@ -798,10 +1101,6 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed { virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id, uint32_t* cmatch, uint32_t* rank); virtual void PutToFeedVec(const Record* ins_vec, int num); - std::vector> batch_float_feasigns_; - std::vector> batch_uint64_feasigns_; - std::vector> offset_; - std::vector visit_; }; class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed { diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 08c42a93d1f..82a39b206e6 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -351,10 +351,8 @@ static int compute_thread_batch_nccl( return thread_avg_batch_num; } -template -void DatasetImpl::SetHeterPs(bool enable_heterps) { +void MultiSlotDataset::PrepareTrain() { #ifdef PADDLE_WITH_GLOO - enable_heterps_ = enable_heterps; if (enable_heterps_) { if (input_records_.size() == 0 && input_channel_ != nullptr && input_channel_->Size() != 0) { @@ -541,22 +539,21 @@ void DatasetImpl::LocalShuffle() { << timeline.ElapsedSec() << " seconds"; } -template -void DatasetImpl::GlobalShuffle(int thread_num) { +void MultiSlotDataset::GlobalShuffle(int thread_num) { #ifdef PADDLE_WITH_PSLIB - VLOG(3) << "DatasetImpl::GlobalShuffle() begin"; + VLOG(3) << "MultiSlotDataset::GlobalShuffle() begin"; platform::Timer timeline; timeline.Start(); auto fleet_ptr = FleetWrapper::GetInstance(); if (!input_channel_ || input_channel_->Size() == 0) { - VLOG(3) << "DatasetImpl::GlobalShuffle() end, no data to shuffle"; + VLOG(3) << "MultiSlotDataset::GlobalShuffle() end, no data to shuffle"; return; } // local shuffle input_channel_->Close(); - std::vector data; + std::vector data; input_channel_->ReadAll(data); std::shuffle(data.begin(), data.end(), fleet_ptr->LocalRandomEngine()); input_channel_->Open(); @@ -566,10 +563,10 @@ void DatasetImpl::GlobalShuffle(int thread_num) { input_channel_->Close(); input_channel_->SetBlockSize(fleet_send_batch_size_); - VLOG(3) << "DatasetImpl::GlobalShuffle() input_channel_ size " + VLOG(3) << "MultiSlotDataset::GlobalShuffle() input_channel_ size " << input_channel_->Size(); - auto get_client_id = [this, fleet_ptr](const T& data) -> size_t { + auto get_client_id = [this, fleet_ptr](const Record& data) -> size_t { if (!this->merge_by_insid_) { return fleet_ptr->LocalRandomEngine()() % this->trainer_num_; } else { @@ -580,7 +577,7 @@ void DatasetImpl::GlobalShuffle(int thread_num) { auto global_shuffle_func = [this, get_client_id]() { auto fleet_ptr = FleetWrapper::GetInstance(); - std::vector data; + std::vector data; while (this->input_channel_->Read(data)) { std::vector ars(this->trainer_num_); for (auto& t : data) { @@ -835,9 +832,6 @@ void DatasetImpl::CreateReaders() { channel_idx = 0; } } - if (enable_heterps_) { - SetHeterPs(true); - } VLOG(3) << "readers size: " << readers_.size(); } @@ -923,9 +917,8 @@ int64_t DatasetImpl::GetShuffleDataSize() { return sum; } -template -int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, - const std::string& msg) { +int MultiSlotDataset::ReceiveFromClient(int msg_type, int client_id, + const std::string& msg) { #ifdef _LINUX VLOG(3) << "ReceiveFromClient msg_type=" << msg_type << ", client_id=" << client_id << ", msg length=" << msg.length(); @@ -937,9 +930,9 @@ int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, if (ar.Cursor() == ar.Finish()) { return 0; } - std::vector data; + std::vector data; while (ar.Cursor() < ar.Finish()) { - data.push_back(ar.Get()); + data.push_back(ar.Get()); } CHECK(ar.Cursor() == ar.Finish()); @@ -966,6 +959,20 @@ int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, // explicit instantiation template class DatasetImpl; +void MultiSlotDataset::DynamicAdjustReadersNum(int thread_num) { + if (thread_num_ == thread_num) { + VLOG(3) << "DatasetImpl::DynamicAdjustReadersNum thread_num_=" + << thread_num_ << ", thread_num_=thread_num, no need to adjust"; + return; + } + VLOG(3) << "adjust readers num from " << thread_num_ << " to " << thread_num; + thread_num_ = thread_num; + std::vector>().swap(readers_); + CreateReaders(); + VLOG(3) << "adjust readers num done"; + PrepareTrain(); +} + void MultiSlotDataset::PostprocessInstance() { // divide pv instance, and merge to input_channel_ if (enable_pv_merge_) { @@ -1503,5 +1510,126 @@ void MultiSlotDataset::SlotsShuffle( << ", cost time=" << timeline.ElapsedSec() << " seconds"; } +template class DatasetImpl; +void SlotRecordDataset::CreateChannel() { + if (input_channel_ == nullptr) { + input_channel_ = paddle::framework::MakeChannel(); + } +} +void SlotRecordDataset::CreateReaders() { + VLOG(3) << "Calling CreateReaders()"; + VLOG(3) << "thread num in Dataset: " << thread_num_; + VLOG(3) << "Filelist size in Dataset: " << filelist_.size(); + VLOG(3) << "channel num in Dataset: " << channel_num_; + CHECK(thread_num_ > 0) << "thread num should > 0"; + CHECK(channel_num_ > 0) << "channel num should > 0"; + CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num"; + VLOG(3) << "readers size: " << readers_.size(); + if (readers_.size() != 0) { + VLOG(3) << "readers_.size() = " << readers_.size() + << ", will not create again"; + return; + } + VLOG(3) << "data feed class name: " << data_feed_desc_.name(); + for (int i = 0; i < thread_num_; ++i) { + readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name())); + readers_[i]->Init(data_feed_desc_); + readers_[i]->SetThreadId(i); + readers_[i]->SetThreadNum(thread_num_); + readers_[i]->SetFileListMutex(&mutex_for_pick_file_); + readers_[i]->SetFileListIndex(&file_idx_); + readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_); + readers_[i]->SetFeaNum(&total_fea_num_); + readers_[i]->SetFileList(filelist_); + readers_[i]->SetParseInsId(parse_ins_id_); + readers_[i]->SetParseContent(parse_content_); + readers_[i]->SetParseLogKey(parse_logkey_); + readers_[i]->SetEnablePvMerge(enable_pv_merge_); + readers_[i]->SetCurrentPhase(current_phase_); + if (input_channel_ != nullptr) { + readers_[i]->SetInputChannel(input_channel_.get()); + } + } + VLOG(3) << "readers size: " << readers_.size(); +} + +void SlotRecordDataset::ReleaseMemory() { + VLOG(3) << "SlotRecordDataset::ReleaseMemory() begin"; + platform::Timer timeline; + timeline.Start(); + + if (input_channel_) { + input_channel_->Clear(); + input_channel_ = nullptr; + } + if (enable_heterps_) { + VLOG(3) << "put pool records size: " << input_records_.size(); + SlotRecordPool().put(&input_records_); + input_records_.clear(); + input_records_.shrink_to_fit(); + VLOG(3) << "release heterps input records records size: " + << input_records_.size(); + } + + readers_.clear(); + readers_.shrink_to_fit(); + + std::vector>().swap(readers_); + + VLOG(3) << "SlotRecordDataset::ReleaseMemory() end"; + VLOG(3) << "total_feasign_num_(" << STAT_GET(STAT_total_feasign_num_in_mem) + << ") - current_fea_num_(" << total_fea_num_ << ") = (" + << STAT_GET(STAT_total_feasign_num_in_mem) - total_fea_num_ << ")" + << " object pool size=" << SlotRecordPool().capacity(); // For Debug + STAT_SUB(STAT_total_feasign_num_in_mem, total_fea_num_); +} +void SlotRecordDataset::GlobalShuffle(int thread_num) { + // TODO(yaoxuefeng) + return; +} + +void SlotRecordDataset::DynamicAdjustChannelNum(int channel_num, + bool discard_remaining_ins) { + if (channel_num_ == channel_num) { + VLOG(3) << "DatasetImpl::DynamicAdjustChannelNum channel_num_=" + << channel_num_ << ", channel_num_=channel_num, no need to adjust"; + return; + } + VLOG(3) << "adjust channel num from " << channel_num_ << " to " + << channel_num; + channel_num_ = channel_num; + + if (static_cast(input_channel_->Size()) >= channel_num) { + input_channel_->SetBlockSize(input_channel_->Size() / channel_num + + (discard_remaining_ins ? 0 : 1)); + } + + VLOG(3) << "adjust channel num done"; +} + +void SlotRecordDataset::PrepareTrain() { +#ifdef PADDLE_WITH_GLOO + return; +#else + PADDLE_THROW(platform::errors::Unavailable( + "dataset set heterps need compile with GLOO")); +#endif + return; +} + +void SlotRecordDataset::DynamicAdjustReadersNum(int thread_num) { + if (thread_num_ == thread_num) { + VLOG(3) << "DatasetImpl::DynamicAdjustReadersNum thread_num_=" + << thread_num_ << ", thread_num_=thread_num, no need to adjust"; + return; + } + VLOG(3) << "adjust readers num from " << thread_num_ << " to " << thread_num; + thread_num_ = thread_num; + std::vector>().swap(readers_); + CreateReaders(); + VLOG(3) << "adjust readers num done"; + PrepareTrain(); +} + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index f3ee96fab82..981fb694e0f 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -149,7 +149,6 @@ class Dataset { virtual void DynamicAdjustReadersNum(int thread_num) = 0; // set fleet send sleep seconds virtual void SetFleetSendSleepSeconds(int seconds) = 0; - virtual void SetHeterPs(bool enable_heterps) = 0; protected: virtual int ReceiveFromClient(int msg_type, int client_id, @@ -207,7 +206,7 @@ class DatasetImpl : public Dataset { virtual void WaitPreLoadDone(); virtual void ReleaseMemory(); virtual void LocalShuffle(); - virtual void GlobalShuffle(int thread_num = -1); + virtual void GlobalShuffle(int thread_num = -1) {} virtual void SlotsShuffle(const std::set& slots_to_replace) {} virtual const std::vector& GetSlotsOriginalData() { return slots_shuffle_original_data_; @@ -233,7 +232,11 @@ class DatasetImpl : public Dataset { bool discard_remaining_ins = false); virtual void DynamicAdjustReadersNum(int thread_num); virtual void SetFleetSendSleepSeconds(int seconds); - virtual void SetHeterPs(bool enable_heterps); + /* for enable_heterps_ + virtual void EnableHeterps(bool enable_heterps) { + enable_heterps_ = enable_heterps; + } + */ std::vector>& GetMultiOutputChannel() { return multi_output_channel_; @@ -251,7 +254,10 @@ class DatasetImpl : public Dataset { protected: virtual int ReceiveFromClient(int msg_type, int client_id, - const std::string& msg); + const std::string& msg) { + // TODO(yaoxuefeng) for SlotRecordDataset + return -1; + } std::vector> readers_; std::vector> preload_readers_; paddle::framework::Channel input_channel_; @@ -327,6 +333,32 @@ class MultiSlotDataset : public DatasetImpl { const std::unordered_set& slots_to_replace, std::vector* result); virtual ~MultiSlotDataset() {} + virtual void GlobalShuffle(int thread_num = -1); + virtual void DynamicAdjustReadersNum(int thread_num); + virtual void PrepareTrain(); + + protected: + virtual int ReceiveFromClient(int msg_type, int client_id, + const std::string& msg); +}; +class SlotRecordDataset : public DatasetImpl { + public: + SlotRecordDataset() { SlotRecordPool(); } + virtual ~SlotRecordDataset() {} + // create input channel + virtual void CreateChannel(); + // create readers + virtual void CreateReaders(); + // release memory + virtual void ReleaseMemory(); + virtual void GlobalShuffle(int thread_num = -1); + virtual void DynamicAdjustChannelNum(int channel_num, + bool discard_remaining_ins); + virtual void PrepareTrain(); + virtual void DynamicAdjustReadersNum(int thread_num); + + protected: + bool enable_heterps_ = true; }; } // end namespace framework diff --git a/paddle/fluid/framework/dataset_factory.cc b/paddle/fluid/framework/dataset_factory.cc index aeaf9611853..38200927c55 100644 --- a/paddle/fluid/framework/dataset_factory.cc +++ b/paddle/fluid/framework/dataset_factory.cc @@ -53,7 +53,7 @@ std::unique_ptr DatasetFactory::CreateDataset( std::string dataset_class) { if (g_dataset_map.count(dataset_class) < 1) { LOG(WARNING) << "Your Dataset " << dataset_class - << "is not supported currently"; + << " is not supported currently"; LOG(WARNING) << "Supported Dataset: " << DatasetTypeList(); exit(-1); } @@ -61,5 +61,6 @@ std::unique_ptr DatasetFactory::CreateDataset( } REGISTER_DATASET_CLASS(MultiSlotDataset); +REGISTER_DATASET_CLASS(SlotRecordDataset); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 89a829f9490..72b95dcc153 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -680,3 +680,11 @@ PADDLE_DEFINE_EXPORTED_int32(get_host_by_name_time, 120, PADDLE_DEFINE_EXPORTED_bool( apply_pass_to_program, false, "It controls whether to apply IR pass to program when using Fleet APIs"); + +DEFINE_int32(record_pool_max_size, 2000000, + "SlotRecordDataset slot record pool max size"); +DEFINE_int32(slotpool_thread_num, 1, "SlotRecordDataset slot pool thread num"); +DEFINE_bool(enable_slotpool_wait_release, false, + "enable slotrecord obejct wait release, default false"); +DEFINE_bool(enable_slotrecord_reset_shrink, false, + "enable slotrecord obejct reset shrink memory, default false"); \ No newline at end of file diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 41cf0189d3d..7a32d8729fc 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -309,8 +309,6 @@ void BindDataset(py::module *m) { &framework::Dataset::SetFleetSendSleepSeconds, py::call_guard()) .def("enable_pv_merge", &framework::Dataset::EnablePvMerge, - py::call_guard()) - .def("set_heter_ps", &framework::Dataset::SetHeterPs, py::call_guard()); py::class_(*m, "IterableDatasetWrapper") -- GitLab