From 5223e2bbc4da204728cedbd2627c8073b2ef2edf Mon Sep 17 00:00:00 2001 From: ShenLiang <2282912238@qq.com> Date: Mon, 6 Apr 2020 15:36:36 +0800 Subject: [PATCH] Add a new DataFeed named PaddleBoxDataFeed (#23321) * add paddleboxdatafeed * add ifdef linux and boxps * add untest for datafeed * fix untest of test_paddlebox_datafeed * fix untest * rename function --- paddle/fluid/framework/data_feed.cc | 308 ++++++++++++++++++ paddle/fluid/framework/data_feed.h | 125 +++++-- paddle/fluid/framework/data_feed.proto | 2 + paddle/fluid/framework/data_feed_factory.cc | 1 + paddle/fluid/framework/data_set.cc | 208 +++++++++++- paddle/fluid/framework/data_set.h | 33 ++ paddle/fluid/pybind/data_set_py.cc | 15 + python/paddle/fluid/dataset.py | 179 ++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 2 + .../unittests/test_paddlebox_datafeed.py | 147 +++++++++ 10 files changed, 985 insertions(+), 35 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_paddlebox_datafeed.py diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index b91fe8974f..9e31b581a0 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -33,6 +33,7 @@ limitations under the License. */ #include "io/shell.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/platform/timer.h" @@ -232,6 +233,9 @@ InMemoryDataFeed::InMemoryDataFeed() { this->thread_num_ = 1; this->parse_ins_id_ = false; this->parse_content_ = false; + this->parse_logkey_ = false; + this->enable_pv_merge_ = false; + this->current_phase_ = 1; // 1:join ;0:update this->input_channel_ = nullptr; this->output_channel_ = nullptr; this->consume_channel_ = nullptr; @@ -305,6 +309,24 @@ void InMemoryDataFeed::SetConsumeChannel(void* channel) { consume_channel_ = static_cast*>(channel); } +template +void InMemoryDataFeed::SetInputPvChannel(void* channel) { + input_pv_channel_ = + static_cast*>(channel); +} + +template +void InMemoryDataFeed::SetOutputPvChannel(void* channel) { + output_pv_channel_ = + static_cast*>(channel); +} + +template +void InMemoryDataFeed::SetConsumePvChannel(void* channel) { + consume_pv_channel_ = + static_cast*>(channel); +} + template void InMemoryDataFeed::SetThreadId(int thread_id) { thread_id_ = thread_id; @@ -320,6 +342,21 @@ void InMemoryDataFeed::SetParseContent(bool parse_content) { parse_content_ = parse_content; } +template +void InMemoryDataFeed::SetParseLogKey(bool parse_logkey) { + parse_logkey_ = parse_logkey; +} + +template +void InMemoryDataFeed::SetEnablePvMerge(bool enable_pv_merge) { + enable_pv_merge_ = enable_pv_merge; +} + +template +void InMemoryDataFeed::SetCurrentPhase(int current_phase) { + current_phase_ = current_phase; +} + template void InMemoryDataFeed::SetParseInsId(bool parse_ins_id) { parse_ins_id_ = parse_ins_id; @@ -756,6 +793,20 @@ void MultiSlotInMemoryDataFeed::Init( finish_init_ = true; } +void MultiSlotInMemoryDataFeed::GetMsgFromLogKey(const std::string& log_key, + uint64_t* search_id, + uint32_t* cmatch, + uint32_t* rank) { + std::string searchid_str = log_key.substr(16, 16); + *search_id = (uint64_t)strtoull(searchid_str.c_str(), NULL, 16); + + std::string cmatch_str = log_key.substr(11, 3); + *cmatch = (uint32_t)strtoul(cmatch_str.c_str(), NULL, 16); + + std::string rank_str = log_key.substr(14, 2); + *rank = (uint32_t)strtoul(rank_str.c_str(), NULL, 16); +} + bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { #ifdef _LINUX thread_local string::LineFileReader reader; @@ -792,6 +843,26 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { pos += len + 1; VLOG(3) << "content " << instance->content_; } + if (parse_logkey_) { + int num = strtol(&str[pos], &endptr, 10); + CHECK(num == 1); // NOLINT + pos = endptr - str + 1; + size_t len = 0; + while (str[pos + len] != ' ') { + ++len; + } + // parse_logkey + std::string log_key = std::string(str + pos, len); + uint64_t search_id; + uint32_t cmatch; + uint32_t rank; + GetMsgFromLogKey(log_key, &search_id, &cmatch, &rank); + + instance->search_id = search_id; + instance->cmatch = cmatch; + instance->rank = rank; + pos += len + 1; + } for (size_t i = 0; i < use_slots_index_.size(); ++i) { int idx = use_slots_index_[i]; int num = strtol(&str[pos], &endptr, 10); @@ -1186,5 +1257,242 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() { } #endif +bool PaddleBoxDataFeed::Start() { +#ifdef _LINUX + int phase = GetCurrentPhase(); // join: 1, update: 0 + this->CheckSetFileList(); + if (enable_pv_merge_ && phase == 1) { + // join phase : input_pv_channel to output_pv_channel + if (output_pv_channel_->Size() == 0 && input_pv_channel_->Size() != 0) { + std::vector data; + input_pv_channel_->Read(data); + output_pv_channel_->Write(std::move(data)); + } + } else { + // input_channel to output + if (output_channel_->Size() == 0 && input_channel_->Size() != 0) { + std::vector data; + input_channel_->Read(data); + output_channel_->Write(std::move(data)); + } + } +#endif + this->finish_start_ = true; + return true; +} + +int PaddleBoxDataFeed::Next() { +#ifdef _LINUX + int phase = GetCurrentPhase(); // join: 1, update: 0 + this->CheckStart(); + if (enable_pv_merge_ && phase == 1) { + // join phase : output_pv_channel to consume_pv_channel + CHECK(output_pv_channel_ != nullptr); + CHECK(consume_pv_channel_ != nullptr); + VLOG(3) << "output_pv_channel_ size=" << output_pv_channel_->Size() + << ", consume_pv_channel_ size=" << consume_pv_channel_->Size() + << ", thread_id=" << thread_id_; + int index = 0; + PvInstance pv_instance; + std::vector pv_vec; + pv_vec.reserve(this->pv_batch_size_); + while (index < this->pv_batch_size_) { + if (output_pv_channel_->Size() == 0) { + break; + } + output_pv_channel_->Get(pv_instance); + pv_vec.push_back(pv_instance); + ++index; + consume_pv_channel_->Put(std::move(pv_instance)); + } + this->batch_size_ = index; + VLOG(3) << "pv_batch_size_=" << this->batch_size_ + << ", thread_id=" << thread_id_; + if (this->batch_size_ != 0) { + PutToFeedVec(pv_vec); + } else { + VLOG(3) << "finish reading, output_pv_channel_ size=" + << output_pv_channel_->Size() + << ", consume_pv_channel_ size=" << consume_pv_channel_->Size() + << ", thread_id=" << thread_id_; + } + return this->batch_size_; + } else { + this->batch_size_ = MultiSlotInMemoryDataFeed::Next(); + return this->batch_size_; + } +#else + return 0; +#endif +} + +void PaddleBoxDataFeed::Init(const DataFeedDesc& data_feed_desc) { + MultiSlotInMemoryDataFeed::Init(data_feed_desc); + rank_offset_name_ = data_feed_desc.rank_offset(); + pv_batch_size_ = data_feed_desc.pv_batch_size(); +} + +void PaddleBoxDataFeed::GetRankOffset(const std::vector& pv_vec, + int ins_number) { + int index = 0; + int max_rank = 3; // the value is setting + int row = ins_number; + int col = max_rank * 2 + 1; + int pv_num = pv_vec.size(); + + std::vector rank_offset_mat(row * col, -1); + rank_offset_mat.shrink_to_fit(); + + for (int i = 0; i < pv_num; i++) { + auto pv_ins = pv_vec[i]; + int ad_num = pv_ins->ads.size(); + int index_start = index; + for (int j = 0; j < ad_num; ++j) { + auto ins = pv_ins->ads[j]; + int rank = -1; + if ((ins->cmatch == 222 || ins->cmatch == 223) && + ins->rank <= static_cast(max_rank) && ins->rank != 0) { + rank = ins->rank; + } + + rank_offset_mat[index * col] = rank; + if (rank > 0) { + for (int k = 0; k < ad_num; ++k) { + auto cur_ins = pv_ins->ads[k]; + int fast_rank = -1; + if ((cur_ins->cmatch == 222 || cur_ins->cmatch == 223) && + cur_ins->rank <= static_cast(max_rank) && + cur_ins->rank != 0) { + fast_rank = cur_ins->rank; + } + + if (fast_rank > 0) { + int m = fast_rank - 1; + rank_offset_mat[index * col + 2 * m + 1] = cur_ins->rank; + rank_offset_mat[index * col + 2 * m + 2] = index_start + k; + } + } + } + index += 1; + } + } + + int* rank_offset = rank_offset_mat.data(); + int* tensor_ptr = rank_offset_->mutable_data({row, col}, this->place_); + CopyToFeedTensor(tensor_ptr, rank_offset, row * col * sizeof(int)); +} + +void PaddleBoxDataFeed::AssignFeedVar(const Scope& scope) { + MultiSlotInMemoryDataFeed::AssignFeedVar(scope); + // set rank offset memory + int phase = GetCurrentPhase(); // join: 1, update: 0 + if (enable_pv_merge_ && phase == 1) { + rank_offset_ = scope.FindVar(rank_offset_name_)->GetMutable(); + } +} + +void PaddleBoxDataFeed::PutToFeedVec(const std::vector& pv_vec) { +#ifdef _LINUX + int ins_number = 0; + std::vector ins_vec; + for (auto& pv : pv_vec) { + ins_number += pv->ads.size(); + for (auto ins : pv->ads) { + ins_vec.push_back(ins); + } + } + GetRankOffset(pv_vec, ins_number); + PutToFeedVec(ins_vec); +#endif +} + +int PaddleBoxDataFeed::GetCurrentPhase() { +#ifdef PADDLE_WITH_BOX_PS + auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); + return box_ptr->PassFlag(); // join: 1, update: 0 +#else + LOG(WARNING) << "It should be complied with BOX_PS..."; + return current_phase_; +#endif +} + +void PaddleBoxDataFeed::PutToFeedVec(const std::vector& ins_vec) { +#ifdef _LINUX + std::vector> batch_float_feasigns(use_slots_.size(), + std::vector()); + std::vector> batch_uint64_feasigns( + use_slots_.size(), std::vector()); + std::vector> offset(use_slots_.size(), + std::vector{0}); + std::vector visit(use_slots_.size(), false); + ins_content_vec_.clear(); + ins_content_vec_.reserve(ins_vec.size()); + ins_id_vec_.clear(); + ins_id_vec_.reserve(ins_vec.size()); + for (size_t i = 0; i < ins_vec.size(); ++i) { + auto r = ins_vec[i]; + ins_id_vec_.push_back(r->ins_id_); + ins_content_vec_.push_back(r->content_); + for (auto& item : r->float_feasigns_) { + batch_float_feasigns[item.slot()].push_back(item.sign().float_feasign_); + visit[item.slot()] = true; + } + for (auto& item : r->uint64_feasigns_) { + batch_uint64_feasigns[item.slot()].push_back(item.sign().uint64_feasign_); + visit[item.slot()] = true; + } + for (size_t j = 0; j < use_slots_.size(); ++j) { + const auto& type = all_slots_type_[j]; + if (visit[j]) { + visit[j] = false; + } else { + // fill slot value with default value 0 + if (type[0] == 'f') { // float + batch_float_feasigns[j].push_back(0.0); + } else if (type[0] == 'u') { // uint64 + batch_uint64_feasigns[j].push_back(0); + } + } + // get offset of this ins in this slot + if (type[0] == 'f') { // float + offset[j].push_back(batch_float_feasigns[j].size()); + } else if (type[0] == 'u') { // uint64 + offset[j].push_back(batch_uint64_feasigns[j].size()); + } + } + } + + for (size_t i = 0; i < use_slots_.size(); ++i) { + if (feed_vec_[i] == nullptr) { + continue; + } + int total_instance = offset[i].back(); + const auto& type = all_slots_type_[i]; + if (type[0] == 'f') { // float + float* feasign = batch_float_feasigns[i].data(); + float* tensor_ptr = + feed_vec_[i]->mutable_data({total_instance, 1}, this->place_); + CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float)); + } else if (type[0] == 'u') { // uint64 + // no uint64_t type in paddlepaddle + uint64_t* feasign = batch_uint64_feasigns[i].data(); + int64_t* tensor_ptr = feed_vec_[i]->mutable_data( + {total_instance, 1}, this->place_); + CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t)); + } + auto& slot_offset = offset[i]; + LoD data_lod{slot_offset}; + feed_vec_[i]->set_lod(data_lod); + if (use_slots_is_dense_[i]) { + if (inductive_shape_index_[i] != -1) { + use_slots_shape_[i][inductive_shape_index_[i]] = + total_instance / total_dims_without_inductive_[i]; + } + feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i])); + } + } +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 9ea9be4199..a52cadcbc0 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -58,6 +58,51 @@ namespace framework { // while (reader->Next()) { // // trainer do something // } +union FeatureKey { + uint64_t uint64_feasign_; + float float_feasign_; +}; + +struct FeatureItem { + FeatureItem() {} + FeatureItem(FeatureKey sign, uint16_t slot) { + this->sign() = sign; + this->slot() = slot; + } + FeatureKey& sign() { return *(reinterpret_cast(sign_buffer())); } + const FeatureKey& sign() const { + const FeatureKey* ret = reinterpret_cast(sign_buffer()); + return *ret; + } + uint16_t& slot() { return slot_; } + const uint16_t& slot() const { return slot_; } + + private: + char* sign_buffer() const { return const_cast(sign_); } + char sign_[sizeof(FeatureKey)]; + uint16_t slot_; +}; + +// sizeof Record is much less than std::vector +struct Record { + std::vector uint64_feasigns_; + std::vector float_feasigns_; + std::string ins_id_; + std::string content_; + uint64_t search_id; + uint32_t rank; + uint32_t cmatch; +}; + +struct PvInstanceObject { + std::vector ads; + void merge_instance(Record* ins) { ads.push_back(ins); } +}; + +using PvInstance = PvInstanceObject*; + +inline PvInstance make_pv_instance() { return new PvInstanceObject(); } + class DataFeed { public: DataFeed() { @@ -93,6 +138,13 @@ class DataFeed { // This function is used for binding feed_vec memory in a given scope virtual void AssignFeedVar(const Scope& scope); + // This function will do nothing at default + virtual void SetInputPvChannel(void* channel) {} + // This function will do nothing at default + virtual void SetOutputPvChannel(void* channel) {} + // This function will do nothing at default + virtual void SetConsumePvChannel(void* channel) {} + // This function will do nothing at default virtual void SetInputChannel(void* channel) {} // This function will do nothing at default @@ -106,6 +158,9 @@ class DataFeed { // This function will do nothing at default virtual void SetParseInsId(bool parse_ins_id) {} virtual void SetParseContent(bool parse_content) {} + virtual void SetParseLogKey(bool parse_logkey) {} + virtual void SetEnablePvMerge(bool enable_pv_merge) {} + virtual void SetCurrentPhase(int current_phase) {} virtual void SetFileListMutex(std::mutex* mutex) { mutex_for_pick_file_ = mutex; } @@ -163,6 +218,8 @@ class DataFeed { // The data read by DataFeed will be stored here std::vector feed_vec_; + LoDTensor* rank_offset_; + // the batch size defined by user int default_batch_size_; // current batch size @@ -226,6 +283,10 @@ class InMemoryDataFeed : public DataFeed { virtual void Init(const DataFeedDesc& data_feed_desc) = 0; virtual bool Start(); virtual int Next(); + virtual void SetInputPvChannel(void* channel); + virtual void SetOutputPvChannel(void* channel); + virtual void SetConsumePvChannel(void* channel); + virtual void SetInputChannel(void* channel); virtual void SetOutputChannel(void* channel); virtual void SetConsumeChannel(void* channel); @@ -233,6 +294,9 @@ class InMemoryDataFeed : public DataFeed { virtual void SetThreadNum(int thread_num); virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseContent(bool parse_content); + virtual void SetParseLogKey(bool parse_logkey); + virtual void SetEnablePvMerge(bool enable_pv_merge); + virtual void SetCurrentPhase(int current_phase); virtual void LoadIntoMemory(); protected: @@ -244,11 +308,18 @@ class InMemoryDataFeed : public DataFeed { int thread_num_; bool parse_ins_id_; bool parse_content_; + bool parse_logkey_; + bool enable_pv_merge_; + int current_phase_{-1}; // only for untest std::ifstream file_; std::shared_ptr fp_; paddle::framework::ChannelObject* input_channel_; paddle::framework::ChannelObject* output_channel_; paddle::framework::ChannelObject* consume_channel_; + + paddle::framework::ChannelObject* input_pv_channel_; + paddle::framework::ChannelObject* output_pv_channel_; + paddle::framework::ChannelObject* consume_pv_channel_; }; // This class define the data type of instance(ins_vec) in MultiSlotDataFeed @@ -408,39 +479,6 @@ paddle::framework::Archive& operator>>(paddle::framework::Archive& ar, return ar; } -union FeatureKey { - uint64_t uint64_feasign_; - float float_feasign_; -}; - -struct FeatureItem { - FeatureItem() {} - FeatureItem(FeatureKey sign, uint16_t slot) { - this->sign() = sign; - this->slot() = slot; - } - FeatureKey& sign() { return *(reinterpret_cast(sign_buffer())); } - const FeatureKey& sign() const { - const FeatureKey* ret = reinterpret_cast(sign_buffer()); - return *ret; - } - uint16_t& slot() { return slot_; } - const uint16_t& slot() const { return slot_; } - - private: - char* sign_buffer() const { return const_cast(sign_); } - char sign_[sizeof(FeatureKey)]; - uint16_t slot_; -}; - -// sizeof Record is much less than std::vector -struct Record { - std::vector uint64_feasigns_; - std::vector float_feasigns_; - std::string ins_id_; - std::string content_; -}; - struct RecordCandidate { std::string ins_id_; std::unordered_multimap feas; @@ -557,6 +595,27 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed { virtual bool ParseOneInstance(Record* instance); virtual bool ParseOneInstanceFromPipe(Record* instance); virtual void PutToFeedVec(const std::vector& ins_vec); + virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id, + uint32_t* cmatch, uint32_t* rank); +}; + +class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed { + public: + PaddleBoxDataFeed() {} + virtual ~PaddleBoxDataFeed() {} + + protected: + virtual void Init(const DataFeedDesc& data_feed_desc); + virtual bool Start(); + virtual int Next(); + virtual void AssignFeedVar(const Scope& scope); + virtual void PutToFeedVec(const std::vector& pv_vec); + virtual void PutToFeedVec(const std::vector& ins_vec); + virtual int GetCurrentPhase(); + virtual void GetRankOffset(const std::vector& pv_vec, + int ins_number); + std::string rank_offset_name_; + int pv_batch_size_; }; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index 03996e0e20..a0429a84d4 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -30,4 +30,6 @@ message DataFeedDesc { optional MultiSlotDesc multi_slot_desc = 3; optional string pipe_command = 4; optional int32 thread_num = 5; + optional string rank_offset = 6; + optional int32 pv_batch_size = 7 [ default = 32 ]; } diff --git a/paddle/fluid/framework/data_feed_factory.cc b/paddle/fluid/framework/data_feed_factory.cc index ec1acad99b..1d8aec7624 100644 --- a/paddle/fluid/framework/data_feed_factory.cc +++ b/paddle/fluid/framework/data_feed_factory.cc @@ -64,6 +64,7 @@ std::shared_ptr DataFeedFactory::CreateDataFeed( REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed); +REGISTER_DATAFEED_CLASS(PaddleBoxDataFeed); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed); #endif diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 5621770535..0684d5674a 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -46,9 +46,12 @@ DatasetImpl::DatasetImpl() { fleet_send_batch_size_ = 1024; fleet_send_sleep_seconds_ = 0; merge_by_insid_ = false; + merge_by_sid_ = true; + enable_pv_merge_ = false; merge_size_ = 2; parse_ins_id_ = false; parse_content_ = false; + parse_logkey_ = false; preload_thread_num_ = 0; global_index_ = 0; } @@ -126,6 +129,11 @@ void DatasetImpl::SetParseContent(bool parse_content) { parse_content_ = parse_content; } +template +void DatasetImpl::SetParseLogKey(bool parse_logkey) { + parse_logkey_ = parse_logkey; +} + template void DatasetImpl::SetMergeByInsId(int merge_size) { merge_by_insid_ = true; @@ -133,6 +141,16 @@ void DatasetImpl::SetMergeByInsId(int merge_size) { merge_size_ = merge_size; } +template +void DatasetImpl::SetMergeBySid(bool is_merge) { + merge_by_sid_ = is_merge; +} + +template +void DatasetImpl::SetEnablePvMerge(bool enable_pv_merge) { + enable_pv_merge_ = enable_pv_merge; +} + template void DatasetImpl::SetGenerateUniqueFeasign(bool gen_uni_feasigns) { gen_uni_feasigns_ = gen_uni_feasigns; @@ -174,6 +192,21 @@ void DatasetImpl::CreateChannel() { multi_consume_channel_.push_back(paddle::framework::MakeChannel()); } } + if (input_pv_channel_ == nullptr) { + input_pv_channel_ = paddle::framework::MakeChannel(); + } + if (multi_pv_output_.size() == 0) { + multi_pv_output_.reserve(channel_num_); + for (int i = 0; i < channel_num_; ++i) { + multi_pv_output_.push_back(paddle::framework::MakeChannel()); + } + } + if (multi_pv_consume_.size() == 0) { + multi_pv_consume_.reserve(channel_num_); + for (int i = 0; i < channel_num_; ++i) { + multi_pv_consume_.push_back(paddle::framework::MakeChannel()); + } + } } // if sent message between workers, should first call this function @@ -206,6 +239,7 @@ void DatasetImpl::LoadIntoMemory() { input_channel_->Close(); int64_t in_chan_size = input_channel_->Size(); input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1); + timeline.Pause(); VLOG(3) << "DatasetImpl::LoadIntoMemory() end" << ", memory data size=" << input_channel_->Size() @@ -270,6 +304,27 @@ void DatasetImpl::ReleaseMemory() { multi_consume_channel_[i] = nullptr; } std::vector>().swap(multi_consume_channel_); + if (input_pv_channel_) { + input_pv_channel_->Clear(); + input_pv_channel_ = nullptr; + } + for (size_t i = 0; i < multi_pv_output_.size(); ++i) { + if (!multi_pv_output_[i]) { + continue; + } + multi_pv_output_[i]->Clear(); + multi_pv_output_[i] = nullptr; + } + std::vector>().swap(multi_pv_output_); + for (size_t i = 0; i < multi_pv_consume_.size(); ++i) { + if (!multi_pv_consume_[i]) { + continue; + } + multi_pv_consume_[i]->Clear(); + multi_pv_consume_[i] = nullptr; + } + std::vector>().swap(multi_pv_consume_); + std::vector>().swap(readers_); VLOG(3) << "DatasetImpl::ReleaseMemory() end"; } @@ -412,6 +467,11 @@ void DatasetImpl::DynamicAdjustChannelNum(int channel_num, channel_num_ = channel_num; std::vector>* origin_channels = nullptr; std::vector>* other_channels = nullptr; + std::vector>* origin_pv_channels = + nullptr; + std::vector>* other_pv_channels = + nullptr; + // find out which channel (output or consume) has data int cur_channel = 0; uint64_t output_channels_data_size = 0; @@ -431,17 +491,26 @@ void DatasetImpl::DynamicAdjustChannelNum(int channel_num, if (cur_channel == 0) { origin_channels = &multi_output_channel_; other_channels = &multi_consume_channel_; + origin_pv_channels = &multi_pv_output_; + other_pv_channels = &multi_pv_consume_; } else { origin_channels = &multi_consume_channel_; other_channels = &multi_output_channel_; + origin_pv_channels = &multi_pv_consume_; + other_pv_channels = &multi_pv_output_; } - CHECK(origin_channels != nullptr); // NOLINT - CHECK(other_channels != nullptr); // NOLINT + CHECK(origin_channels != nullptr); // NOLINT + CHECK(other_channels != nullptr); // NOLINT + CHECK(origin_pv_channels != nullptr); // NOLINT + CHECK(other_pv_channels != nullptr); // NOLINT paddle::framework::Channel total_data_channel = paddle::framework::MakeChannel(); std::vector> new_channels; std::vector> new_other_channels; + std::vector> new_pv_channels; + std::vector> new_other_pv_channels; + std::vector local_vec; for (size_t i = 0; i < origin_channels->size(); ++i) { local_vec.clear(); @@ -458,6 +527,12 @@ void DatasetImpl::DynamicAdjustChannelNum(int channel_num, input_channel_->SetBlockSize(input_channel_->Size() / channel_num + (discard_remaining_ins ? 0 : 1)); } + if (static_cast(input_pv_channel_->Size()) >= channel_num) { + input_pv_channel_->SetBlockSize(input_pv_channel_->Size() / channel_num + + (discard_remaining_ins ? 0 : 1)); + VLOG(3) << "now input_pv_channle block size is " + << input_pv_channel_->BlockSize(); + } for (int i = 0; i < channel_num; ++i) { local_vec.clear(); @@ -465,6 +540,9 @@ void DatasetImpl::DynamicAdjustChannelNum(int channel_num, new_other_channels.push_back(paddle::framework::MakeChannel()); new_channels.push_back(paddle::framework::MakeChannel()); new_channels[i]->Write(std::move(local_vec)); + new_other_pv_channels.push_back( + paddle::framework::MakeChannel()); + new_pv_channels.push_back(paddle::framework::MakeChannel()); } total_data_channel->Clear(); @@ -473,10 +551,22 @@ void DatasetImpl::DynamicAdjustChannelNum(int channel_num, *origin_channels = new_channels; *other_channels = new_other_channels; + origin_pv_channels->clear(); + other_pv_channels->clear(); + *origin_pv_channels = new_pv_channels; + *other_pv_channels = new_other_pv_channels; + new_channels.clear(); new_other_channels.clear(); std::vector>().swap(new_channels); std::vector>().swap(new_other_channels); + + new_pv_channels.clear(); + new_other_pv_channels.clear(); + std::vector>().swap(new_pv_channels); + std::vector>().swap( + new_other_pv_channels); + local_vec.clear(); std::vector().swap(local_vec); VLOG(3) << "adjust channel num done"; @@ -528,17 +618,30 @@ void DatasetImpl::CreateReaders() { readers_[i]->SetFileList(filelist_); readers_[i]->SetParseInsId(parse_ins_id_); readers_[i]->SetParseContent(parse_content_); + readers_[i]->SetParseLogKey(parse_logkey_); + readers_[i]->SetEnablePvMerge(enable_pv_merge_); + // Notice: it is only valid for untest of test_paddlebox_datafeed. + // In fact, it does not affect the train process when paddle is + // complied with Box_Ps. + readers_[i]->SetCurrentPhase(current_phase_); if (input_channel_ != nullptr) { readers_[i]->SetInputChannel(input_channel_.get()); } + if (input_pv_channel_ != nullptr) { + readers_[i]->SetInputPvChannel(input_pv_channel_.get()); + } if (cur_channel_ == 0 && static_cast(channel_idx) < multi_output_channel_.size()) { readers_[i]->SetOutputChannel(multi_output_channel_[channel_idx].get()); readers_[i]->SetConsumeChannel(multi_consume_channel_[channel_idx].get()); + readers_[i]->SetOutputPvChannel(multi_pv_output_[channel_idx].get()); + readers_[i]->SetConsumePvChannel(multi_pv_consume_[channel_idx].get()); } else if (static_cast(channel_idx) < multi_output_channel_.size()) { readers_[i]->SetOutputChannel(multi_consume_channel_[channel_idx].get()); readers_[i]->SetConsumeChannel(multi_output_channel_[channel_idx].get()); + readers_[i]->SetOutputPvChannel(multi_pv_consume_[channel_idx].get()); + readers_[i]->SetConsumePvChannel(multi_pv_output_[channel_idx].get()); } ++channel_idx; if (channel_idx >= channel_num_) { @@ -583,9 +686,13 @@ void DatasetImpl::CreatePreLoadReaders() { preload_readers_[i]->SetFileList(filelist_); preload_readers_[i]->SetParseInsId(parse_ins_id_); preload_readers_[i]->SetParseContent(parse_content_); + preload_readers_[i]->SetParseLogKey(parse_logkey_); + preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_); preload_readers_[i]->SetInputChannel(input_channel_.get()); preload_readers_[i]->SetOutputChannel(nullptr); preload_readers_[i]->SetConsumeChannel(nullptr); + preload_readers_[i]->SetOutputPvChannel(nullptr); + preload_readers_[i]->SetConsumePvChannel(nullptr); } VLOG(3) << "End CreatePreLoadReaders"; } @@ -605,6 +712,16 @@ int64_t DatasetImpl::GetMemoryDataSize() { return input_channel_->Size(); } +template +int64_t DatasetImpl::GetPvDataSize() { + if (enable_pv_merge_) { + return input_pv_channel_->Size(); + } else { + VLOG(0) << "It does not merge pv.."; + return 0; + } +} + template int64_t DatasetImpl::GetShuffleDataSize() { int64_t sum = 0; @@ -657,6 +774,92 @@ int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, // explicit instantiation template class DatasetImpl; +void MultiSlotDataset::PostprocessInstance() { + // divide pv instance, and merge to input_channel_ + if (enable_pv_merge_) { + input_channel_->Open(); + input_channel_->Write(std::move(input_records_)); + for (size_t i = 0; i < multi_pv_consume_.size(); ++i) { + multi_pv_consume_[i]->Clear(); + } + input_channel_->Close(); + input_records_.clear(); + input_records_.shrink_to_fit(); + } else { + input_channel_->Open(); + for (size_t i = 0; i < multi_consume_channel_.size(); ++i) { + std::vector ins_data; + multi_consume_channel_[i]->Close(); + multi_consume_channel_[i]->ReadAll(ins_data); + input_channel_->Write(std::move(ins_data)); + ins_data.clear(); + ins_data.shrink_to_fit(); + multi_consume_channel_[i]->Clear(); + } + input_channel_->Close(); + } + this->LocalShuffle(); +} + +void MultiSlotDataset::SetCurrentPhase(int current_phase) { + current_phase_ = current_phase; +} + +void MultiSlotDataset::PreprocessInstance() { + if (!input_channel_ || input_channel_->Size() == 0) { + return; + } + if (!enable_pv_merge_) { // means to use Record + this->LocalShuffle(); + } else { // means to use Pv + auto fleet_ptr = FleetWrapper::GetInstance(); + input_channel_->Close(); + std::vector pv_data; + input_channel_->ReadAll(input_records_); + int all_records_num = input_records_.size(); + std::vector all_records; + all_records.reserve(all_records_num); + for (int index = 0; index < all_records_num; ++index) { + all_records.push_back(&input_records_[index]); + } + + std::sort(all_records.data(), all_records.data() + all_records_num, + [](const Record* lhs, const Record* rhs) { + return lhs->search_id < rhs->search_id; + }); + if (merge_by_sid_) { + uint64_t last_search_id = 0; + for (int i = 0; i < all_records_num; ++i) { + Record* ins = all_records[i]; + if (i == 0 || last_search_id != ins->search_id) { + PvInstance pv_instance = make_pv_instance(); + pv_instance->merge_instance(ins); + pv_data.push_back(pv_instance); + last_search_id = ins->search_id; + continue; + } + pv_data.back()->merge_instance(ins); + } + } else { + for (int i = 0; i < all_records_num; ++i) { + Record* ins = all_records[i]; + PvInstance pv_instance = make_pv_instance(); + pv_instance->merge_instance(ins); + pv_data.push_back(pv_instance); + } + } + + std::shuffle(pv_data.begin(), pv_data.end(), + fleet_ptr->LocalRandomEngine()); + input_pv_channel_->Open(); + input_pv_channel_->Write(std::move(pv_data)); + + pv_data.clear(); + pv_data.shrink_to_fit(); + input_pv_channel_->Close(); + } +} + void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim, int read_thread_num, int consume_thread_num, @@ -736,6 +939,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim, consume_task_pool_.clear(); fleet_ptr_->PullSparseToLocal(table_id, feadim); } + void MultiSlotDataset::MergeByInsId() { VLOG(3) << "MultiSlotDataset::MergeByInsId begin"; if (!merge_by_insid_) { diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index df8bbc33e7..7adef69a44 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -65,6 +65,9 @@ class Dataset { // set parse ins id virtual void SetParseInsId(bool parse_ins_id) = 0; virtual void SetParseContent(bool parse_content) = 0; + virtual void SetParseLogKey(bool parse_logkey) = 0; + virtual void SetEnablePvMerge(bool enable_pv_merge) = 0; + virtual void SetMergeBySid(bool is_merge) = 0; // set merge by ins id virtual void SetMergeByInsId(int merge_size) = 0; virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0; @@ -115,10 +118,18 @@ class Dataset { virtual void DestroyReaders() = 0; // get memory data size virtual int64_t GetMemoryDataSize() = 0; + // get memory data size in input_pv_channel_ + virtual int64_t GetPvDataSize() = 0; // get shuffle data size virtual int64_t GetShuffleDataSize() = 0; // merge by ins id virtual void MergeByInsId() = 0; + // merge pv instance + virtual void PreprocessInstance() = 0; + // divide pv instance + virtual void PostprocessInstance() = 0; + // only for untest + virtual void SetCurrentPhase(int current_phase) = 0; virtual void GenerateLocalTablesUnlock(int table_id, int feadim, int read_thread_num, int consume_thread_num, @@ -161,6 +172,10 @@ class DatasetImpl : public Dataset { virtual void SetChannelNum(int channel_num); virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseContent(bool parse_content); + virtual void SetParseLogKey(bool parse_logkey); + virtual void SetEnablePvMerge(bool enable_pv_merge); + virtual void SetMergeBySid(bool is_merge); + virtual void SetMergeByInsId(int merge_size); virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns); virtual void SetFeaEval(bool fea_eval, int record_candidate_size); @@ -192,8 +207,12 @@ class DatasetImpl : public Dataset { virtual void CreateReaders(); virtual void DestroyReaders(); virtual int64_t GetMemoryDataSize(); + virtual int64_t GetPvDataSize(); virtual int64_t GetShuffleDataSize(); virtual void MergeByInsId() {} + virtual void PreprocessInstance() {} + virtual void PostprocessInstance() {} + virtual void SetCurrentPhase(int current_phase) {} virtual void GenerateLocalTablesUnlock(int table_id, int feadim, int read_thread_num, int consume_thread_num, @@ -213,6 +232,10 @@ class DatasetImpl : public Dataset { std::vector> readers_; std::vector> preload_readers_; paddle::framework::Channel input_channel_; + paddle::framework::Channel input_pv_channel_; + std::vector> multi_pv_output_; + std::vector> multi_pv_consume_; + int channel_num_; std::vector> multi_output_channel_; std::vector> multi_consume_channel_; @@ -238,6 +261,10 @@ class DatasetImpl : public Dataset { bool merge_by_insid_; bool parse_ins_id_; bool parse_content_; + bool parse_logkey_; + bool merge_by_sid_; + bool enable_pv_merge_; // True means to merge pv + int current_phase_; // 1 join, 0 update size_t merge_size_; bool slots_shuffle_fea_eval_ = false; bool gen_uni_feasigns_ = false; @@ -252,6 +279,9 @@ class MultiSlotDataset : public DatasetImpl { public: MultiSlotDataset() {} virtual void MergeByInsId(); + virtual void PreprocessInstance(); + virtual void PostprocessInstance(); + virtual void SetCurrentPhase(int current_phase); virtual void GenerateLocalTablesUnlock(int table_id, int feadim, int read_thread_num, int consume_thread_num, int shard_num); @@ -266,6 +296,9 @@ class MultiSlotDataset : public DatasetImpl { virtual void GetRandomData(const std::set& slots_to_replace, std::vector* result); virtual ~MultiSlotDataset() {} + + protected: + std::vector input_records_; // the real data }; } // end namespace framework diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index bd3aa4e498..4b12f66c61 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -239,6 +239,8 @@ void BindDataset(py::module *m) { py::call_guard()) .def("get_memory_data_size", &framework::Dataset::GetMemoryDataSize, py::call_guard()) + .def("get_pv_data_size", &framework::Dataset::GetPvDataSize, + py::call_guard()) .def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize, py::call_guard()) .def("set_queue_num", &framework::Dataset::SetChannelNum, @@ -247,6 +249,19 @@ void BindDataset(py::module *m) { py::call_guard()) .def("set_parse_content", &framework::Dataset::SetParseContent, py::call_guard()) + .def("set_parse_logkey", &framework::Dataset::SetParseLogKey, + py::call_guard()) + .def("set_merge_by_sid", &framework::Dataset::SetMergeBySid, + py::call_guard()) + .def("preprocess_instance", &framework::Dataset::PreprocessInstance, + py::call_guard()) + .def("postprocess_instance", &framework::Dataset::PostprocessInstance, + py::call_guard()) + .def("set_current_phase", &framework::Dataset::SetCurrentPhase, + py::call_guard()) + .def("set_enable_pv_merge", &framework::Dataset::SetEnablePvMerge, + py::call_guard()) + .def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId, py::call_guard()) .def("merge_by_lineid", &framework::Dataset::MergeByInsId, diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 97900d02cb..a125cd4013 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -92,6 +92,23 @@ class DatasetBase(object): """ self.proto_desc.pipe_command = pipe_command + def set_rank_offset(self, rank_offset): + """ + Set rank_offset for merge_pv. It set the message of Pv. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_rank_offset("rank_offset") + + Args: + rank_offset(str): rank_offset's name + + """ + self.proto_desc.rank_offset = rank_offset + def set_fea_eval(self, record_candidate_size, fea_eval=True): """ set fea eval mode for slots shuffle to debug the importance level of @@ -154,6 +171,22 @@ class DatasetBase(object): """ self.proto_desc.batch_size = batch_size + def set_pv_batch_size(self, pv_batch_size): + """ + Set pv batch size. It will be effective during enable_pv_merge + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_pv_batch(128) + Args: + pv_batch_size(int): pv batch size + + """ + self.proto_desc.pv_batch_size = pv_batch_size + def set_thread(self, thread_num): """ Set thread num, it is the num of readers. @@ -308,9 +341,18 @@ class InMemoryDataset(DatasetBase): self.queue_num = None self.parse_ins_id = False self.parse_content = False + self.parse_logkey = False + self.merge_by_sid = True + self.enable_pv_merge = False self.merge_by_lineid = False self.fleet_send_sleep_seconds = None + def set_feed_type(self, data_feed_type): + """ + Set data_feed_desc + """ + self.proto_desc.name = data_feed_type + def _prepare_to_run(self): """ Set data_feed_desc before load or shuffle, @@ -324,6 +366,9 @@ class InMemoryDataset(DatasetBase): self.dataset.set_queue_num(self.queue_num) self.dataset.set_parse_ins_id(self.parse_ins_id) self.dataset.set_parse_content(self.parse_content) + self.dataset.set_parse_logkey(self.parse_logkey) + self.dataset.set_merge_by_sid(self.merge_by_sid) + self.dataset.set_enable_pv_merge(self.enable_pv_merge) self.dataset.set_data_feed_desc(self.desc()) self.dataset.create_channel() self.dataset.create_readers() @@ -390,6 +435,112 @@ class InMemoryDataset(DatasetBase): """ self.parse_content = parse_content + def set_parse_logkey(self, parse_logkey): + """ + Set if Dataset need to parse logkey + + Args: + parse_content(bool): if parse logkey or not + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_parse_logkey(True) + + """ + self.parse_logkey = parse_logkey + + def set_merge_by_sid(self, merge_by_sid): + """ + Set if Dataset need to merge sid. If not, one ins means one Pv. + + Args: + merge_by_sid(bool): if merge sid or not + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_merge_by_sid(True) + + """ + self.merge_by_sid = merge_by_sid + + def set_enable_pv_merge(self, enable_pv_merge): + """ + Set if Dataset need to merge pv. + + Args: + enable_pv_merge(bool): if enable_pv_merge or not + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_enable_pv_merge(True) + + """ + self.enable_pv_merge = enable_pv_merge + + def preprocess_instance(self): + """ + Merge pv instance and convey it from input_channel to input_pv_channel. + It will be effective when enable_pv_merge_ is True. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.preprocess_instance() + + """ + self.dataset.preprocess_instance() + + def set_current_phase(self, current_phase): + """ + Set current phase in train. It is useful for untest. + current_phase : 1 for join, 0 for update. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.set_current_phase(1) + + """ + self.dataset.set_current_phase(current_phase) + + def postprocess_instance(self): + """ + Divide pv instance and convey it to input_channel. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + dataset.preprocess_instance() + exe.train_from_dataset(dataset) + dataset.postprocess_instance() + + """ + self.dataset.postprocess_instance() + def set_fleet_send_batch_size(self, fleet_send_batch_size=1024): """ Set fleet send batch size, default is 1024 @@ -594,6 +745,30 @@ class InMemoryDataset(DatasetBase): """ self.dataset.release_memory() + def get_pv_data_size(self): + """ + Get memory data size of Pv, user can call this function to know the pv num + of ins in all workers after load into memory. + + Note: + This function may cause bad performance, because it has barrier + + Returns: + The size of memory pv data. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + filelist = ["a.txt", "b.txt"] + dataset.set_filelist(filelist) + dataset.load_into_memory() + print dataset.get_pv_data_size() + + """ + return self.dataset.get_pv_data_size() + def get_memory_data_size(self, fleet=None): """ Get memory data size, user can call this function to know the num @@ -808,6 +983,7 @@ class BoxPSDataset(InMemoryDataset): """ super(BoxPSDataset, self).__init__() self.boxps = core.BoxPS(self.dataset) + self.proto_desc.name = "PaddleBoxDataFeed" def set_date(self, date): """ @@ -895,3 +1071,6 @@ class BoxPSDataset(InMemoryDataset): if not self.is_user_set_queue_num: self.dataset.dynamic_adjust_channel_num(thread_num, True) self.dataset.dynamic_adjust_readers_num(thread_num) + + def _dynamic_adjust_after_train(self): + pass diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 6e50108383..7c191cd950 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -44,6 +44,7 @@ endif() if(WIN32) LIST(REMOVE_ITEM TEST_OPS test_boxps) + LIST(REMOVE_ITEM TEST_OPS test_paddlebox_datafeed) LIST(REMOVE_ITEM TEST_OPS test_trainer_desc) LIST(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception) LIST(REMOVE_ITEM TEST_OPS test_avoid_twice_initialization) @@ -59,6 +60,7 @@ endif() if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_pipeline) LIST(REMOVE_ITEM TEST_OPS test_boxps) + LIST(REMOVE_ITEM TEST_OPS test_paddlebox_datafeed) endif() list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185 diff --git a/python/paddle/fluid/tests/unittests/test_paddlebox_datafeed.py b/python/paddle/fluid/tests/unittests/test_paddlebox_datafeed.py new file mode 100644 index 0000000000..35bc144989 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_paddlebox_datafeed.py @@ -0,0 +1,147 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function +import paddle.fluid as fluid +import paddle.fluid.core as core +import os +import unittest +import paddle.fluid.layers as layers +from paddle.fluid.layers.nn import _pull_box_sparse + + +class TestDataFeed(unittest.TestCase): + """ TestBaseCase(Merge PV) """ + + def setUp(self): + self.batch_size = 10 + self.pv_batch_size = 10 + self.enable_pv_merge = True + self.merge_by_sid = True + + def set_data_config(self): + self.dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset") + self.dataset.set_feed_type("PaddleBoxDataFeed") + self.dataset.set_parse_logkey(True) + self.dataset.set_thread(1) + self.dataset.set_enable_pv_merge(self.enable_pv_merge) + self.dataset.set_batch_size(self.batch_size) + if self.enable_pv_merge: + self.dataset.set_merge_by_sid(self.merge_by_sid) + self.dataset.set_rank_offset("rank_offset") + self.dataset.set_pv_batch_size(self.pv_batch_size) + + def test_pboxdatafeed(self): + self.run_dataset(False) + + def test_pboxdatafeed(self): + self.run_dataset(True) + + def run_dataset(self, is_cpu): + x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0) + y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0) + rank_offset = fluid.layers.data( + name="rank_offset", + shape=[-1, 7], + dtype="int32", + lod_level=0, + append_batch_size=False) + + emb_x, emb_y = _pull_box_sparse([x, y], size=2) + emb_xp = _pull_box_sparse(x, size=2) + concat = layers.concat([emb_x, emb_y], axis=1) + fc = layers.fc(input=concat, + name="fc", + size=1, + num_flatten_dims=1, + bias_attr=False) + loss = layers.reduce_mean(fc) + place = fluid.CPUPlace() if is_cpu or not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + exe = fluid.Executor(place) + + with open("test_run_with_dump_a.txt", "w") as f: + data = "1 1702f830eee19501ad7429505f714c1d 1 1 1 9\n" + data += "1 1702f830eee19502ad7429505f714c1d 1 2 1 8\n" + data += "1 1702f830eee19503ad7429505f714c1d 1 3 1 7\n" + data += "1 1702f830eee0de01ad7429505f714c2d 1 4 1 6\n" + data += "1 1702f830eee0df01ad7429505f714c3d 1 5 1 5\n" + data += "1 1702f830eee0df02ad7429505f714c3d 1 6 1 4\n" + f.write(data) + with open("test_run_with_dump_b.txt", "w") as f: + data = "1 1702f830fff22201ad7429505f715c1d 1 1 1 1\n" + data += "1 1702f830fff22202ad7429505f715c1d 1 2 1 2\n" + data += "1 1702f830fff22203ad7429505f715c1d 1 3 1 3\n" + data += "1 1702f830fff22101ad7429505f714ccd 1 4 1 4\n" + data += "1 1702f830fff22102ad7429505f714ccd 1 5 1 5\n" + data += "1 1702f830fff22103ad7429505f714ccd 1 6 1 6\n" + data += "1 1702f830fff22104ad7429505f714ccd 1 6 1 7\n" + f.write(data) + + self.set_data_config() + self.dataset.set_use_var([x, y]) + self.dataset.set_filelist( + ["test_run_with_dump_a.txt", "test_run_with_dump_b.txt"]) + + optimizer = fluid.optimizer.SGD(learning_rate=0.5) + optimizer = fluid.optimizer.PipelineOptimizer( + optimizer, + cut_list=[], + place_list=[place], + concurrency_list=[1], + queue_size=1, + sync_steps=-1) + optimizer.minimize(loss) + exe.run(fluid.default_startup_program()) + self.dataset.set_current_phase(1) + self.dataset.load_into_memory() + self.dataset.preprocess_instance() + self.dataset.begin_pass() + pv_num = self.dataset.get_pv_data_size() + + exe.train_from_dataset( + program=fluid.default_main_program(), + dataset=self.dataset, + print_period=1) + self.dataset.set_current_phase(0) + self.dataset.postprocess_instance() + exe.train_from_dataset( + program=fluid.default_main_program(), + dataset=self.dataset, + print_period=1) + self.dataset.end_pass(True) + os.remove("test_run_with_dump_a.txt") + os.remove("test_run_with_dump_b.txt") + + +class TestDataFeed2(TestDataFeed): + """ TestBaseCase(Merge PV not merge by sid) """ + + def setUp(self): + self.batch_size = 10 + self.pv_batch_size = 10 + self.enable_pv_merge = True + self.merge_by_sid = False + + +class TestDataFeed3(TestDataFeed): + """ TestBaseCase(Not Merge PV) """ + + def setUp(self): + self.batch_size = 10 + self.pv_batch_size = 10 + self.enable_pv_merge = False + + +if __name__ == '__main__': + unittest.main() -- GitLab