未验证 提交 5223e2bb 编写于 作者: S ShenLiang 提交者: GitHub

Add a new DataFeed named PaddleBoxDataFeed (#23321)

* add paddleboxdatafeed
* add ifdef linux and boxps
* add untest for datafeed
* fix untest of test_paddlebox_datafeed
* fix untest
* rename function
上级 75bd3507
...@@ -33,6 +33,7 @@ limitations under the License. */ ...@@ -33,6 +33,7 @@ limitations under the License. */
#include "io/shell.h" #include "io/shell.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
...@@ -232,6 +233,9 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -232,6 +233,9 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_num_ = 1; this->thread_num_ = 1;
this->parse_ins_id_ = false; this->parse_ins_id_ = false;
this->parse_content_ = false; this->parse_content_ = false;
this->parse_logkey_ = false;
this->enable_pv_merge_ = false;
this->current_phase_ = 1; // 1:join ;0:update
this->input_channel_ = nullptr; this->input_channel_ = nullptr;
this->output_channel_ = nullptr; this->output_channel_ = nullptr;
this->consume_channel_ = nullptr; this->consume_channel_ = nullptr;
...@@ -305,6 +309,24 @@ void InMemoryDataFeed<T>::SetConsumeChannel(void* channel) { ...@@ -305,6 +309,24 @@ void InMemoryDataFeed<T>::SetConsumeChannel(void* channel) {
consume_channel_ = static_cast<paddle::framework::ChannelObject<T>*>(channel); consume_channel_ = static_cast<paddle::framework::ChannelObject<T>*>(channel);
} }
template <typename T>
void InMemoryDataFeed<T>::SetInputPvChannel(void* channel) {
input_pv_channel_ =
static_cast<paddle::framework::ChannelObject<PvInstance>*>(channel);
}
template <typename T>
void InMemoryDataFeed<T>::SetOutputPvChannel(void* channel) {
output_pv_channel_ =
static_cast<paddle::framework::ChannelObject<PvInstance>*>(channel);
}
template <typename T>
void InMemoryDataFeed<T>::SetConsumePvChannel(void* channel) {
consume_pv_channel_ =
static_cast<paddle::framework::ChannelObject<PvInstance>*>(channel);
}
template <typename T> template <typename T>
void InMemoryDataFeed<T>::SetThreadId(int thread_id) { void InMemoryDataFeed<T>::SetThreadId(int thread_id) {
thread_id_ = thread_id; thread_id_ = thread_id;
...@@ -320,6 +342,21 @@ void InMemoryDataFeed<T>::SetParseContent(bool parse_content) { ...@@ -320,6 +342,21 @@ void InMemoryDataFeed<T>::SetParseContent(bool parse_content) {
parse_content_ = parse_content; parse_content_ = parse_content;
} }
template <typename T>
void InMemoryDataFeed<T>::SetParseLogKey(bool parse_logkey) {
parse_logkey_ = parse_logkey;
}
template <typename T>
void InMemoryDataFeed<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge;
}
template <typename T>
void InMemoryDataFeed<T>::SetCurrentPhase(int current_phase) {
current_phase_ = current_phase;
}
template <typename T> template <typename T>
void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) { void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id; parse_ins_id_ = parse_ins_id;
...@@ -756,6 +793,20 @@ void MultiSlotInMemoryDataFeed::Init( ...@@ -756,6 +793,20 @@ void MultiSlotInMemoryDataFeed::Init(
finish_init_ = true; finish_init_ = true;
} }
void MultiSlotInMemoryDataFeed::GetMsgFromLogKey(const std::string& log_key,
uint64_t* search_id,
uint32_t* cmatch,
uint32_t* rank) {
std::string searchid_str = log_key.substr(16, 16);
*search_id = (uint64_t)strtoull(searchid_str.c_str(), NULL, 16);
std::string cmatch_str = log_key.substr(11, 3);
*cmatch = (uint32_t)strtoul(cmatch_str.c_str(), NULL, 16);
std::string rank_str = log_key.substr(14, 2);
*rank = (uint32_t)strtoul(rank_str.c_str(), NULL, 16);
}
bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
#ifdef _LINUX #ifdef _LINUX
thread_local string::LineFileReader reader; thread_local string::LineFileReader reader;
...@@ -792,6 +843,26 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -792,6 +843,26 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
pos += len + 1; pos += len + 1;
VLOG(3) << "content " << instance->content_; VLOG(3) << "content " << instance->content_;
} }
if (parse_logkey_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
// parse_logkey
std::string log_key = std::string(str + pos, len);
uint64_t search_id;
uint32_t cmatch;
uint32_t rank;
GetMsgFromLogKey(log_key, &search_id, &cmatch, &rank);
instance->search_id = search_id;
instance->cmatch = cmatch;
instance->rank = rank;
pos += len + 1;
}
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
...@@ -1186,5 +1257,242 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() { ...@@ -1186,5 +1257,242 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
} }
#endif #endif
bool PaddleBoxDataFeed::Start() {
#ifdef _LINUX
int phase = GetCurrentPhase(); // join: 1, update: 0
this->CheckSetFileList();
if (enable_pv_merge_ && phase == 1) {
// join phase : input_pv_channel to output_pv_channel
if (output_pv_channel_->Size() == 0 && input_pv_channel_->Size() != 0) {
std::vector<PvInstance> data;
input_pv_channel_->Read(data);
output_pv_channel_->Write(std::move(data));
}
} else {
// input_channel to output
if (output_channel_->Size() == 0 && input_channel_->Size() != 0) {
std::vector<Record> data;
input_channel_->Read(data);
output_channel_->Write(std::move(data));
}
}
#endif
this->finish_start_ = true;
return true;
}
int PaddleBoxDataFeed::Next() {
#ifdef _LINUX
int phase = GetCurrentPhase(); // join: 1, update: 0
this->CheckStart();
if (enable_pv_merge_ && phase == 1) {
// join phase : output_pv_channel to consume_pv_channel
CHECK(output_pv_channel_ != nullptr);
CHECK(consume_pv_channel_ != nullptr);
VLOG(3) << "output_pv_channel_ size=" << output_pv_channel_->Size()
<< ", consume_pv_channel_ size=" << consume_pv_channel_->Size()
<< ", thread_id=" << thread_id_;
int index = 0;
PvInstance pv_instance;
std::vector<PvInstance> pv_vec;
pv_vec.reserve(this->pv_batch_size_);
while (index < this->pv_batch_size_) {
if (output_pv_channel_->Size() == 0) {
break;
}
output_pv_channel_->Get(pv_instance);
pv_vec.push_back(pv_instance);
++index;
consume_pv_channel_->Put(std::move(pv_instance));
}
this->batch_size_ = index;
VLOG(3) << "pv_batch_size_=" << this->batch_size_
<< ", thread_id=" << thread_id_;
if (this->batch_size_ != 0) {
PutToFeedVec(pv_vec);
} else {
VLOG(3) << "finish reading, output_pv_channel_ size="
<< output_pv_channel_->Size()
<< ", consume_pv_channel_ size=" << consume_pv_channel_->Size()
<< ", thread_id=" << thread_id_;
}
return this->batch_size_;
} else {
this->batch_size_ = MultiSlotInMemoryDataFeed::Next();
return this->batch_size_;
}
#else
return 0;
#endif
}
void PaddleBoxDataFeed::Init(const DataFeedDesc& data_feed_desc) {
MultiSlotInMemoryDataFeed::Init(data_feed_desc);
rank_offset_name_ = data_feed_desc.rank_offset();
pv_batch_size_ = data_feed_desc.pv_batch_size();
}
void PaddleBoxDataFeed::GetRankOffset(const std::vector<PvInstance>& pv_vec,
int ins_number) {
int index = 0;
int max_rank = 3; // the value is setting
int row = ins_number;
int col = max_rank * 2 + 1;
int pv_num = pv_vec.size();
std::vector<int> rank_offset_mat(row * col, -1);
rank_offset_mat.shrink_to_fit();
for (int i = 0; i < pv_num; i++) {
auto pv_ins = pv_vec[i];
int ad_num = pv_ins->ads.size();
int index_start = index;
for (int j = 0; j < ad_num; ++j) {
auto ins = pv_ins->ads[j];
int rank = -1;
if ((ins->cmatch == 222 || ins->cmatch == 223) &&
ins->rank <= static_cast<uint32_t>(max_rank) && ins->rank != 0) {
rank = ins->rank;
}
rank_offset_mat[index * col] = rank;
if (rank > 0) {
for (int k = 0; k < ad_num; ++k) {
auto cur_ins = pv_ins->ads[k];
int fast_rank = -1;
if ((cur_ins->cmatch == 222 || cur_ins->cmatch == 223) &&
cur_ins->rank <= static_cast<uint32_t>(max_rank) &&
cur_ins->rank != 0) {
fast_rank = cur_ins->rank;
}
if (fast_rank > 0) {
int m = fast_rank - 1;
rank_offset_mat[index * col + 2 * m + 1] = cur_ins->rank;
rank_offset_mat[index * col + 2 * m + 2] = index_start + k;
}
}
}
index += 1;
}
}
int* rank_offset = rank_offset_mat.data();
int* tensor_ptr = rank_offset_->mutable_data<int>({row, col}, this->place_);
CopyToFeedTensor(tensor_ptr, rank_offset, row * col * sizeof(int));
}
void PaddleBoxDataFeed::AssignFeedVar(const Scope& scope) {
MultiSlotInMemoryDataFeed::AssignFeedVar(scope);
// set rank offset memory
int phase = GetCurrentPhase(); // join: 1, update: 0
if (enable_pv_merge_ && phase == 1) {
rank_offset_ = scope.FindVar(rank_offset_name_)->GetMutable<LoDTensor>();
}
}
void PaddleBoxDataFeed::PutToFeedVec(const std::vector<PvInstance>& pv_vec) {
#ifdef _LINUX
int ins_number = 0;
std::vector<Record*> ins_vec;
for (auto& pv : pv_vec) {
ins_number += pv->ads.size();
for (auto ins : pv->ads) {
ins_vec.push_back(ins);
}
}
GetRankOffset(pv_vec, ins_number);
PutToFeedVec(ins_vec);
#endif
}
int PaddleBoxDataFeed::GetCurrentPhase() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
return box_ptr->PassFlag(); // join: 1, update: 0
#else
LOG(WARNING) << "It should be complied with BOX_PS...";
return current_phase_;
#endif
}
void PaddleBoxDataFeed::PutToFeedVec(const std::vector<Record*>& ins_vec) {
#ifdef _LINUX
std::vector<std::vector<float>> batch_float_feasigns(use_slots_.size(),
std::vector<float>());
std::vector<std::vector<uint64_t>> batch_uint64_feasigns(
use_slots_.size(), std::vector<uint64_t>());
std::vector<std::vector<size_t>> offset(use_slots_.size(),
std::vector<size_t>{0});
std::vector<bool> visit(use_slots_.size(), false);
ins_content_vec_.clear();
ins_content_vec_.reserve(ins_vec.size());
ins_id_vec_.clear();
ins_id_vec_.reserve(ins_vec.size());
for (size_t i = 0; i < ins_vec.size(); ++i) {
auto r = ins_vec[i];
ins_id_vec_.push_back(r->ins_id_);
ins_content_vec_.push_back(r->content_);
for (auto& item : r->float_feasigns_) {
batch_float_feasigns[item.slot()].push_back(item.sign().float_feasign_);
visit[item.slot()] = true;
}
for (auto& item : r->uint64_feasigns_) {
batch_uint64_feasigns[item.slot()].push_back(item.sign().uint64_feasign_);
visit[item.slot()] = true;
}
for (size_t j = 0; j < use_slots_.size(); ++j) {
const auto& type = all_slots_type_[j];
if (visit[j]) {
visit[j] = false;
} else {
// fill slot value with default value 0
if (type[0] == 'f') { // float
batch_float_feasigns[j].push_back(0.0);
} else if (type[0] == 'u') { // uint64
batch_uint64_feasigns[j].push_back(0);
}
}
// get offset of this ins in this slot
if (type[0] == 'f') { // float
offset[j].push_back(batch_float_feasigns[j].size());
} else if (type[0] == 'u') { // uint64
offset[j].push_back(batch_uint64_feasigns[j].size());
}
}
}
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
int total_instance = offset[i].back();
const auto& type = all_slots_type_[i];
if (type[0] == 'f') { // float
float* feasign = batch_float_feasigns[i].data();
float* tensor_ptr =
feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
uint64_t* feasign = batch_uint64_feasigns[i].data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t));
}
auto& slot_offset = offset[i];
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -58,6 +58,51 @@ namespace framework { ...@@ -58,6 +58,51 @@ namespace framework {
// while (reader->Next()) { // while (reader->Next()) {
// // trainer do something // // trainer do something
// } // }
union FeatureKey {
uint64_t uint64_feasign_;
float float_feasign_;
};
struct FeatureItem {
FeatureItem() {}
FeatureItem(FeatureKey sign, uint16_t slot) {
this->sign() = sign;
this->slot() = slot;
}
FeatureKey& sign() { return *(reinterpret_cast<FeatureKey*>(sign_buffer())); }
const FeatureKey& sign() const {
const FeatureKey* ret = reinterpret_cast<FeatureKey*>(sign_buffer());
return *ret;
}
uint16_t& slot() { return slot_; }
const uint16_t& slot() const { return slot_; }
private:
char* sign_buffer() const { return const_cast<char*>(sign_); }
char sign_[sizeof(FeatureKey)];
uint16_t slot_;
};
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
std::vector<FeatureItem> uint64_feasigns_;
std::vector<FeatureItem> float_feasigns_;
std::string ins_id_;
std::string content_;
uint64_t search_id;
uint32_t rank;
uint32_t cmatch;
};
struct PvInstanceObject {
std::vector<Record*> ads;
void merge_instance(Record* ins) { ads.push_back(ins); }
};
using PvInstance = PvInstanceObject*;
inline PvInstance make_pv_instance() { return new PvInstanceObject(); }
class DataFeed { class DataFeed {
public: public:
DataFeed() { DataFeed() {
...@@ -93,6 +138,13 @@ class DataFeed { ...@@ -93,6 +138,13 @@ class DataFeed {
// This function is used for binding feed_vec memory in a given scope // This function is used for binding feed_vec memory in a given scope
virtual void AssignFeedVar(const Scope& scope); virtual void AssignFeedVar(const Scope& scope);
// This function will do nothing at default
virtual void SetInputPvChannel(void* channel) {}
// This function will do nothing at default
virtual void SetOutputPvChannel(void* channel) {}
// This function will do nothing at default
virtual void SetConsumePvChannel(void* channel) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetInputChannel(void* channel) {} virtual void SetInputChannel(void* channel) {}
// This function will do nothing at default // This function will do nothing at default
...@@ -106,6 +158,9 @@ class DataFeed { ...@@ -106,6 +158,9 @@ class DataFeed {
// This function will do nothing at default // This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {} virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseContent(bool parse_content) {} virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
virtual void SetCurrentPhase(int current_phase) {}
virtual void SetFileListMutex(std::mutex* mutex) { virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex; mutex_for_pick_file_ = mutex;
} }
...@@ -163,6 +218,8 @@ class DataFeed { ...@@ -163,6 +218,8 @@ class DataFeed {
// The data read by DataFeed will be stored here // The data read by DataFeed will be stored here
std::vector<LoDTensor*> feed_vec_; std::vector<LoDTensor*> feed_vec_;
LoDTensor* rank_offset_;
// the batch size defined by user // the batch size defined by user
int default_batch_size_; int default_batch_size_;
// current batch size // current batch size
...@@ -226,6 +283,10 @@ class InMemoryDataFeed : public DataFeed { ...@@ -226,6 +283,10 @@ class InMemoryDataFeed : public DataFeed {
virtual void Init(const DataFeedDesc& data_feed_desc) = 0; virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool Start(); virtual bool Start();
virtual int Next(); virtual int Next();
virtual void SetInputPvChannel(void* channel);
virtual void SetOutputPvChannel(void* channel);
virtual void SetConsumePvChannel(void* channel);
virtual void SetInputChannel(void* channel); virtual void SetInputChannel(void* channel);
virtual void SetOutputChannel(void* channel); virtual void SetOutputChannel(void* channel);
virtual void SetConsumeChannel(void* channel); virtual void SetConsumeChannel(void* channel);
...@@ -233,6 +294,9 @@ class InMemoryDataFeed : public DataFeed { ...@@ -233,6 +294,9 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content); virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetCurrentPhase(int current_phase);
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
protected: protected:
...@@ -244,11 +308,18 @@ class InMemoryDataFeed : public DataFeed { ...@@ -244,11 +308,18 @@ class InMemoryDataFeed : public DataFeed {
int thread_num_; int thread_num_;
bool parse_ins_id_; bool parse_ins_id_;
bool parse_content_; bool parse_content_;
bool parse_logkey_;
bool enable_pv_merge_;
int current_phase_{-1}; // only for untest
std::ifstream file_; std::ifstream file_;
std::shared_ptr<FILE> fp_; std::shared_ptr<FILE> fp_;
paddle::framework::ChannelObject<T>* input_channel_; paddle::framework::ChannelObject<T>* input_channel_;
paddle::framework::ChannelObject<T>* output_channel_; paddle::framework::ChannelObject<T>* output_channel_;
paddle::framework::ChannelObject<T>* consume_channel_; paddle::framework::ChannelObject<T>* consume_channel_;
paddle::framework::ChannelObject<PvInstance>* input_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* output_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* consume_pv_channel_;
}; };
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed // This class define the data type of instance(ins_vec) in MultiSlotDataFeed
...@@ -408,39 +479,6 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar, ...@@ -408,39 +479,6 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
return ar; return ar;
} }
union FeatureKey {
uint64_t uint64_feasign_;
float float_feasign_;
};
struct FeatureItem {
FeatureItem() {}
FeatureItem(FeatureKey sign, uint16_t slot) {
this->sign() = sign;
this->slot() = slot;
}
FeatureKey& sign() { return *(reinterpret_cast<FeatureKey*>(sign_buffer())); }
const FeatureKey& sign() const {
const FeatureKey* ret = reinterpret_cast<FeatureKey*>(sign_buffer());
return *ret;
}
uint16_t& slot() { return slot_; }
const uint16_t& slot() const { return slot_; }
private:
char* sign_buffer() const { return const_cast<char*>(sign_); }
char sign_[sizeof(FeatureKey)];
uint16_t slot_;
};
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
std::vector<FeatureItem> uint64_feasigns_;
std::vector<FeatureItem> float_feasigns_;
std::string ins_id_;
std::string content_;
};
struct RecordCandidate { struct RecordCandidate {
std::string ins_id_; std::string ins_id_;
std::unordered_multimap<uint16_t, FeatureKey> feas; std::unordered_multimap<uint16_t, FeatureKey> feas;
...@@ -557,6 +595,27 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> { ...@@ -557,6 +595,27 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual bool ParseOneInstance(Record* instance); virtual bool ParseOneInstance(Record* instance);
virtual bool ParseOneInstanceFromPipe(Record* instance); virtual bool ParseOneInstanceFromPipe(Record* instance);
virtual void PutToFeedVec(const std::vector<Record>& ins_vec); virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id,
uint32_t* cmatch, uint32_t* rank);
};
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
public:
PaddleBoxDataFeed() {}
virtual ~PaddleBoxDataFeed() {}
protected:
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual bool Start();
virtual int Next();
virtual void AssignFeedVar(const Scope& scope);
virtual void PutToFeedVec(const std::vector<PvInstance>& pv_vec);
virtual void PutToFeedVec(const std::vector<Record*>& ins_vec);
virtual int GetCurrentPhase();
virtual void GetRankOffset(const std::vector<PvInstance>& pv_vec,
int ins_number);
std::string rank_offset_name_;
int pv_batch_size_;
}; };
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
...@@ -30,4 +30,6 @@ message DataFeedDesc { ...@@ -30,4 +30,6 @@ message DataFeedDesc {
optional MultiSlotDesc multi_slot_desc = 3; optional MultiSlotDesc multi_slot_desc = 3;
optional string pipe_command = 4; optional string pipe_command = 4;
optional int32 thread_num = 5; optional int32 thread_num = 5;
optional string rank_offset = 6;
optional int32 pv_batch_size = 7 [ default = 32 ];
} }
...@@ -64,6 +64,7 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( ...@@ -64,6 +64,7 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
REGISTER_DATAFEED_CLASS(PaddleBoxDataFeed);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif #endif
......
...@@ -46,9 +46,12 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -46,9 +46,12 @@ DatasetImpl<T>::DatasetImpl() {
fleet_send_batch_size_ = 1024; fleet_send_batch_size_ = 1024;
fleet_send_sleep_seconds_ = 0; fleet_send_sleep_seconds_ = 0;
merge_by_insid_ = false; merge_by_insid_ = false;
merge_by_sid_ = true;
enable_pv_merge_ = false;
merge_size_ = 2; merge_size_ = 2;
parse_ins_id_ = false; parse_ins_id_ = false;
parse_content_ = false; parse_content_ = false;
parse_logkey_ = false;
preload_thread_num_ = 0; preload_thread_num_ = 0;
global_index_ = 0; global_index_ = 0;
} }
...@@ -126,6 +129,11 @@ void DatasetImpl<T>::SetParseContent(bool parse_content) { ...@@ -126,6 +129,11 @@ void DatasetImpl<T>::SetParseContent(bool parse_content) {
parse_content_ = parse_content; parse_content_ = parse_content;
} }
template <typename T>
void DatasetImpl<T>::SetParseLogKey(bool parse_logkey) {
parse_logkey_ = parse_logkey;
}
template <typename T> template <typename T>
void DatasetImpl<T>::SetMergeByInsId(int merge_size) { void DatasetImpl<T>::SetMergeByInsId(int merge_size) {
merge_by_insid_ = true; merge_by_insid_ = true;
...@@ -133,6 +141,16 @@ void DatasetImpl<T>::SetMergeByInsId(int merge_size) { ...@@ -133,6 +141,16 @@ void DatasetImpl<T>::SetMergeByInsId(int merge_size) {
merge_size_ = merge_size; merge_size_ = merge_size;
} }
template <typename T>
void DatasetImpl<T>::SetMergeBySid(bool is_merge) {
merge_by_sid_ = is_merge;
}
template <typename T>
void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge;
}
template <typename T> template <typename T>
void DatasetImpl<T>::SetGenerateUniqueFeasign(bool gen_uni_feasigns) { void DatasetImpl<T>::SetGenerateUniqueFeasign(bool gen_uni_feasigns) {
gen_uni_feasigns_ = gen_uni_feasigns; gen_uni_feasigns_ = gen_uni_feasigns;
...@@ -174,6 +192,21 @@ void DatasetImpl<T>::CreateChannel() { ...@@ -174,6 +192,21 @@ void DatasetImpl<T>::CreateChannel() {
multi_consume_channel_.push_back(paddle::framework::MakeChannel<T>()); multi_consume_channel_.push_back(paddle::framework::MakeChannel<T>());
} }
} }
if (input_pv_channel_ == nullptr) {
input_pv_channel_ = paddle::framework::MakeChannel<PvInstance>();
}
if (multi_pv_output_.size() == 0) {
multi_pv_output_.reserve(channel_num_);
for (int i = 0; i < channel_num_; ++i) {
multi_pv_output_.push_back(paddle::framework::MakeChannel<PvInstance>());
}
}
if (multi_pv_consume_.size() == 0) {
multi_pv_consume_.reserve(channel_num_);
for (int i = 0; i < channel_num_; ++i) {
multi_pv_consume_.push_back(paddle::framework::MakeChannel<PvInstance>());
}
}
} }
// if sent message between workers, should first call this function // if sent message between workers, should first call this function
...@@ -206,6 +239,7 @@ void DatasetImpl<T>::LoadIntoMemory() { ...@@ -206,6 +239,7 @@ void DatasetImpl<T>::LoadIntoMemory() {
input_channel_->Close(); input_channel_->Close();
int64_t in_chan_size = input_channel_->Size(); int64_t in_chan_size = input_channel_->Size();
input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1); input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1);
timeline.Pause(); timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end" VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
<< ", memory data size=" << input_channel_->Size() << ", memory data size=" << input_channel_->Size()
...@@ -270,6 +304,27 @@ void DatasetImpl<T>::ReleaseMemory() { ...@@ -270,6 +304,27 @@ void DatasetImpl<T>::ReleaseMemory() {
multi_consume_channel_[i] = nullptr; multi_consume_channel_[i] = nullptr;
} }
std::vector<paddle::framework::Channel<T>>().swap(multi_consume_channel_); std::vector<paddle::framework::Channel<T>>().swap(multi_consume_channel_);
if (input_pv_channel_) {
input_pv_channel_->Clear();
input_pv_channel_ = nullptr;
}
for (size_t i = 0; i < multi_pv_output_.size(); ++i) {
if (!multi_pv_output_[i]) {
continue;
}
multi_pv_output_[i]->Clear();
multi_pv_output_[i] = nullptr;
}
std::vector<paddle::framework::Channel<PvInstance>>().swap(multi_pv_output_);
for (size_t i = 0; i < multi_pv_consume_.size(); ++i) {
if (!multi_pv_consume_[i]) {
continue;
}
multi_pv_consume_[i]->Clear();
multi_pv_consume_[i] = nullptr;
}
std::vector<paddle::framework::Channel<PvInstance>>().swap(multi_pv_consume_);
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_); std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end"; VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
} }
...@@ -412,6 +467,11 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num, ...@@ -412,6 +467,11 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
channel_num_ = channel_num; channel_num_ = channel_num;
std::vector<paddle::framework::Channel<T>>* origin_channels = nullptr; std::vector<paddle::framework::Channel<T>>* origin_channels = nullptr;
std::vector<paddle::framework::Channel<T>>* other_channels = nullptr; std::vector<paddle::framework::Channel<T>>* other_channels = nullptr;
std::vector<paddle::framework::Channel<PvInstance>>* origin_pv_channels =
nullptr;
std::vector<paddle::framework::Channel<PvInstance>>* other_pv_channels =
nullptr;
// find out which channel (output or consume) has data // find out which channel (output or consume) has data
int cur_channel = 0; int cur_channel = 0;
uint64_t output_channels_data_size = 0; uint64_t output_channels_data_size = 0;
...@@ -431,17 +491,26 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num, ...@@ -431,17 +491,26 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
if (cur_channel == 0) { if (cur_channel == 0) {
origin_channels = &multi_output_channel_; origin_channels = &multi_output_channel_;
other_channels = &multi_consume_channel_; other_channels = &multi_consume_channel_;
origin_pv_channels = &multi_pv_output_;
other_pv_channels = &multi_pv_consume_;
} else { } else {
origin_channels = &multi_consume_channel_; origin_channels = &multi_consume_channel_;
other_channels = &multi_output_channel_; other_channels = &multi_output_channel_;
origin_pv_channels = &multi_pv_consume_;
other_pv_channels = &multi_pv_output_;
} }
CHECK(origin_channels != nullptr); // NOLINT CHECK(origin_channels != nullptr); // NOLINT
CHECK(other_channels != nullptr); // NOLINT CHECK(other_channels != nullptr); // NOLINT
CHECK(origin_pv_channels != nullptr); // NOLINT
CHECK(other_pv_channels != nullptr); // NOLINT
paddle::framework::Channel<T> total_data_channel = paddle::framework::Channel<T> total_data_channel =
paddle::framework::MakeChannel<T>(); paddle::framework::MakeChannel<T>();
std::vector<paddle::framework::Channel<T>> new_channels; std::vector<paddle::framework::Channel<T>> new_channels;
std::vector<paddle::framework::Channel<T>> new_other_channels; std::vector<paddle::framework::Channel<T>> new_other_channels;
std::vector<paddle::framework::Channel<PvInstance>> new_pv_channels;
std::vector<paddle::framework::Channel<PvInstance>> new_other_pv_channels;
std::vector<T> local_vec; std::vector<T> local_vec;
for (size_t i = 0; i < origin_channels->size(); ++i) { for (size_t i = 0; i < origin_channels->size(); ++i) {
local_vec.clear(); local_vec.clear();
...@@ -458,6 +527,12 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num, ...@@ -458,6 +527,12 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
input_channel_->SetBlockSize(input_channel_->Size() / channel_num + input_channel_->SetBlockSize(input_channel_->Size() / channel_num +
(discard_remaining_ins ? 0 : 1)); (discard_remaining_ins ? 0 : 1));
} }
if (static_cast<int>(input_pv_channel_->Size()) >= channel_num) {
input_pv_channel_->SetBlockSize(input_pv_channel_->Size() / channel_num +
(discard_remaining_ins ? 0 : 1));
VLOG(3) << "now input_pv_channle block size is "
<< input_pv_channel_->BlockSize();
}
for (int i = 0; i < channel_num; ++i) { for (int i = 0; i < channel_num; ++i) {
local_vec.clear(); local_vec.clear();
...@@ -465,6 +540,9 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num, ...@@ -465,6 +540,9 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
new_other_channels.push_back(paddle::framework::MakeChannel<T>()); new_other_channels.push_back(paddle::framework::MakeChannel<T>());
new_channels.push_back(paddle::framework::MakeChannel<T>()); new_channels.push_back(paddle::framework::MakeChannel<T>());
new_channels[i]->Write(std::move(local_vec)); new_channels[i]->Write(std::move(local_vec));
new_other_pv_channels.push_back(
paddle::framework::MakeChannel<PvInstance>());
new_pv_channels.push_back(paddle::framework::MakeChannel<PvInstance>());
} }
total_data_channel->Clear(); total_data_channel->Clear();
...@@ -473,10 +551,22 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num, ...@@ -473,10 +551,22 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
*origin_channels = new_channels; *origin_channels = new_channels;
*other_channels = new_other_channels; *other_channels = new_other_channels;
origin_pv_channels->clear();
other_pv_channels->clear();
*origin_pv_channels = new_pv_channels;
*other_pv_channels = new_other_pv_channels;
new_channels.clear(); new_channels.clear();
new_other_channels.clear(); new_other_channels.clear();
std::vector<paddle::framework::Channel<T>>().swap(new_channels); std::vector<paddle::framework::Channel<T>>().swap(new_channels);
std::vector<paddle::framework::Channel<T>>().swap(new_other_channels); std::vector<paddle::framework::Channel<T>>().swap(new_other_channels);
new_pv_channels.clear();
new_other_pv_channels.clear();
std::vector<paddle::framework::Channel<PvInstance>>().swap(new_pv_channels);
std::vector<paddle::framework::Channel<PvInstance>>().swap(
new_other_pv_channels);
local_vec.clear(); local_vec.clear();
std::vector<T>().swap(local_vec); std::vector<T>().swap(local_vec);
VLOG(3) << "adjust channel num done"; VLOG(3) << "adjust channel num done";
...@@ -528,17 +618,30 @@ void DatasetImpl<T>::CreateReaders() { ...@@ -528,17 +618,30 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFileList(filelist_); readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_); readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseContent(parse_content_); readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_);
// Notice: it is only valid for untest of test_paddlebox_datafeed.
// In fact, it does not affect the train process when paddle is
// complied with Box_Ps.
readers_[i]->SetCurrentPhase(current_phase_);
if (input_channel_ != nullptr) { if (input_channel_ != nullptr) {
readers_[i]->SetInputChannel(input_channel_.get()); readers_[i]->SetInputChannel(input_channel_.get());
} }
if (input_pv_channel_ != nullptr) {
readers_[i]->SetInputPvChannel(input_pv_channel_.get());
}
if (cur_channel_ == 0 && if (cur_channel_ == 0 &&
static_cast<size_t>(channel_idx) < multi_output_channel_.size()) { static_cast<size_t>(channel_idx) < multi_output_channel_.size()) {
readers_[i]->SetOutputChannel(multi_output_channel_[channel_idx].get()); readers_[i]->SetOutputChannel(multi_output_channel_[channel_idx].get());
readers_[i]->SetConsumeChannel(multi_consume_channel_[channel_idx].get()); readers_[i]->SetConsumeChannel(multi_consume_channel_[channel_idx].get());
readers_[i]->SetOutputPvChannel(multi_pv_output_[channel_idx].get());
readers_[i]->SetConsumePvChannel(multi_pv_consume_[channel_idx].get());
} else if (static_cast<size_t>(channel_idx) < } else if (static_cast<size_t>(channel_idx) <
multi_output_channel_.size()) { multi_output_channel_.size()) {
readers_[i]->SetOutputChannel(multi_consume_channel_[channel_idx].get()); readers_[i]->SetOutputChannel(multi_consume_channel_[channel_idx].get());
readers_[i]->SetConsumeChannel(multi_output_channel_[channel_idx].get()); readers_[i]->SetConsumeChannel(multi_output_channel_[channel_idx].get());
readers_[i]->SetOutputPvChannel(multi_pv_consume_[channel_idx].get());
readers_[i]->SetConsumePvChannel(multi_pv_output_[channel_idx].get());
} }
++channel_idx; ++channel_idx;
if (channel_idx >= channel_num_) { if (channel_idx >= channel_num_) {
...@@ -583,9 +686,13 @@ void DatasetImpl<T>::CreatePreLoadReaders() { ...@@ -583,9 +686,13 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFileList(filelist_); preload_readers_[i]->SetFileList(filelist_);
preload_readers_[i]->SetParseInsId(parse_ins_id_); preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseContent(parse_content_); preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_);
preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_);
preload_readers_[i]->SetInputChannel(input_channel_.get()); preload_readers_[i]->SetInputChannel(input_channel_.get());
preload_readers_[i]->SetOutputChannel(nullptr); preload_readers_[i]->SetOutputChannel(nullptr);
preload_readers_[i]->SetConsumeChannel(nullptr); preload_readers_[i]->SetConsumeChannel(nullptr);
preload_readers_[i]->SetOutputPvChannel(nullptr);
preload_readers_[i]->SetConsumePvChannel(nullptr);
} }
VLOG(3) << "End CreatePreLoadReaders"; VLOG(3) << "End CreatePreLoadReaders";
} }
...@@ -605,6 +712,16 @@ int64_t DatasetImpl<T>::GetMemoryDataSize() { ...@@ -605,6 +712,16 @@ int64_t DatasetImpl<T>::GetMemoryDataSize() {
return input_channel_->Size(); return input_channel_->Size();
} }
template <typename T>
int64_t DatasetImpl<T>::GetPvDataSize() {
if (enable_pv_merge_) {
return input_pv_channel_->Size();
} else {
VLOG(0) << "It does not merge pv..";
return 0;
}
}
template <typename T> template <typename T>
int64_t DatasetImpl<T>::GetShuffleDataSize() { int64_t DatasetImpl<T>::GetShuffleDataSize() {
int64_t sum = 0; int64_t sum = 0;
...@@ -657,6 +774,92 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, ...@@ -657,6 +774,92 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
// explicit instantiation // explicit instantiation
template class DatasetImpl<Record>; template class DatasetImpl<Record>;
void MultiSlotDataset::PostprocessInstance() {
// divide pv instance, and merge to input_channel_
if (enable_pv_merge_) {
input_channel_->Open();
input_channel_->Write(std::move(input_records_));
for (size_t i = 0; i < multi_pv_consume_.size(); ++i) {
multi_pv_consume_[i]->Clear();
}
input_channel_->Close();
input_records_.clear();
input_records_.shrink_to_fit();
} else {
input_channel_->Open();
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
std::vector<Record> ins_data;
multi_consume_channel_[i]->Close();
multi_consume_channel_[i]->ReadAll(ins_data);
input_channel_->Write(std::move(ins_data));
ins_data.clear();
ins_data.shrink_to_fit();
multi_consume_channel_[i]->Clear();
}
input_channel_->Close();
}
this->LocalShuffle();
}
void MultiSlotDataset::SetCurrentPhase(int current_phase) {
current_phase_ = current_phase;
}
void MultiSlotDataset::PreprocessInstance() {
if (!input_channel_ || input_channel_->Size() == 0) {
return;
}
if (!enable_pv_merge_) { // means to use Record
this->LocalShuffle();
} else { // means to use Pv
auto fleet_ptr = FleetWrapper::GetInstance();
input_channel_->Close();
std::vector<PvInstance> pv_data;
input_channel_->ReadAll(input_records_);
int all_records_num = input_records_.size();
std::vector<Record*> all_records;
all_records.reserve(all_records_num);
for (int index = 0; index < all_records_num; ++index) {
all_records.push_back(&input_records_[index]);
}
std::sort(all_records.data(), all_records.data() + all_records_num,
[](const Record* lhs, const Record* rhs) {
return lhs->search_id < rhs->search_id;
});
if (merge_by_sid_) {
uint64_t last_search_id = 0;
for (int i = 0; i < all_records_num; ++i) {
Record* ins = all_records[i];
if (i == 0 || last_search_id != ins->search_id) {
PvInstance pv_instance = make_pv_instance();
pv_instance->merge_instance(ins);
pv_data.push_back(pv_instance);
last_search_id = ins->search_id;
continue;
}
pv_data.back()->merge_instance(ins);
}
} else {
for (int i = 0; i < all_records_num; ++i) {
Record* ins = all_records[i];
PvInstance pv_instance = make_pv_instance();
pv_instance->merge_instance(ins);
pv_data.push_back(pv_instance);
}
}
std::shuffle(pv_data.begin(), pv_data.end(),
fleet_ptr->LocalRandomEngine());
input_pv_channel_->Open();
input_pv_channel_->Write(std::move(pv_data));
pv_data.clear();
pv_data.shrink_to_fit();
input_pv_channel_->Close();
}
}
void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim, void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num, int read_thread_num,
int consume_thread_num, int consume_thread_num,
...@@ -736,6 +939,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim, ...@@ -736,6 +939,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim,
consume_task_pool_.clear(); consume_task_pool_.clear();
fleet_ptr_->PullSparseToLocal(table_id, feadim); fleet_ptr_->PullSparseToLocal(table_id, feadim);
} }
void MultiSlotDataset::MergeByInsId() { void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId begin"; VLOG(3) << "MultiSlotDataset::MergeByInsId begin";
if (!merge_by_insid_) { if (!merge_by_insid_) {
......
...@@ -65,6 +65,9 @@ class Dataset { ...@@ -65,6 +65,9 @@ class Dataset {
// set parse ins id // set parse ins id
virtual void SetParseInsId(bool parse_ins_id) = 0; virtual void SetParseInsId(bool parse_ins_id) = 0;
virtual void SetParseContent(bool parse_content) = 0; virtual void SetParseContent(bool parse_content) = 0;
virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
// set merge by ins id // set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0; virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0; virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
...@@ -115,10 +118,18 @@ class Dataset { ...@@ -115,10 +118,18 @@ class Dataset {
virtual void DestroyReaders() = 0; virtual void DestroyReaders() = 0;
// get memory data size // get memory data size
virtual int64_t GetMemoryDataSize() = 0; virtual int64_t GetMemoryDataSize() = 0;
// get memory data size in input_pv_channel_
virtual int64_t GetPvDataSize() = 0;
// get shuffle data size // get shuffle data size
virtual int64_t GetShuffleDataSize() = 0; virtual int64_t GetShuffleDataSize() = 0;
// merge by ins id // merge by ins id
virtual void MergeByInsId() = 0; virtual void MergeByInsId() = 0;
// merge pv instance
virtual void PreprocessInstance() = 0;
// divide pv instance
virtual void PostprocessInstance() = 0;
// only for untest
virtual void SetCurrentPhase(int current_phase) = 0;
virtual void GenerateLocalTablesUnlock(int table_id, int feadim, virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num, int read_thread_num,
int consume_thread_num, int consume_thread_num,
...@@ -161,6 +172,10 @@ class DatasetImpl : public Dataset { ...@@ -161,6 +172,10 @@ class DatasetImpl : public Dataset {
virtual void SetChannelNum(int channel_num); virtual void SetChannelNum(int channel_num);
virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content); virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge);
virtual void SetMergeByInsId(int merge_size); virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns); virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
virtual void SetFeaEval(bool fea_eval, int record_candidate_size); virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
...@@ -192,8 +207,12 @@ class DatasetImpl : public Dataset { ...@@ -192,8 +207,12 @@ class DatasetImpl : public Dataset {
virtual void CreateReaders(); virtual void CreateReaders();
virtual void DestroyReaders(); virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize(); virtual int64_t GetMemoryDataSize();
virtual int64_t GetPvDataSize();
virtual int64_t GetShuffleDataSize(); virtual int64_t GetShuffleDataSize();
virtual void MergeByInsId() {} virtual void MergeByInsId() {}
virtual void PreprocessInstance() {}
virtual void PostprocessInstance() {}
virtual void SetCurrentPhase(int current_phase) {}
virtual void GenerateLocalTablesUnlock(int table_id, int feadim, virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num, int read_thread_num,
int consume_thread_num, int consume_thread_num,
...@@ -213,6 +232,10 @@ class DatasetImpl : public Dataset { ...@@ -213,6 +232,10 @@ class DatasetImpl : public Dataset {
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_; std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_; std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
paddle::framework::Channel<T> input_channel_; paddle::framework::Channel<T> input_channel_;
paddle::framework::Channel<PvInstance> input_pv_channel_;
std::vector<paddle::framework::Channel<PvInstance>> multi_pv_output_;
std::vector<paddle::framework::Channel<PvInstance>> multi_pv_consume_;
int channel_num_; int channel_num_;
std::vector<paddle::framework::Channel<T>> multi_output_channel_; std::vector<paddle::framework::Channel<T>> multi_output_channel_;
std::vector<paddle::framework::Channel<T>> multi_consume_channel_; std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
...@@ -238,6 +261,10 @@ class DatasetImpl : public Dataset { ...@@ -238,6 +261,10 @@ class DatasetImpl : public Dataset {
bool merge_by_insid_; bool merge_by_insid_;
bool parse_ins_id_; bool parse_ins_id_;
bool parse_content_; bool parse_content_;
bool parse_logkey_;
bool merge_by_sid_;
bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update
size_t merge_size_; size_t merge_size_;
bool slots_shuffle_fea_eval_ = false; bool slots_shuffle_fea_eval_ = false;
bool gen_uni_feasigns_ = false; bool gen_uni_feasigns_ = false;
...@@ -252,6 +279,9 @@ class MultiSlotDataset : public DatasetImpl<Record> { ...@@ -252,6 +279,9 @@ class MultiSlotDataset : public DatasetImpl<Record> {
public: public:
MultiSlotDataset() {} MultiSlotDataset() {}
virtual void MergeByInsId(); virtual void MergeByInsId();
virtual void PreprocessInstance();
virtual void PostprocessInstance();
virtual void SetCurrentPhase(int current_phase);
virtual void GenerateLocalTablesUnlock(int table_id, int feadim, virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num, int read_thread_num,
int consume_thread_num, int shard_num); int consume_thread_num, int shard_num);
...@@ -266,6 +296,9 @@ class MultiSlotDataset : public DatasetImpl<Record> { ...@@ -266,6 +296,9 @@ class MultiSlotDataset : public DatasetImpl<Record> {
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace, virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result); std::vector<Record>* result);
virtual ~MultiSlotDataset() {} virtual ~MultiSlotDataset() {}
protected:
std::vector<Record> input_records_; // the real data
}; };
} // end namespace framework } // end namespace framework
......
...@@ -239,6 +239,8 @@ void BindDataset(py::module *m) { ...@@ -239,6 +239,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("get_memory_data_size", &framework::Dataset::GetMemoryDataSize, .def("get_memory_data_size", &framework::Dataset::GetMemoryDataSize,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("get_pv_data_size", &framework::Dataset::GetPvDataSize,
py::call_guard<py::gil_scoped_release>())
.def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize, .def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_queue_num", &framework::Dataset::SetChannelNum, .def("set_queue_num", &framework::Dataset::SetChannelNum,
...@@ -247,6 +249,19 @@ void BindDataset(py::module *m) { ...@@ -247,6 +249,19 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_parse_content", &framework::Dataset::SetParseContent, .def("set_parse_content", &framework::Dataset::SetParseContent,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_parse_logkey", &framework::Dataset::SetParseLogKey,
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_sid", &framework::Dataset::SetMergeBySid,
py::call_guard<py::gil_scoped_release>())
.def("preprocess_instance", &framework::Dataset::PreprocessInstance,
py::call_guard<py::gil_scoped_release>())
.def("postprocess_instance", &framework::Dataset::PostprocessInstance,
py::call_guard<py::gil_scoped_release>())
.def("set_current_phase", &framework::Dataset::SetCurrentPhase,
py::call_guard<py::gil_scoped_release>())
.def("set_enable_pv_merge", &framework::Dataset::SetEnablePvMerge,
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId, .def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("merge_by_lineid", &framework::Dataset::MergeByInsId, .def("merge_by_lineid", &framework::Dataset::MergeByInsId,
......
...@@ -92,6 +92,23 @@ class DatasetBase(object): ...@@ -92,6 +92,23 @@ class DatasetBase(object):
""" """
self.proto_desc.pipe_command = pipe_command self.proto_desc.pipe_command = pipe_command
def set_rank_offset(self, rank_offset):
"""
Set rank_offset for merge_pv. It set the message of Pv.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_rank_offset("rank_offset")
Args:
rank_offset(str): rank_offset's name
"""
self.proto_desc.rank_offset = rank_offset
def set_fea_eval(self, record_candidate_size, fea_eval=True): def set_fea_eval(self, record_candidate_size, fea_eval=True):
""" """
set fea eval mode for slots shuffle to debug the importance level of set fea eval mode for slots shuffle to debug the importance level of
...@@ -154,6 +171,22 @@ class DatasetBase(object): ...@@ -154,6 +171,22 @@ class DatasetBase(object):
""" """
self.proto_desc.batch_size = batch_size self.proto_desc.batch_size = batch_size
def set_pv_batch_size(self, pv_batch_size):
"""
Set pv batch size. It will be effective during enable_pv_merge
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_pv_batch(128)
Args:
pv_batch_size(int): pv batch size
"""
self.proto_desc.pv_batch_size = pv_batch_size
def set_thread(self, thread_num): def set_thread(self, thread_num):
""" """
Set thread num, it is the num of readers. Set thread num, it is the num of readers.
...@@ -308,9 +341,18 @@ class InMemoryDataset(DatasetBase): ...@@ -308,9 +341,18 @@ class InMemoryDataset(DatasetBase):
self.queue_num = None self.queue_num = None
self.parse_ins_id = False self.parse_ins_id = False
self.parse_content = False self.parse_content = False
self.parse_logkey = False
self.merge_by_sid = True
self.enable_pv_merge = False
self.merge_by_lineid = False self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None self.fleet_send_sleep_seconds = None
def set_feed_type(self, data_feed_type):
"""
Set data_feed_desc
"""
self.proto_desc.name = data_feed_type
def _prepare_to_run(self): def _prepare_to_run(self):
""" """
Set data_feed_desc before load or shuffle, Set data_feed_desc before load or shuffle,
...@@ -324,6 +366,9 @@ class InMemoryDataset(DatasetBase): ...@@ -324,6 +366,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.set_queue_num(self.queue_num) self.dataset.set_queue_num(self.queue_num)
self.dataset.set_parse_ins_id(self.parse_ins_id) self.dataset.set_parse_ins_id(self.parse_ins_id)
self.dataset.set_parse_content(self.parse_content) self.dataset.set_parse_content(self.parse_content)
self.dataset.set_parse_logkey(self.parse_logkey)
self.dataset.set_merge_by_sid(self.merge_by_sid)
self.dataset.set_enable_pv_merge(self.enable_pv_merge)
self.dataset.set_data_feed_desc(self.desc()) self.dataset.set_data_feed_desc(self.desc())
self.dataset.create_channel() self.dataset.create_channel()
self.dataset.create_readers() self.dataset.create_readers()
...@@ -390,6 +435,112 @@ class InMemoryDataset(DatasetBase): ...@@ -390,6 +435,112 @@ class InMemoryDataset(DatasetBase):
""" """
self.parse_content = parse_content self.parse_content = parse_content
def set_parse_logkey(self, parse_logkey):
"""
Set if Dataset need to parse logkey
Args:
parse_content(bool): if parse logkey or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_parse_logkey(True)
"""
self.parse_logkey = parse_logkey
def set_merge_by_sid(self, merge_by_sid):
"""
Set if Dataset need to merge sid. If not, one ins means one Pv.
Args:
merge_by_sid(bool): if merge sid or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_sid(True)
"""
self.merge_by_sid = merge_by_sid
def set_enable_pv_merge(self, enable_pv_merge):
"""
Set if Dataset need to merge pv.
Args:
enable_pv_merge(bool): if enable_pv_merge or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_enable_pv_merge(True)
"""
self.enable_pv_merge = enable_pv_merge
def preprocess_instance(self):
"""
Merge pv instance and convey it from input_channel to input_pv_channel.
It will be effective when enable_pv_merge_ is True.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.preprocess_instance()
"""
self.dataset.preprocess_instance()
def set_current_phase(self, current_phase):
"""
Set current phase in train. It is useful for untest.
current_phase : 1 for join, 0 for update.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.set_current_phase(1)
"""
self.dataset.set_current_phase(current_phase)
def postprocess_instance(self):
"""
Divide pv instance and convey it to input_channel.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.preprocess_instance()
exe.train_from_dataset(dataset)
dataset.postprocess_instance()
"""
self.dataset.postprocess_instance()
def set_fleet_send_batch_size(self, fleet_send_batch_size=1024): def set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
""" """
Set fleet send batch size, default is 1024 Set fleet send batch size, default is 1024
...@@ -594,6 +745,30 @@ class InMemoryDataset(DatasetBase): ...@@ -594,6 +745,30 @@ class InMemoryDataset(DatasetBase):
""" """
self.dataset.release_memory() self.dataset.release_memory()
def get_pv_data_size(self):
"""
Get memory data size of Pv, user can call this function to know the pv num
of ins in all workers after load into memory.
Note:
This function may cause bad performance, because it has barrier
Returns:
The size of memory pv data.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
print dataset.get_pv_data_size()
"""
return self.dataset.get_pv_data_size()
def get_memory_data_size(self, fleet=None): def get_memory_data_size(self, fleet=None):
""" """
Get memory data size, user can call this function to know the num Get memory data size, user can call this function to know the num
...@@ -808,6 +983,7 @@ class BoxPSDataset(InMemoryDataset): ...@@ -808,6 +983,7 @@ class BoxPSDataset(InMemoryDataset):
""" """
super(BoxPSDataset, self).__init__() super(BoxPSDataset, self).__init__()
self.boxps = core.BoxPS(self.dataset) self.boxps = core.BoxPS(self.dataset)
self.proto_desc.name = "PaddleBoxDataFeed"
def set_date(self, date): def set_date(self, date):
""" """
...@@ -895,3 +1071,6 @@ class BoxPSDataset(InMemoryDataset): ...@@ -895,3 +1071,6 @@ class BoxPSDataset(InMemoryDataset):
if not self.is_user_set_queue_num: if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num, True) self.dataset.dynamic_adjust_channel_num(thread_num, True)
self.dataset.dynamic_adjust_readers_num(thread_num) self.dataset.dynamic_adjust_readers_num(thread_num)
def _dynamic_adjust_after_train(self):
pass
...@@ -44,6 +44,7 @@ endif() ...@@ -44,6 +44,7 @@ endif()
if(WIN32) if(WIN32)
LIST(REMOVE_ITEM TEST_OPS test_boxps) LIST(REMOVE_ITEM TEST_OPS test_boxps)
LIST(REMOVE_ITEM TEST_OPS test_paddlebox_datafeed)
LIST(REMOVE_ITEM TEST_OPS test_trainer_desc) LIST(REMOVE_ITEM TEST_OPS test_trainer_desc)
LIST(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception) LIST(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception)
LIST(REMOVE_ITEM TEST_OPS test_avoid_twice_initialization) LIST(REMOVE_ITEM TEST_OPS test_avoid_twice_initialization)
...@@ -59,6 +60,7 @@ endif() ...@@ -59,6 +60,7 @@ endif()
if(NOT WITH_GPU OR WIN32) if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_pipeline) LIST(REMOVE_ITEM TEST_OPS test_pipeline)
LIST(REMOVE_ITEM TEST_OPS test_boxps) LIST(REMOVE_ITEM TEST_OPS test_boxps)
LIST(REMOVE_ITEM TEST_OPS test_paddlebox_datafeed)
endif() endif()
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185 list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.core as core
import os
import unittest
import paddle.fluid.layers as layers
from paddle.fluid.layers.nn import _pull_box_sparse
class TestDataFeed(unittest.TestCase):
""" TestBaseCase(Merge PV) """
def setUp(self):
self.batch_size = 10
self.pv_batch_size = 10
self.enable_pv_merge = True
self.merge_by_sid = True
def set_data_config(self):
self.dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
self.dataset.set_feed_type("PaddleBoxDataFeed")
self.dataset.set_parse_logkey(True)
self.dataset.set_thread(1)
self.dataset.set_enable_pv_merge(self.enable_pv_merge)
self.dataset.set_batch_size(self.batch_size)
if self.enable_pv_merge:
self.dataset.set_merge_by_sid(self.merge_by_sid)
self.dataset.set_rank_offset("rank_offset")
self.dataset.set_pv_batch_size(self.pv_batch_size)
def test_pboxdatafeed(self):
self.run_dataset(False)
def test_pboxdatafeed(self):
self.run_dataset(True)
def run_dataset(self, is_cpu):
x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0)
y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0)
rank_offset = fluid.layers.data(
name="rank_offset",
shape=[-1, 7],
dtype="int32",
lod_level=0,
append_batch_size=False)
emb_x, emb_y = _pull_box_sparse([x, y], size=2)
emb_xp = _pull_box_sparse(x, size=2)
concat = layers.concat([emb_x, emb_y], axis=1)
fc = layers.fc(input=concat,
name="fc",
size=1,
num_flatten_dims=1,
bias_attr=False)
loss = layers.reduce_mean(fc)
place = fluid.CPUPlace() if is_cpu or not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
with open("test_run_with_dump_a.txt", "w") as f:
data = "1 1702f830eee19501ad7429505f714c1d 1 1 1 9\n"
data += "1 1702f830eee19502ad7429505f714c1d 1 2 1 8\n"
data += "1 1702f830eee19503ad7429505f714c1d 1 3 1 7\n"
data += "1 1702f830eee0de01ad7429505f714c2d 1 4 1 6\n"
data += "1 1702f830eee0df01ad7429505f714c3d 1 5 1 5\n"
data += "1 1702f830eee0df02ad7429505f714c3d 1 6 1 4\n"
f.write(data)
with open("test_run_with_dump_b.txt", "w") as f:
data = "1 1702f830fff22201ad7429505f715c1d 1 1 1 1\n"
data += "1 1702f830fff22202ad7429505f715c1d 1 2 1 2\n"
data += "1 1702f830fff22203ad7429505f715c1d 1 3 1 3\n"
data += "1 1702f830fff22101ad7429505f714ccd 1 4 1 4\n"
data += "1 1702f830fff22102ad7429505f714ccd 1 5 1 5\n"
data += "1 1702f830fff22103ad7429505f714ccd 1 6 1 6\n"
data += "1 1702f830fff22104ad7429505f714ccd 1 6 1 7\n"
f.write(data)
self.set_data_config()
self.dataset.set_use_var([x, y])
self.dataset.set_filelist(
["test_run_with_dump_a.txt", "test_run_with_dump_b.txt"])
optimizer = fluid.optimizer.SGD(learning_rate=0.5)
optimizer = fluid.optimizer.PipelineOptimizer(
optimizer,
cut_list=[],
place_list=[place],
concurrency_list=[1],
queue_size=1,
sync_steps=-1)
optimizer.minimize(loss)
exe.run(fluid.default_startup_program())
self.dataset.set_current_phase(1)
self.dataset.load_into_memory()
self.dataset.preprocess_instance()
self.dataset.begin_pass()
pv_num = self.dataset.get_pv_data_size()
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=self.dataset,
print_period=1)
self.dataset.set_current_phase(0)
self.dataset.postprocess_instance()
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=self.dataset,
print_period=1)
self.dataset.end_pass(True)
os.remove("test_run_with_dump_a.txt")
os.remove("test_run_with_dump_b.txt")
class TestDataFeed2(TestDataFeed):
""" TestBaseCase(Merge PV not merge by sid) """
def setUp(self):
self.batch_size = 10
self.pv_batch_size = 10
self.enable_pv_merge = True
self.merge_by_sid = False
class TestDataFeed3(TestDataFeed):
""" TestBaseCase(Not Merge PV) """
def setUp(self):
self.batch_size = 10
self.pv_batch_size = 10
self.enable_pv_merge = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册