未验证 提交 5223e2bb 编写于 作者: S ShenLiang 提交者: GitHub

Add a new DataFeed named PaddleBoxDataFeed (#23321)

* add paddleboxdatafeed
* add ifdef linux and boxps
* add untest for datafeed
* fix untest of test_paddlebox_datafeed
* fix untest
* rename function
上级 75bd3507
......@@ -33,6 +33,7 @@ limitations under the License. */
#include "io/shell.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/platform/timer.h"
......@@ -232,6 +233,9 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->parse_content_ = false;
this->parse_logkey_ = false;
this->enable_pv_merge_ = false;
this->current_phase_ = 1; // 1:join ;0:update
this->input_channel_ = nullptr;
this->output_channel_ = nullptr;
this->consume_channel_ = nullptr;
......@@ -305,6 +309,24 @@ void InMemoryDataFeed<T>::SetConsumeChannel(void* channel) {
consume_channel_ = static_cast<paddle::framework::ChannelObject<T>*>(channel);
}
template <typename T>
void InMemoryDataFeed<T>::SetInputPvChannel(void* channel) {
input_pv_channel_ =
static_cast<paddle::framework::ChannelObject<PvInstance>*>(channel);
}
template <typename T>
void InMemoryDataFeed<T>::SetOutputPvChannel(void* channel) {
output_pv_channel_ =
static_cast<paddle::framework::ChannelObject<PvInstance>*>(channel);
}
template <typename T>
void InMemoryDataFeed<T>::SetConsumePvChannel(void* channel) {
consume_pv_channel_ =
static_cast<paddle::framework::ChannelObject<PvInstance>*>(channel);
}
template <typename T>
void InMemoryDataFeed<T>::SetThreadId(int thread_id) {
thread_id_ = thread_id;
......@@ -320,6 +342,21 @@ void InMemoryDataFeed<T>::SetParseContent(bool parse_content) {
parse_content_ = parse_content;
}
template <typename T>
void InMemoryDataFeed<T>::SetParseLogKey(bool parse_logkey) {
parse_logkey_ = parse_logkey;
}
template <typename T>
void InMemoryDataFeed<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge;
}
template <typename T>
void InMemoryDataFeed<T>::SetCurrentPhase(int current_phase) {
current_phase_ = current_phase;
}
template <typename T>
void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
......@@ -756,6 +793,20 @@ void MultiSlotInMemoryDataFeed::Init(
finish_init_ = true;
}
void MultiSlotInMemoryDataFeed::GetMsgFromLogKey(const std::string& log_key,
uint64_t* search_id,
uint32_t* cmatch,
uint32_t* rank) {
std::string searchid_str = log_key.substr(16, 16);
*search_id = (uint64_t)strtoull(searchid_str.c_str(), NULL, 16);
std::string cmatch_str = log_key.substr(11, 3);
*cmatch = (uint32_t)strtoul(cmatch_str.c_str(), NULL, 16);
std::string rank_str = log_key.substr(14, 2);
*rank = (uint32_t)strtoul(rank_str.c_str(), NULL, 16);
}
bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
#ifdef _LINUX
thread_local string::LineFileReader reader;
......@@ -792,6 +843,26 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
pos += len + 1;
VLOG(3) << "content " << instance->content_;
}
if (parse_logkey_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
// parse_logkey
std::string log_key = std::string(str + pos, len);
uint64_t search_id;
uint32_t cmatch;
uint32_t rank;
GetMsgFromLogKey(log_key, &search_id, &cmatch, &rank);
instance->search_id = search_id;
instance->cmatch = cmatch;
instance->rank = rank;
pos += len + 1;
}
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
......@@ -1186,5 +1257,242 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
}
#endif
bool PaddleBoxDataFeed::Start() {
#ifdef _LINUX
int phase = GetCurrentPhase(); // join: 1, update: 0
this->CheckSetFileList();
if (enable_pv_merge_ && phase == 1) {
// join phase : input_pv_channel to output_pv_channel
if (output_pv_channel_->Size() == 0 && input_pv_channel_->Size() != 0) {
std::vector<PvInstance> data;
input_pv_channel_->Read(data);
output_pv_channel_->Write(std::move(data));
}
} else {
// input_channel to output
if (output_channel_->Size() == 0 && input_channel_->Size() != 0) {
std::vector<Record> data;
input_channel_->Read(data);
output_channel_->Write(std::move(data));
}
}
#endif
this->finish_start_ = true;
return true;
}
int PaddleBoxDataFeed::Next() {
#ifdef _LINUX
int phase = GetCurrentPhase(); // join: 1, update: 0
this->CheckStart();
if (enable_pv_merge_ && phase == 1) {
// join phase : output_pv_channel to consume_pv_channel
CHECK(output_pv_channel_ != nullptr);
CHECK(consume_pv_channel_ != nullptr);
VLOG(3) << "output_pv_channel_ size=" << output_pv_channel_->Size()
<< ", consume_pv_channel_ size=" << consume_pv_channel_->Size()
<< ", thread_id=" << thread_id_;
int index = 0;
PvInstance pv_instance;
std::vector<PvInstance> pv_vec;
pv_vec.reserve(this->pv_batch_size_);
while (index < this->pv_batch_size_) {
if (output_pv_channel_->Size() == 0) {
break;
}
output_pv_channel_->Get(pv_instance);
pv_vec.push_back(pv_instance);
++index;
consume_pv_channel_->Put(std::move(pv_instance));
}
this->batch_size_ = index;
VLOG(3) << "pv_batch_size_=" << this->batch_size_
<< ", thread_id=" << thread_id_;
if (this->batch_size_ != 0) {
PutToFeedVec(pv_vec);
} else {
VLOG(3) << "finish reading, output_pv_channel_ size="
<< output_pv_channel_->Size()
<< ", consume_pv_channel_ size=" << consume_pv_channel_->Size()
<< ", thread_id=" << thread_id_;
}
return this->batch_size_;
} else {
this->batch_size_ = MultiSlotInMemoryDataFeed::Next();
return this->batch_size_;
}
#else
return 0;
#endif
}
void PaddleBoxDataFeed::Init(const DataFeedDesc& data_feed_desc) {
MultiSlotInMemoryDataFeed::Init(data_feed_desc);
rank_offset_name_ = data_feed_desc.rank_offset();
pv_batch_size_ = data_feed_desc.pv_batch_size();
}
void PaddleBoxDataFeed::GetRankOffset(const std::vector<PvInstance>& pv_vec,
int ins_number) {
int index = 0;
int max_rank = 3; // the value is setting
int row = ins_number;
int col = max_rank * 2 + 1;
int pv_num = pv_vec.size();
std::vector<int> rank_offset_mat(row * col, -1);
rank_offset_mat.shrink_to_fit();
for (int i = 0; i < pv_num; i++) {
auto pv_ins = pv_vec[i];
int ad_num = pv_ins->ads.size();
int index_start = index;
for (int j = 0; j < ad_num; ++j) {
auto ins = pv_ins->ads[j];
int rank = -1;
if ((ins->cmatch == 222 || ins->cmatch == 223) &&
ins->rank <= static_cast<uint32_t>(max_rank) && ins->rank != 0) {
rank = ins->rank;
}
rank_offset_mat[index * col] = rank;
if (rank > 0) {
for (int k = 0; k < ad_num; ++k) {
auto cur_ins = pv_ins->ads[k];
int fast_rank = -1;
if ((cur_ins->cmatch == 222 || cur_ins->cmatch == 223) &&
cur_ins->rank <= static_cast<uint32_t>(max_rank) &&
cur_ins->rank != 0) {
fast_rank = cur_ins->rank;
}
if (fast_rank > 0) {
int m = fast_rank - 1;
rank_offset_mat[index * col + 2 * m + 1] = cur_ins->rank;
rank_offset_mat[index * col + 2 * m + 2] = index_start + k;
}
}
}
index += 1;
}
}
int* rank_offset = rank_offset_mat.data();
int* tensor_ptr = rank_offset_->mutable_data<int>({row, col}, this->place_);
CopyToFeedTensor(tensor_ptr, rank_offset, row * col * sizeof(int));
}
void PaddleBoxDataFeed::AssignFeedVar(const Scope& scope) {
MultiSlotInMemoryDataFeed::AssignFeedVar(scope);
// set rank offset memory
int phase = GetCurrentPhase(); // join: 1, update: 0
if (enable_pv_merge_ && phase == 1) {
rank_offset_ = scope.FindVar(rank_offset_name_)->GetMutable<LoDTensor>();
}
}
void PaddleBoxDataFeed::PutToFeedVec(const std::vector<PvInstance>& pv_vec) {
#ifdef _LINUX
int ins_number = 0;
std::vector<Record*> ins_vec;
for (auto& pv : pv_vec) {
ins_number += pv->ads.size();
for (auto ins : pv->ads) {
ins_vec.push_back(ins);
}
}
GetRankOffset(pv_vec, ins_number);
PutToFeedVec(ins_vec);
#endif
}
int PaddleBoxDataFeed::GetCurrentPhase() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
return box_ptr->PassFlag(); // join: 1, update: 0
#else
LOG(WARNING) << "It should be complied with BOX_PS...";
return current_phase_;
#endif
}
void PaddleBoxDataFeed::PutToFeedVec(const std::vector<Record*>& ins_vec) {
#ifdef _LINUX
std::vector<std::vector<float>> batch_float_feasigns(use_slots_.size(),
std::vector<float>());
std::vector<std::vector<uint64_t>> batch_uint64_feasigns(
use_slots_.size(), std::vector<uint64_t>());
std::vector<std::vector<size_t>> offset(use_slots_.size(),
std::vector<size_t>{0});
std::vector<bool> visit(use_slots_.size(), false);
ins_content_vec_.clear();
ins_content_vec_.reserve(ins_vec.size());
ins_id_vec_.clear();
ins_id_vec_.reserve(ins_vec.size());
for (size_t i = 0; i < ins_vec.size(); ++i) {
auto r = ins_vec[i];
ins_id_vec_.push_back(r->ins_id_);
ins_content_vec_.push_back(r->content_);
for (auto& item : r->float_feasigns_) {
batch_float_feasigns[item.slot()].push_back(item.sign().float_feasign_);
visit[item.slot()] = true;
}
for (auto& item : r->uint64_feasigns_) {
batch_uint64_feasigns[item.slot()].push_back(item.sign().uint64_feasign_);
visit[item.slot()] = true;
}
for (size_t j = 0; j < use_slots_.size(); ++j) {
const auto& type = all_slots_type_[j];
if (visit[j]) {
visit[j] = false;
} else {
// fill slot value with default value 0
if (type[0] == 'f') { // float
batch_float_feasigns[j].push_back(0.0);
} else if (type[0] == 'u') { // uint64
batch_uint64_feasigns[j].push_back(0);
}
}
// get offset of this ins in this slot
if (type[0] == 'f') { // float
offset[j].push_back(batch_float_feasigns[j].size());
} else if (type[0] == 'u') { // uint64
offset[j].push_back(batch_uint64_feasigns[j].size());
}
}
}
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
int total_instance = offset[i].back();
const auto& type = all_slots_type_[i];
if (type[0] == 'f') { // float
float* feasign = batch_float_feasigns[i].data();
float* tensor_ptr =
feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
uint64_t* feasign = batch_uint64_feasigns[i].data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t));
}
auto& slot_offset = offset[i];
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
#endif
}
} // namespace framework
} // namespace paddle
......@@ -58,6 +58,51 @@ namespace framework {
// while (reader->Next()) {
// // trainer do something
// }
union FeatureKey {
uint64_t uint64_feasign_;
float float_feasign_;
};
struct FeatureItem {
FeatureItem() {}
FeatureItem(FeatureKey sign, uint16_t slot) {
this->sign() = sign;
this->slot() = slot;
}
FeatureKey& sign() { return *(reinterpret_cast<FeatureKey*>(sign_buffer())); }
const FeatureKey& sign() const {
const FeatureKey* ret = reinterpret_cast<FeatureKey*>(sign_buffer());
return *ret;
}
uint16_t& slot() { return slot_; }
const uint16_t& slot() const { return slot_; }
private:
char* sign_buffer() const { return const_cast<char*>(sign_); }
char sign_[sizeof(FeatureKey)];
uint16_t slot_;
};
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
std::vector<FeatureItem> uint64_feasigns_;
std::vector<FeatureItem> float_feasigns_;
std::string ins_id_;
std::string content_;
uint64_t search_id;
uint32_t rank;
uint32_t cmatch;
};
struct PvInstanceObject {
std::vector<Record*> ads;
void merge_instance(Record* ins) { ads.push_back(ins); }
};
using PvInstance = PvInstanceObject*;
inline PvInstance make_pv_instance() { return new PvInstanceObject(); }
class DataFeed {
public:
DataFeed() {
......@@ -93,6 +138,13 @@ class DataFeed {
// This function is used for binding feed_vec memory in a given scope
virtual void AssignFeedVar(const Scope& scope);
// This function will do nothing at default
virtual void SetInputPvChannel(void* channel) {}
// This function will do nothing at default
virtual void SetOutputPvChannel(void* channel) {}
// This function will do nothing at default
virtual void SetConsumePvChannel(void* channel) {}
// This function will do nothing at default
virtual void SetInputChannel(void* channel) {}
// This function will do nothing at default
......@@ -106,6 +158,9 @@ class DataFeed {
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
virtual void SetCurrentPhase(int current_phase) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
......@@ -163,6 +218,8 @@ class DataFeed {
// The data read by DataFeed will be stored here
std::vector<LoDTensor*> feed_vec_;
LoDTensor* rank_offset_;
// the batch size defined by user
int default_batch_size_;
// current batch size
......@@ -226,6 +283,10 @@ class InMemoryDataFeed : public DataFeed {
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
virtual void SetInputPvChannel(void* channel);
virtual void SetOutputPvChannel(void* channel);
virtual void SetConsumePvChannel(void* channel);
virtual void SetInputChannel(void* channel);
virtual void SetOutputChannel(void* channel);
virtual void SetConsumeChannel(void* channel);
......@@ -233,6 +294,9 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetCurrentPhase(int current_phase);
virtual void LoadIntoMemory();
protected:
......@@ -244,11 +308,18 @@ class InMemoryDataFeed : public DataFeed {
int thread_num_;
bool parse_ins_id_;
bool parse_content_;
bool parse_logkey_;
bool enable_pv_merge_;
int current_phase_{-1}; // only for untest
std::ifstream file_;
std::shared_ptr<FILE> fp_;
paddle::framework::ChannelObject<T>* input_channel_;
paddle::framework::ChannelObject<T>* output_channel_;
paddle::framework::ChannelObject<T>* consume_channel_;
paddle::framework::ChannelObject<PvInstance>* input_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* output_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* consume_pv_channel_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
......@@ -408,39 +479,6 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
return ar;
}
union FeatureKey {
uint64_t uint64_feasign_;
float float_feasign_;
};
struct FeatureItem {
FeatureItem() {}
FeatureItem(FeatureKey sign, uint16_t slot) {
this->sign() = sign;
this->slot() = slot;
}
FeatureKey& sign() { return *(reinterpret_cast<FeatureKey*>(sign_buffer())); }
const FeatureKey& sign() const {
const FeatureKey* ret = reinterpret_cast<FeatureKey*>(sign_buffer());
return *ret;
}
uint16_t& slot() { return slot_; }
const uint16_t& slot() const { return slot_; }
private:
char* sign_buffer() const { return const_cast<char*>(sign_); }
char sign_[sizeof(FeatureKey)];
uint16_t slot_;
};
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
std::vector<FeatureItem> uint64_feasigns_;
std::vector<FeatureItem> float_feasigns_;
std::string ins_id_;
std::string content_;
};
struct RecordCandidate {
std::string ins_id_;
std::unordered_multimap<uint16_t, FeatureKey> feas;
......@@ -557,6 +595,27 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual bool ParseOneInstance(Record* instance);
virtual bool ParseOneInstanceFromPipe(Record* instance);
virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id,
uint32_t* cmatch, uint32_t* rank);
};
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
public:
PaddleBoxDataFeed() {}
virtual ~PaddleBoxDataFeed() {}
protected:
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual bool Start();
virtual int Next();
virtual void AssignFeedVar(const Scope& scope);
virtual void PutToFeedVec(const std::vector<PvInstance>& pv_vec);
virtual void PutToFeedVec(const std::vector<Record*>& ins_vec);
virtual int GetCurrentPhase();
virtual void GetRankOffset(const std::vector<PvInstance>& pv_vec,
int ins_number);
std::string rank_offset_name_;
int pv_batch_size_;
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
......@@ -30,4 +30,6 @@ message DataFeedDesc {
optional MultiSlotDesc multi_slot_desc = 3;
optional string pipe_command = 4;
optional int32 thread_num = 5;
optional string rank_offset = 6;
optional int32 pv_batch_size = 7 [ default = 32 ];
}
......@@ -64,6 +64,7 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
REGISTER_DATAFEED_CLASS(PaddleBoxDataFeed);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif
......
......@@ -46,9 +46,12 @@ DatasetImpl<T>::DatasetImpl() {
fleet_send_batch_size_ = 1024;
fleet_send_sleep_seconds_ = 0;
merge_by_insid_ = false;
merge_by_sid_ = true;
enable_pv_merge_ = false;
merge_size_ = 2;
parse_ins_id_ = false;
parse_content_ = false;
parse_logkey_ = false;
preload_thread_num_ = 0;
global_index_ = 0;
}
......@@ -126,6 +129,11 @@ void DatasetImpl<T>::SetParseContent(bool parse_content) {
parse_content_ = parse_content;
}
template <typename T>
void DatasetImpl<T>::SetParseLogKey(bool parse_logkey) {
parse_logkey_ = parse_logkey;
}
template <typename T>
void DatasetImpl<T>::SetMergeByInsId(int merge_size) {
merge_by_insid_ = true;
......@@ -133,6 +141,16 @@ void DatasetImpl<T>::SetMergeByInsId(int merge_size) {
merge_size_ = merge_size;
}
template <typename T>
void DatasetImpl<T>::SetMergeBySid(bool is_merge) {
merge_by_sid_ = is_merge;
}
template <typename T>
void DatasetImpl<T>::SetEnablePvMerge(bool enable_pv_merge) {
enable_pv_merge_ = enable_pv_merge;
}
template <typename T>
void DatasetImpl<T>::SetGenerateUniqueFeasign(bool gen_uni_feasigns) {
gen_uni_feasigns_ = gen_uni_feasigns;
......@@ -174,6 +192,21 @@ void DatasetImpl<T>::CreateChannel() {
multi_consume_channel_.push_back(paddle::framework::MakeChannel<T>());
}
}
if (input_pv_channel_ == nullptr) {
input_pv_channel_ = paddle::framework::MakeChannel<PvInstance>();
}
if (multi_pv_output_.size() == 0) {
multi_pv_output_.reserve(channel_num_);
for (int i = 0; i < channel_num_; ++i) {
multi_pv_output_.push_back(paddle::framework::MakeChannel<PvInstance>());
}
}
if (multi_pv_consume_.size() == 0) {
multi_pv_consume_.reserve(channel_num_);
for (int i = 0; i < channel_num_; ++i) {
multi_pv_consume_.push_back(paddle::framework::MakeChannel<PvInstance>());
}
}
}
// if sent message between workers, should first call this function
......@@ -206,6 +239,7 @@ void DatasetImpl<T>::LoadIntoMemory() {
input_channel_->Close();
int64_t in_chan_size = input_channel_->Size();
input_channel_->SetBlockSize(in_chan_size / thread_num_ + 1);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
<< ", memory data size=" << input_channel_->Size()
......@@ -270,6 +304,27 @@ void DatasetImpl<T>::ReleaseMemory() {
multi_consume_channel_[i] = nullptr;
}
std::vector<paddle::framework::Channel<T>>().swap(multi_consume_channel_);
if (input_pv_channel_) {
input_pv_channel_->Clear();
input_pv_channel_ = nullptr;
}
for (size_t i = 0; i < multi_pv_output_.size(); ++i) {
if (!multi_pv_output_[i]) {
continue;
}
multi_pv_output_[i]->Clear();
multi_pv_output_[i] = nullptr;
}
std::vector<paddle::framework::Channel<PvInstance>>().swap(multi_pv_output_);
for (size_t i = 0; i < multi_pv_consume_.size(); ++i) {
if (!multi_pv_consume_[i]) {
continue;
}
multi_pv_consume_[i]->Clear();
multi_pv_consume_[i] = nullptr;
}
std::vector<paddle::framework::Channel<PvInstance>>().swap(multi_pv_consume_);
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}
......@@ -412,6 +467,11 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
channel_num_ = channel_num;
std::vector<paddle::framework::Channel<T>>* origin_channels = nullptr;
std::vector<paddle::framework::Channel<T>>* other_channels = nullptr;
std::vector<paddle::framework::Channel<PvInstance>>* origin_pv_channels =
nullptr;
std::vector<paddle::framework::Channel<PvInstance>>* other_pv_channels =
nullptr;
// find out which channel (output or consume) has data
int cur_channel = 0;
uint64_t output_channels_data_size = 0;
......@@ -431,17 +491,26 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
if (cur_channel == 0) {
origin_channels = &multi_output_channel_;
other_channels = &multi_consume_channel_;
origin_pv_channels = &multi_pv_output_;
other_pv_channels = &multi_pv_consume_;
} else {
origin_channels = &multi_consume_channel_;
other_channels = &multi_output_channel_;
origin_pv_channels = &multi_pv_consume_;
other_pv_channels = &multi_pv_output_;
}
CHECK(origin_channels != nullptr); // NOLINT
CHECK(other_channels != nullptr); // NOLINT
CHECK(origin_channels != nullptr); // NOLINT
CHECK(other_channels != nullptr); // NOLINT
CHECK(origin_pv_channels != nullptr); // NOLINT
CHECK(other_pv_channels != nullptr); // NOLINT
paddle::framework::Channel<T> total_data_channel =
paddle::framework::MakeChannel<T>();
std::vector<paddle::framework::Channel<T>> new_channels;
std::vector<paddle::framework::Channel<T>> new_other_channels;
std::vector<paddle::framework::Channel<PvInstance>> new_pv_channels;
std::vector<paddle::framework::Channel<PvInstance>> new_other_pv_channels;
std::vector<T> local_vec;
for (size_t i = 0; i < origin_channels->size(); ++i) {
local_vec.clear();
......@@ -458,6 +527,12 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
input_channel_->SetBlockSize(input_channel_->Size() / channel_num +
(discard_remaining_ins ? 0 : 1));
}
if (static_cast<int>(input_pv_channel_->Size()) >= channel_num) {
input_pv_channel_->SetBlockSize(input_pv_channel_->Size() / channel_num +
(discard_remaining_ins ? 0 : 1));
VLOG(3) << "now input_pv_channle block size is "
<< input_pv_channel_->BlockSize();
}
for (int i = 0; i < channel_num; ++i) {
local_vec.clear();
......@@ -465,6 +540,9 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
new_other_channels.push_back(paddle::framework::MakeChannel<T>());
new_channels.push_back(paddle::framework::MakeChannel<T>());
new_channels[i]->Write(std::move(local_vec));
new_other_pv_channels.push_back(
paddle::framework::MakeChannel<PvInstance>());
new_pv_channels.push_back(paddle::framework::MakeChannel<PvInstance>());
}
total_data_channel->Clear();
......@@ -473,10 +551,22 @@ void DatasetImpl<T>::DynamicAdjustChannelNum(int channel_num,
*origin_channels = new_channels;
*other_channels = new_other_channels;
origin_pv_channels->clear();
other_pv_channels->clear();
*origin_pv_channels = new_pv_channels;
*other_pv_channels = new_other_pv_channels;
new_channels.clear();
new_other_channels.clear();
std::vector<paddle::framework::Channel<T>>().swap(new_channels);
std::vector<paddle::framework::Channel<T>>().swap(new_other_channels);
new_pv_channels.clear();
new_other_pv_channels.clear();
std::vector<paddle::framework::Channel<PvInstance>>().swap(new_pv_channels);
std::vector<paddle::framework::Channel<PvInstance>>().swap(
new_other_pv_channels);
local_vec.clear();
std::vector<T>().swap(local_vec);
VLOG(3) << "adjust channel num done";
......@@ -528,17 +618,30 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_);
// Notice: it is only valid for untest of test_paddlebox_datafeed.
// In fact, it does not affect the train process when paddle is
// complied with Box_Ps.
readers_[i]->SetCurrentPhase(current_phase_);
if (input_channel_ != nullptr) {
readers_[i]->SetInputChannel(input_channel_.get());
}
if (input_pv_channel_ != nullptr) {
readers_[i]->SetInputPvChannel(input_pv_channel_.get());
}
if (cur_channel_ == 0 &&
static_cast<size_t>(channel_idx) < multi_output_channel_.size()) {
readers_[i]->SetOutputChannel(multi_output_channel_[channel_idx].get());
readers_[i]->SetConsumeChannel(multi_consume_channel_[channel_idx].get());
readers_[i]->SetOutputPvChannel(multi_pv_output_[channel_idx].get());
readers_[i]->SetConsumePvChannel(multi_pv_consume_[channel_idx].get());
} else if (static_cast<size_t>(channel_idx) <
multi_output_channel_.size()) {
readers_[i]->SetOutputChannel(multi_consume_channel_[channel_idx].get());
readers_[i]->SetConsumeChannel(multi_output_channel_[channel_idx].get());
readers_[i]->SetOutputPvChannel(multi_pv_consume_[channel_idx].get());
readers_[i]->SetConsumePvChannel(multi_pv_output_[channel_idx].get());
}
++channel_idx;
if (channel_idx >= channel_num_) {
......@@ -583,9 +686,13 @@ void DatasetImpl<T>::CreatePreLoadReaders() {
preload_readers_[i]->SetFileList(filelist_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetParseLogKey(parse_logkey_);
preload_readers_[i]->SetEnablePvMerge(enable_pv_merge_);
preload_readers_[i]->SetInputChannel(input_channel_.get());
preload_readers_[i]->SetOutputChannel(nullptr);
preload_readers_[i]->SetConsumeChannel(nullptr);
preload_readers_[i]->SetOutputPvChannel(nullptr);
preload_readers_[i]->SetConsumePvChannel(nullptr);
}
VLOG(3) << "End CreatePreLoadReaders";
}
......@@ -605,6 +712,16 @@ int64_t DatasetImpl<T>::GetMemoryDataSize() {
return input_channel_->Size();
}
template <typename T>
int64_t DatasetImpl<T>::GetPvDataSize() {
if (enable_pv_merge_) {
return input_pv_channel_->Size();
} else {
VLOG(0) << "It does not merge pv..";
return 0;
}
}
template <typename T>
int64_t DatasetImpl<T>::GetShuffleDataSize() {
int64_t sum = 0;
......@@ -657,6 +774,92 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
// explicit instantiation
template class DatasetImpl<Record>;
void MultiSlotDataset::PostprocessInstance() {
// divide pv instance, and merge to input_channel_
if (enable_pv_merge_) {
input_channel_->Open();
input_channel_->Write(std::move(input_records_));
for (size_t i = 0; i < multi_pv_consume_.size(); ++i) {
multi_pv_consume_[i]->Clear();
}
input_channel_->Close();
input_records_.clear();
input_records_.shrink_to_fit();
} else {
input_channel_->Open();
for (size_t i = 0; i < multi_consume_channel_.size(); ++i) {
std::vector<Record> ins_data;
multi_consume_channel_[i]->Close();
multi_consume_channel_[i]->ReadAll(ins_data);
input_channel_->Write(std::move(ins_data));
ins_data.clear();
ins_data.shrink_to_fit();
multi_consume_channel_[i]->Clear();
}
input_channel_->Close();
}
this->LocalShuffle();
}
void MultiSlotDataset::SetCurrentPhase(int current_phase) {
current_phase_ = current_phase;
}
void MultiSlotDataset::PreprocessInstance() {
if (!input_channel_ || input_channel_->Size() == 0) {
return;
}
if (!enable_pv_merge_) { // means to use Record
this->LocalShuffle();
} else { // means to use Pv
auto fleet_ptr = FleetWrapper::GetInstance();
input_channel_->Close();
std::vector<PvInstance> pv_data;
input_channel_->ReadAll(input_records_);
int all_records_num = input_records_.size();
std::vector<Record*> all_records;
all_records.reserve(all_records_num);
for (int index = 0; index < all_records_num; ++index) {
all_records.push_back(&input_records_[index]);
}
std::sort(all_records.data(), all_records.data() + all_records_num,
[](const Record* lhs, const Record* rhs) {
return lhs->search_id < rhs->search_id;
});
if (merge_by_sid_) {
uint64_t last_search_id = 0;
for (int i = 0; i < all_records_num; ++i) {
Record* ins = all_records[i];
if (i == 0 || last_search_id != ins->search_id) {
PvInstance pv_instance = make_pv_instance();
pv_instance->merge_instance(ins);
pv_data.push_back(pv_instance);
last_search_id = ins->search_id;
continue;
}
pv_data.back()->merge_instance(ins);
}
} else {
for (int i = 0; i < all_records_num; ++i) {
Record* ins = all_records[i];
PvInstance pv_instance = make_pv_instance();
pv_instance->merge_instance(ins);
pv_data.push_back(pv_instance);
}
}
std::shuffle(pv_data.begin(), pv_data.end(),
fleet_ptr->LocalRandomEngine());
input_pv_channel_->Open();
input_pv_channel_->Write(std::move(pv_data));
pv_data.clear();
pv_data.shrink_to_fit();
input_pv_channel_->Close();
}
}
void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num,
......@@ -736,6 +939,7 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim,
consume_task_pool_.clear();
fleet_ptr_->PullSparseToLocal(table_id, feadim);
}
void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId begin";
if (!merge_by_insid_) {
......
......@@ -65,6 +65,9 @@ class Dataset {
// set parse ins id
virtual void SetParseInsId(bool parse_ins_id) = 0;
virtual void SetParseContent(bool parse_content) = 0;
virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
......@@ -115,10 +118,18 @@ class Dataset {
virtual void DestroyReaders() = 0;
// get memory data size
virtual int64_t GetMemoryDataSize() = 0;
// get memory data size in input_pv_channel_
virtual int64_t GetPvDataSize() = 0;
// get shuffle data size
virtual int64_t GetShuffleDataSize() = 0;
// merge by ins id
virtual void MergeByInsId() = 0;
// merge pv instance
virtual void PreprocessInstance() = 0;
// divide pv instance
virtual void PostprocessInstance() = 0;
// only for untest
virtual void SetCurrentPhase(int current_phase) = 0;
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num,
......@@ -161,6 +172,10 @@ class DatasetImpl : public Dataset {
virtual void SetChannelNum(int channel_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge);
virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
......@@ -192,8 +207,12 @@ class DatasetImpl : public Dataset {
virtual void CreateReaders();
virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
virtual int64_t GetPvDataSize();
virtual int64_t GetShuffleDataSize();
virtual void MergeByInsId() {}
virtual void PreprocessInstance() {}
virtual void PostprocessInstance() {}
virtual void SetCurrentPhase(int current_phase) {}
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num,
......@@ -213,6 +232,10 @@ class DatasetImpl : public Dataset {
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
paddle::framework::Channel<T> input_channel_;
paddle::framework::Channel<PvInstance> input_pv_channel_;
std::vector<paddle::framework::Channel<PvInstance>> multi_pv_output_;
std::vector<paddle::framework::Channel<PvInstance>> multi_pv_consume_;
int channel_num_;
std::vector<paddle::framework::Channel<T>> multi_output_channel_;
std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
......@@ -238,6 +261,10 @@ class DatasetImpl : public Dataset {
bool merge_by_insid_;
bool parse_ins_id_;
bool parse_content_;
bool parse_logkey_;
bool merge_by_sid_;
bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update
size_t merge_size_;
bool slots_shuffle_fea_eval_ = false;
bool gen_uni_feasigns_ = false;
......@@ -252,6 +279,9 @@ class MultiSlotDataset : public DatasetImpl<Record> {
public:
MultiSlotDataset() {}
virtual void MergeByInsId();
virtual void PreprocessInstance();
virtual void PostprocessInstance();
virtual void SetCurrentPhase(int current_phase);
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num, int shard_num);
......@@ -266,6 +296,9 @@ class MultiSlotDataset : public DatasetImpl<Record> {
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual ~MultiSlotDataset() {}
protected:
std::vector<Record> input_records_; // the real data
};
} // end namespace framework
......
......@@ -239,6 +239,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("get_memory_data_size", &framework::Dataset::GetMemoryDataSize,
py::call_guard<py::gil_scoped_release>())
.def("get_pv_data_size", &framework::Dataset::GetPvDataSize,
py::call_guard<py::gil_scoped_release>())
.def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize,
py::call_guard<py::gil_scoped_release>())
.def("set_queue_num", &framework::Dataset::SetChannelNum,
......@@ -247,6 +249,19 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("set_parse_content", &framework::Dataset::SetParseContent,
py::call_guard<py::gil_scoped_release>())
.def("set_parse_logkey", &framework::Dataset::SetParseLogKey,
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_sid", &framework::Dataset::SetMergeBySid,
py::call_guard<py::gil_scoped_release>())
.def("preprocess_instance", &framework::Dataset::PreprocessInstance,
py::call_guard<py::gil_scoped_release>())
.def("postprocess_instance", &framework::Dataset::PostprocessInstance,
py::call_guard<py::gil_scoped_release>())
.def("set_current_phase", &framework::Dataset::SetCurrentPhase,
py::call_guard<py::gil_scoped_release>())
.def("set_enable_pv_merge", &framework::Dataset::SetEnablePvMerge,
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId,
py::call_guard<py::gil_scoped_release>())
.def("merge_by_lineid", &framework::Dataset::MergeByInsId,
......
......@@ -92,6 +92,23 @@ class DatasetBase(object):
"""
self.proto_desc.pipe_command = pipe_command
def set_rank_offset(self, rank_offset):
"""
Set rank_offset for merge_pv. It set the message of Pv.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_rank_offset("rank_offset")
Args:
rank_offset(str): rank_offset's name
"""
self.proto_desc.rank_offset = rank_offset
def set_fea_eval(self, record_candidate_size, fea_eval=True):
"""
set fea eval mode for slots shuffle to debug the importance level of
......@@ -154,6 +171,22 @@ class DatasetBase(object):
"""
self.proto_desc.batch_size = batch_size
def set_pv_batch_size(self, pv_batch_size):
"""
Set pv batch size. It will be effective during enable_pv_merge
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_pv_batch(128)
Args:
pv_batch_size(int): pv batch size
"""
self.proto_desc.pv_batch_size = pv_batch_size
def set_thread(self, thread_num):
"""
Set thread num, it is the num of readers.
......@@ -308,9 +341,18 @@ class InMemoryDataset(DatasetBase):
self.queue_num = None
self.parse_ins_id = False
self.parse_content = False
self.parse_logkey = False
self.merge_by_sid = True
self.enable_pv_merge = False
self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None
def set_feed_type(self, data_feed_type):
"""
Set data_feed_desc
"""
self.proto_desc.name = data_feed_type
def _prepare_to_run(self):
"""
Set data_feed_desc before load or shuffle,
......@@ -324,6 +366,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.set_queue_num(self.queue_num)
self.dataset.set_parse_ins_id(self.parse_ins_id)
self.dataset.set_parse_content(self.parse_content)
self.dataset.set_parse_logkey(self.parse_logkey)
self.dataset.set_merge_by_sid(self.merge_by_sid)
self.dataset.set_enable_pv_merge(self.enable_pv_merge)
self.dataset.set_data_feed_desc(self.desc())
self.dataset.create_channel()
self.dataset.create_readers()
......@@ -390,6 +435,112 @@ class InMemoryDataset(DatasetBase):
"""
self.parse_content = parse_content
def set_parse_logkey(self, parse_logkey):
"""
Set if Dataset need to parse logkey
Args:
parse_content(bool): if parse logkey or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_parse_logkey(True)
"""
self.parse_logkey = parse_logkey
def set_merge_by_sid(self, merge_by_sid):
"""
Set if Dataset need to merge sid. If not, one ins means one Pv.
Args:
merge_by_sid(bool): if merge sid or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_sid(True)
"""
self.merge_by_sid = merge_by_sid
def set_enable_pv_merge(self, enable_pv_merge):
"""
Set if Dataset need to merge pv.
Args:
enable_pv_merge(bool): if enable_pv_merge or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_enable_pv_merge(True)
"""
self.enable_pv_merge = enable_pv_merge
def preprocess_instance(self):
"""
Merge pv instance and convey it from input_channel to input_pv_channel.
It will be effective when enable_pv_merge_ is True.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.preprocess_instance()
"""
self.dataset.preprocess_instance()
def set_current_phase(self, current_phase):
"""
Set current phase in train. It is useful for untest.
current_phase : 1 for join, 0 for update.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.set_current_phase(1)
"""
self.dataset.set_current_phase(current_phase)
def postprocess_instance(self):
"""
Divide pv instance and convey it to input_channel.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.preprocess_instance()
exe.train_from_dataset(dataset)
dataset.postprocess_instance()
"""
self.dataset.postprocess_instance()
def set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
"""
Set fleet send batch size, default is 1024
......@@ -594,6 +745,30 @@ class InMemoryDataset(DatasetBase):
"""
self.dataset.release_memory()
def get_pv_data_size(self):
"""
Get memory data size of Pv, user can call this function to know the pv num
of ins in all workers after load into memory.
Note:
This function may cause bad performance, because it has barrier
Returns:
The size of memory pv data.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
print dataset.get_pv_data_size()
"""
return self.dataset.get_pv_data_size()
def get_memory_data_size(self, fleet=None):
"""
Get memory data size, user can call this function to know the num
......@@ -808,6 +983,7 @@ class BoxPSDataset(InMemoryDataset):
"""
super(BoxPSDataset, self).__init__()
self.boxps = core.BoxPS(self.dataset)
self.proto_desc.name = "PaddleBoxDataFeed"
def set_date(self, date):
"""
......@@ -895,3 +1071,6 @@ class BoxPSDataset(InMemoryDataset):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num, True)
self.dataset.dynamic_adjust_readers_num(thread_num)
def _dynamic_adjust_after_train(self):
pass
......@@ -44,6 +44,7 @@ endif()
if(WIN32)
LIST(REMOVE_ITEM TEST_OPS test_boxps)
LIST(REMOVE_ITEM TEST_OPS test_paddlebox_datafeed)
LIST(REMOVE_ITEM TEST_OPS test_trainer_desc)
LIST(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception)
LIST(REMOVE_ITEM TEST_OPS test_avoid_twice_initialization)
......@@ -59,6 +60,7 @@ endif()
if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_pipeline)
LIST(REMOVE_ITEM TEST_OPS test_boxps)
LIST(REMOVE_ITEM TEST_OPS test_paddlebox_datafeed)
endif()
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.core as core
import os
import unittest
import paddle.fluid.layers as layers
from paddle.fluid.layers.nn import _pull_box_sparse
class TestDataFeed(unittest.TestCase):
""" TestBaseCase(Merge PV) """
def setUp(self):
self.batch_size = 10
self.pv_batch_size = 10
self.enable_pv_merge = True
self.merge_by_sid = True
def set_data_config(self):
self.dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
self.dataset.set_feed_type("PaddleBoxDataFeed")
self.dataset.set_parse_logkey(True)
self.dataset.set_thread(1)
self.dataset.set_enable_pv_merge(self.enable_pv_merge)
self.dataset.set_batch_size(self.batch_size)
if self.enable_pv_merge:
self.dataset.set_merge_by_sid(self.merge_by_sid)
self.dataset.set_rank_offset("rank_offset")
self.dataset.set_pv_batch_size(self.pv_batch_size)
def test_pboxdatafeed(self):
self.run_dataset(False)
def test_pboxdatafeed(self):
self.run_dataset(True)
def run_dataset(self, is_cpu):
x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0)
y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0)
rank_offset = fluid.layers.data(
name="rank_offset",
shape=[-1, 7],
dtype="int32",
lod_level=0,
append_batch_size=False)
emb_x, emb_y = _pull_box_sparse([x, y], size=2)
emb_xp = _pull_box_sparse(x, size=2)
concat = layers.concat([emb_x, emb_y], axis=1)
fc = layers.fc(input=concat,
name="fc",
size=1,
num_flatten_dims=1,
bias_attr=False)
loss = layers.reduce_mean(fc)
place = fluid.CPUPlace() if is_cpu or not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
with open("test_run_with_dump_a.txt", "w") as f:
data = "1 1702f830eee19501ad7429505f714c1d 1 1 1 9\n"
data += "1 1702f830eee19502ad7429505f714c1d 1 2 1 8\n"
data += "1 1702f830eee19503ad7429505f714c1d 1 3 1 7\n"
data += "1 1702f830eee0de01ad7429505f714c2d 1 4 1 6\n"
data += "1 1702f830eee0df01ad7429505f714c3d 1 5 1 5\n"
data += "1 1702f830eee0df02ad7429505f714c3d 1 6 1 4\n"
f.write(data)
with open("test_run_with_dump_b.txt", "w") as f:
data = "1 1702f830fff22201ad7429505f715c1d 1 1 1 1\n"
data += "1 1702f830fff22202ad7429505f715c1d 1 2 1 2\n"
data += "1 1702f830fff22203ad7429505f715c1d 1 3 1 3\n"
data += "1 1702f830fff22101ad7429505f714ccd 1 4 1 4\n"
data += "1 1702f830fff22102ad7429505f714ccd 1 5 1 5\n"
data += "1 1702f830fff22103ad7429505f714ccd 1 6 1 6\n"
data += "1 1702f830fff22104ad7429505f714ccd 1 6 1 7\n"
f.write(data)
self.set_data_config()
self.dataset.set_use_var([x, y])
self.dataset.set_filelist(
["test_run_with_dump_a.txt", "test_run_with_dump_b.txt"])
optimizer = fluid.optimizer.SGD(learning_rate=0.5)
optimizer = fluid.optimizer.PipelineOptimizer(
optimizer,
cut_list=[],
place_list=[place],
concurrency_list=[1],
queue_size=1,
sync_steps=-1)
optimizer.minimize(loss)
exe.run(fluid.default_startup_program())
self.dataset.set_current_phase(1)
self.dataset.load_into_memory()
self.dataset.preprocess_instance()
self.dataset.begin_pass()
pv_num = self.dataset.get_pv_data_size()
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=self.dataset,
print_period=1)
self.dataset.set_current_phase(0)
self.dataset.postprocess_instance()
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=self.dataset,
print_period=1)
self.dataset.end_pass(True)
os.remove("test_run_with_dump_a.txt")
os.remove("test_run_with_dump_b.txt")
class TestDataFeed2(TestDataFeed):
""" TestBaseCase(Merge PV not merge by sid) """
def setUp(self):
self.batch_size = 10
self.pv_batch_size = 10
self.enable_pv_merge = True
self.merge_by_sid = False
class TestDataFeed3(TestDataFeed):
""" TestBaseCase(Not Merge PV) """
def setUp(self):
self.batch_size = 10
self.pv_batch_size = 10
self.enable_pv_merge = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册