未验证 提交 79bd5f90 编写于 作者: Y yaoxuefeng 提交者: GitHub

add slot record dataset (#36200)

上级 83578cfa
......@@ -157,7 +157,19 @@ class ChannelObject {
p.resize(finished);
return finished;
}
// read once only
size_t ReadOnce(std::vector<T>& p, size_t size) { // NOLINT
if (size == 0) {
return 0;
}
std::unique_lock<std::mutex> lock(mutex_);
p.resize(size);
size_t finished = Read(size, &p[0], lock, true);
p.resize(finished);
Notify();
return finished;
}
size_t ReadAll(std::vector<T>& p) { // NOLINT
p.clear();
size_t finished = 0;
......@@ -241,17 +253,21 @@ class ChannelObject {
return !closed_;
}
size_t Read(size_t n, T* p, std::unique_lock<std::mutex>& lock) { // NOLINT
size_t Read(size_t n, T* p, std::unique_lock<std::mutex>& lock, // NOLINT
bool once = false) { // NOLINT
size_t finished = 0;
CHECK(n <= MaxCapacity() - reading_count_);
reading_count_ += n;
while (finished < n && WaitForRead(lock)) {
size_t m = std::min(n - finished, data_.size());
size_t m = (std::min)(n - finished, data_.size());
for (size_t i = 0; i < m; i++) {
p[finished++] = std::move(data_.front());
data_.pop_front();
}
reading_count_ -= m;
if (once && m > 0) {
break;
}
}
reading_count_ -= n - finished;
return finished;
......
......@@ -36,6 +36,107 @@ DLManager& global_dlmanager_pool() {
return manager;
}
class BufferedLineFileReader {
typedef std::function<bool()> SampleFunc;
static const int MAX_FILE_BUFF_SIZE = 4 * 1024 * 1024;
class FILEReader {
public:
explicit FILEReader(FILE* fp) : fp_(fp) {}
int read(char* buf, int len) { return fread(buf, sizeof(char), len, fp_); }
private:
FILE* fp_;
};
public:
typedef std::function<bool(const std::string&)> LineFunc;
private:
template <typename T>
int read_lines(T* reader, LineFunc func, int skip_lines) {
int lines = 0;
size_t ret = 0;
char* ptr = NULL;
char* eol = NULL;
total_len_ = 0;
error_line_ = 0;
SampleFunc spfunc = get_sample_func();
std::string x;
while (!is_error() && (ret = reader->read(buff_, MAX_FILE_BUFF_SIZE)) > 0) {
total_len_ += ret;
ptr = buff_;
eol = reinterpret_cast<char*>(memchr(ptr, '\n', ret));
while (eol != NULL) {
int size = static_cast<int>((eol - ptr) + 1);
x.append(ptr, size - 1);
++lines;
if (lines > skip_lines && spfunc()) {
if (!func(x)) {
++error_line_;
}
}
x.clear();
ptr += size;
ret -= size;
eol = reinterpret_cast<char*>(memchr(ptr, '\n', ret));
}
if (ret > 0) {
x.append(ptr, ret);
}
}
if (!is_error() && !x.empty()) {
++lines;
if (lines > skip_lines && spfunc()) {
if (!func(x)) {
++error_line_;
}
}
}
return lines;
}
public:
BufferedLineFileReader()
: random_engine_(std::random_device()()),
uniform_distribution_(0.0f, 1.0f) {
total_len_ = 0;
sample_line_ = 0;
buff_ =
reinterpret_cast<char*>(calloc(MAX_FILE_BUFF_SIZE + 1, sizeof(char)));
}
~BufferedLineFileReader() { free(buff_); }
int read_file(FILE* fp, LineFunc func, int skip_lines) {
FILEReader reader(fp);
return read_lines<FILEReader>(&reader, func, skip_lines);
}
uint64_t file_size(void) { return total_len_; }
void set_sample_rate(float r) { sample_rate_ = r; }
size_t get_sample_line() { return sample_line_; }
bool is_error(void) { return (error_line_ > 10); }
private:
SampleFunc get_sample_func() {
if (std::abs(sample_rate_ - 1.0f) < 1e-5f) {
return [this](void) { return true; };
}
return [this](void) {
return (uniform_distribution_(random_engine_) < sample_rate_);
};
}
private:
char* buff_ = nullptr;
uint64_t total_len_ = 0;
std::default_random_engine random_engine_;
std::uniform_real_distribution<float> uniform_distribution_;
float sample_rate_ = 1.0f;
size_t sample_line_ = 0;
size_t error_line_ = 0;
};
void RecordCandidateList::ReSize(size_t length) {
mutex_.lock();
capacity_ = length;
......@@ -301,7 +402,7 @@ int InMemoryDataFeed<T>::Next() {
<< ", thread_id=" << thread_id_;
}
} else {
VLOG(3) << "enable heter NEXT: " << offset_index_
VLOG(3) << "enable heter next: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size();
if (offset_index_ >= batch_offsets_.size()) {
VLOG(3) << "offset_index: " << offset_index_
......@@ -318,14 +419,7 @@ int InMemoryDataFeed<T>::Next() {
VLOG(3) << "finish reading for heterps, batch size zero, thread_id="
<< thread_id_;
}
/*
if (offset_index_ == batch_offsets_.size() - 1) {
std::vector<Record> data;
output_channel_->ReadAll(data);
consume_channel_->Write(std::move(data));
}
*/
VLOG(3) << "#15 enable heter NEXT: " << offset_index_
VLOG(3) << "enable heter next: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size()
<< " baych_size: " << this->batch_size_;
}
......
......@@ -39,8 +39,14 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/string_helper.h"
DECLARE_int32(record_pool_max_size);
DECLARE_int32(slotpool_thread_num);
DECLARE_bool(enable_slotpool_wait_release);
DECLARE_bool(enable_slotrecord_reset_shrink);
namespace paddle {
namespace framework {
class DataFeedDesc;
......@@ -69,6 +75,50 @@ namespace framework {
// while (reader->Next()) {
// // trainer do something
// }
template <typename T>
struct SlotValues {
std::vector<T> slot_values;
std::vector<uint32_t> slot_offsets;
void add_values(const T* values, uint32_t num) {
if (slot_offsets.empty()) {
slot_offsets.push_back(0);
}
if (num > 0) {
slot_values.insert(slot_values.end(), values, values + num);
}
slot_offsets.push_back(static_cast<uint32_t>(slot_values.size()));
}
T* get_values(int idx, size_t* size) {
uint32_t& offset = slot_offsets[idx];
(*size) = slot_offsets[idx + 1] - offset;
return &slot_values[offset];
}
void add_slot_feasigns(const std::vector<std::vector<T>>& slot_feasigns,
uint32_t fea_num) {
slot_values.reserve(fea_num);
int slot_num = static_cast<int>(slot_feasigns.size());
slot_offsets.resize(slot_num + 1);
for (int i = 0; i < slot_num; ++i) {
auto& slot_val = slot_feasigns[i];
slot_offsets[i] = static_cast<uint32_t>(slot_values.size());
uint32_t num = static_cast<uint32_t>(slot_val.size());
if (num > 0) {
slot_values.insert(slot_values.end(), slot_val.begin(), slot_val.end());
}
}
slot_offsets[slot_num] = slot_values.size();
}
void clear(bool shrink) {
slot_offsets.clear();
slot_values.clear();
if (shrink) {
slot_values.shrink_to_fit();
slot_offsets.shrink_to_fit();
}
}
};
union FeatureFeasign {
uint64_t uint64_feasign_;
float float_feasign_;
......@@ -97,6 +147,38 @@ struct FeatureItem {
uint16_t slot_;
};
struct AllSlotInfo {
std::string slot;
std::string type;
int used_idx;
int slot_value_idx;
};
struct UsedSlotInfo {
int idx;
int slot_value_idx;
std::string slot;
std::string type;
bool dense;
std::vector<int> local_shape;
int total_dims_without_inductive;
int inductive_shape_index;
};
struct SlotRecordObject {
uint64_t search_id;
uint32_t rank;
uint32_t cmatch;
std::string ins_id_;
SlotValues<uint64_t> slot_uint64_feasigns_;
SlotValues<float> slot_float_feasigns_;
~SlotRecordObject() { clear(true); }
void reset(void) { clear(FLAGS_enable_slotrecord_reset_shrink); }
void clear(bool shrink) {
slot_uint64_feasigns_.clear(shrink);
slot_float_feasigns_.clear(shrink);
}
};
using SlotRecord = SlotRecordObject*;
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
std::vector<FeatureItem> uint64_feasigns_;
......@@ -108,6 +190,179 @@ struct Record {
uint32_t cmatch;
};
inline SlotRecord make_slotrecord() {
static const size_t slot_record_byte_size = sizeof(SlotRecordObject);
void* p = malloc(slot_record_byte_size);
new (p) SlotRecordObject;
return reinterpret_cast<SlotRecordObject*>(p);
}
inline void free_slotrecord(SlotRecordObject* p) {
p->~SlotRecordObject();
free(p);
}
template <class T>
class SlotObjAllocator {
public:
explicit SlotObjAllocator(std::function<void(T*)> deleter)
: free_nodes_(NULL), capacity_(0), deleter_(deleter) {}
~SlotObjAllocator() { clear(); }
void clear() {
T* tmp = NULL;
while (free_nodes_ != NULL) {
tmp = reinterpret_cast<T*>(reinterpret_cast<void*>(free_nodes_));
free_nodes_ = free_nodes_->next;
deleter_(tmp);
--capacity_;
}
CHECK_EQ(capacity_, static_cast<size_t>(0));
}
T* acquire(void) {
T* x = NULL;
x = reinterpret_cast<T*>(reinterpret_cast<void*>(free_nodes_));
free_nodes_ = free_nodes_->next;
--capacity_;
return x;
}
void release(T* x) {
Node* node = reinterpret_cast<Node*>(reinterpret_cast<void*>(x));
node->next = free_nodes_;
free_nodes_ = node;
++capacity_;
}
size_t capacity(void) { return capacity_; }
private:
struct alignas(T) Node {
union {
Node* next;
char data[sizeof(T)];
};
};
Node* free_nodes_; // a list
size_t capacity_;
std::function<void(T*)> deleter_ = nullptr;
};
static const int OBJPOOL_BLOCK_SIZE = 10000;
class SlotObjPool {
public:
SlotObjPool()
: max_capacity_(FLAGS_record_pool_max_size), alloc_(free_slotrecord) {
ins_chan_ = MakeChannel<SlotRecord>();
ins_chan_->SetBlockSize(OBJPOOL_BLOCK_SIZE);
for (int i = 0; i < FLAGS_slotpool_thread_num; ++i) {
threads_.push_back(std::thread([this]() { run(); }));
}
disable_pool_ = false;
count_ = 0;
}
~SlotObjPool() {
ins_chan_->Close();
for (auto& t : threads_) {
t.join();
}
}
void disable_pool(bool disable) { disable_pool_ = disable; }
void set_max_capacity(size_t max_capacity) { max_capacity_ = max_capacity; }
void get(std::vector<SlotRecord>* output, int n) {
output->resize(n);
return get(&(*output)[0], n);
}
void get(SlotRecord* output, int n) {
int size = 0;
mutex_.lock();
int left = static_cast<int>(alloc_.capacity());
if (left > 0) {
size = (left >= n) ? n : left;
for (int i = 0; i < size; ++i) {
output[i] = alloc_.acquire();
}
}
mutex_.unlock();
count_ += n;
if (size == n) {
return;
}
for (int i = size; i < n; ++i) {
output[i] = make_slotrecord();
}
}
void put(std::vector<SlotRecord>* input) {
size_t size = input->size();
if (size == 0) {
return;
}
put(&(*input)[0], size);
input->clear();
}
void put(SlotRecord* input, size_t size) {
CHECK(ins_chan_->WriteMove(size, input) == size);
}
void run(void) {
std::vector<SlotRecord> input;
while (ins_chan_->ReadOnce(input, OBJPOOL_BLOCK_SIZE)) {
if (input.empty()) {
continue;
}
// over max capacity
size_t n = input.size();
count_ -= n;
if (disable_pool_ || n + capacity() > max_capacity_) {
for (auto& t : input) {
free_slotrecord(t);
}
} else {
for (auto& t : input) {
t->reset();
}
mutex_.lock();
for (auto& t : input) {
alloc_.release(t);
}
mutex_.unlock();
}
input.clear();
}
}
void clear(void) {
platform::Timer timeline;
timeline.Start();
mutex_.lock();
alloc_.clear();
mutex_.unlock();
// wait release channel data
if (FLAGS_enable_slotpool_wait_release) {
while (!ins_chan_->Empty()) {
sleep(1);
}
}
timeline.Pause();
VLOG(3) << "clear slot pool data size=" << count_.load()
<< ", span=" << timeline.ElapsedSec();
}
size_t capacity(void) {
mutex_.lock();
size_t total = alloc_.capacity();
mutex_.unlock();
return total;
}
private:
size_t max_capacity_;
Channel<SlotRecord> ins_chan_;
std::vector<std::thread> threads_;
std::mutex mutex_;
SlotObjAllocator<SlotRecordObject> alloc_;
bool disable_pool_;
std::atomic<long> count_; // NOLINT
};
inline SlotObjPool& SlotRecordPool() {
static SlotObjPool pool;
return pool;
}
struct PvInstanceObject {
std::vector<Record*> ads;
void merge_instance(Record* ins) { ads.push_back(ins); }
......@@ -129,7 +384,21 @@ class CustomParser {
CustomParser() {}
virtual ~CustomParser() {}
virtual void Init(const std::vector<SlotConf>& slots) = 0;
virtual bool Init(const std::vector<AllSlotInfo>& slots) = 0;
virtual void ParseOneInstance(const char* str, Record* instance) = 0;
virtual bool ParseOneInstance(
const std::string& line,
std::function<void(std::vector<SlotRecord>&, int)>
GetInsFunc) { // NOLINT
return true;
}
virtual bool ParseFileInstance(
std::function<int(char* buf, int len)> ReadBuffFunc,
std::function<void(std::vector<SlotRecord>&, int, int)>
PullRecordsFunc, // NOLINT
int& lines) { // NOLINT
return false;
}
};
typedef paddle::framework::CustomParser* (*CreateParserObjectFunc)();
......@@ -194,6 +463,34 @@ class DLManager {
return nullptr;
}
paddle::framework::CustomParser* Load(const std::string& name,
const std::vector<AllSlotInfo>& conf) {
#ifdef _LINUX
std::lock_guard<std::mutex> lock(mutex_);
DLHandle handle;
std::map<std::string, DLHandle>::iterator it = handle_map_.find(name);
if (it != handle_map_.end()) {
return it->second.parser;
}
handle.module = dlopen(name.c_str(), RTLD_NOW);
if (handle.module == nullptr) {
VLOG(0) << "Create so of " << name << " fail";
exit(-1);
return nullptr;
}
CreateParserObjectFunc create_parser_func =
(CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject");
handle.parser = create_parser_func();
handle.parser->Init(conf);
handle_map_.insert({name, handle});
return handle.parser;
#endif
VLOG(0) << "Not implement in windows";
return nullptr;
}
paddle::framework::CustomParser* ReLoad(const std::string& name,
const std::vector<SlotConf>& conf) {
Close(name);
......@@ -415,6 +712,11 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetCurrentPhase(int current_phase);
virtual void LoadIntoMemory();
virtual void LoadIntoMemoryFromSo();
virtual void SetRecord(T* records) { records_ = records; }
int GetDefaultBatchSize() { return default_batch_size_; }
void AddBatchOffset(const std::pair<int, int>& offset) {
batch_offsets_.push_back(offset);
}
protected:
virtual bool ParseOneInstance(T* instance) = 0;
......@@ -424,6 +726,11 @@ class InMemoryDataFeed : public DataFeed {
virtual void PutToFeedVec(const std::vector<T>& ins_vec) = 0;
virtual void PutToFeedVec(const T* ins_vec, int num) = 0;
std::vector<std::vector<float>> batch_float_feasigns_;
std::vector<std::vector<uint64_t>> batch_uint64_feasigns_;
std::vector<std::vector<size_t>> offset_;
std::vector<bool> visit_;
int thread_id_;
int thread_num_;
bool parse_ins_id_;
......@@ -783,11 +1090,7 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc);
void SetRecord(Record* records) { records_ = records; }
int GetDefaultBatchSize() { return default_batch_size_; }
void AddBatchOffset(const std::pair<int, int>& offset) {
batch_offsets_.push_back(offset);
}
// void SetRecord(Record* records) { records_ = records; }
protected:
virtual bool ParseOneInstance(Record* instance);
......@@ -798,10 +1101,6 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id,
uint32_t* cmatch, uint32_t* rank);
virtual void PutToFeedVec(const Record* ins_vec, int num);
std::vector<std::vector<float>> batch_float_feasigns_;
std::vector<std::vector<uint64_t>> batch_uint64_feasigns_;
std::vector<std::vector<size_t>> offset_;
std::vector<bool> visit_;
};
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
......
......@@ -351,10 +351,8 @@ static int compute_thread_batch_nccl(
return thread_avg_batch_num;
}
template <typename T>
void DatasetImpl<T>::SetHeterPs(bool enable_heterps) {
void MultiSlotDataset::PrepareTrain() {
#ifdef PADDLE_WITH_GLOO
enable_heterps_ = enable_heterps;
if (enable_heterps_) {
if (input_records_.size() == 0 && input_channel_ != nullptr &&
input_channel_->Size() != 0) {
......@@ -541,22 +539,21 @@ void DatasetImpl<T>::LocalShuffle() {
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::GlobalShuffle(int thread_num) {
void MultiSlotDataset::GlobalShuffle(int thread_num) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
VLOG(3) << "MultiSlotDataset::GlobalShuffle() begin";
platform::Timer timeline;
timeline.Start();
auto fleet_ptr = FleetWrapper::GetInstance();
if (!input_channel_ || input_channel_->Size() == 0) {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, no data to shuffle";
VLOG(3) << "MultiSlotDataset::GlobalShuffle() end, no data to shuffle";
return;
}
// local shuffle
input_channel_->Close();
std::vector<T> data;
std::vector<Record> data;
input_channel_->ReadAll(data);
std::shuffle(data.begin(), data.end(), fleet_ptr->LocalRandomEngine());
input_channel_->Open();
......@@ -566,10 +563,10 @@ void DatasetImpl<T>::GlobalShuffle(int thread_num) {
input_channel_->Close();
input_channel_->SetBlockSize(fleet_send_batch_size_);
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() input_channel_ size "
VLOG(3) << "MultiSlotDataset::GlobalShuffle() input_channel_ size "
<< input_channel_->Size();
auto get_client_id = [this, fleet_ptr](const T& data) -> size_t {
auto get_client_id = [this, fleet_ptr](const Record& data) -> size_t {
if (!this->merge_by_insid_) {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} else {
......@@ -580,7 +577,7 @@ void DatasetImpl<T>::GlobalShuffle(int thread_num) {
auto global_shuffle_func = [this, get_client_id]() {
auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<T> data;
std::vector<Record> data;
while (this->input_channel_->Read(data)) {
std::vector<paddle::framework::BinaryArchive> ars(this->trainer_num_);
for (auto& t : data) {
......@@ -835,9 +832,6 @@ void DatasetImpl<T>::CreateReaders() {
channel_idx = 0;
}
}
if (enable_heterps_) {
SetHeterPs(true);
}
VLOG(3) << "readers size: " << readers_.size();
}
......@@ -923,8 +917,7 @@ int64_t DatasetImpl<T>::GetShuffleDataSize() {
return sum;
}
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
int MultiSlotDataset::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
#ifdef _LINUX
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
......@@ -937,9 +930,9 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
if (ar.Cursor() == ar.Finish()) {
return 0;
}
std::vector<T> data;
std::vector<Record> data;
while (ar.Cursor() < ar.Finish()) {
data.push_back(ar.Get<T>());
data.push_back(ar.Get<Record>());
}
CHECK(ar.Cursor() == ar.Finish());
......@@ -966,6 +959,20 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
// explicit instantiation
template class DatasetImpl<Record>;
void MultiSlotDataset::DynamicAdjustReadersNum(int thread_num) {
if (thread_num_ == thread_num) {
VLOG(3) << "DatasetImpl<T>::DynamicAdjustReadersNum thread_num_="
<< thread_num_ << ", thread_num_=thread_num, no need to adjust";
return;
}
VLOG(3) << "adjust readers num from " << thread_num_ << " to " << thread_num;
thread_num_ = thread_num;
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
CreateReaders();
VLOG(3) << "adjust readers num done";
PrepareTrain();
}
void MultiSlotDataset::PostprocessInstance() {
// divide pv instance, and merge to input_channel_
if (enable_pv_merge_) {
......@@ -1503,5 +1510,126 @@ void MultiSlotDataset::SlotsShuffle(
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
template class DatasetImpl<SlotRecord>;
void SlotRecordDataset::CreateChannel() {
if (input_channel_ == nullptr) {
input_channel_ = paddle::framework::MakeChannel<SlotRecord>();
}
}
void SlotRecordDataset::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
VLOG(3) << "thread num in Dataset: " << thread_num_;
VLOG(3) << "Filelist size in Dataset: " << filelist_.size();
VLOG(3) << "channel num in Dataset: " << channel_num_;
CHECK(thread_num_ > 0) << "thread num should > 0";
CHECK(channel_num_ > 0) << "channel num should > 0";
CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num";
VLOG(3) << "readers size: " << readers_.size();
if (readers_.size() != 0) {
VLOG(3) << "readers_.size() = " << readers_.size()
<< ", will not create again";
return;
}
VLOG(3) << "data feed class name: " << data_feed_desc_.name();
for (int i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_[i]->Init(data_feed_desc_);
readers_[i]->SetThreadId(i);
readers_[i]->SetThreadNum(thread_num_);
readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
readers_[i]->SetFileListIndex(&file_idx_);
readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_);
readers_[i]->SetFeaNum(&total_fea_num_);
readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseContent(parse_content_);
readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_);
readers_[i]->SetCurrentPhase(current_phase_);
if (input_channel_ != nullptr) {
readers_[i]->SetInputChannel(input_channel_.get());
}
}
VLOG(3) << "readers size: " << readers_.size();
}
void SlotRecordDataset::ReleaseMemory() {
VLOG(3) << "SlotRecordDataset::ReleaseMemory() begin";
platform::Timer timeline;
timeline.Start();
if (input_channel_) {
input_channel_->Clear();
input_channel_ = nullptr;
}
if (enable_heterps_) {
VLOG(3) << "put pool records size: " << input_records_.size();
SlotRecordPool().put(&input_records_);
input_records_.clear();
input_records_.shrink_to_fit();
VLOG(3) << "release heterps input records records size: "
<< input_records_.size();
}
readers_.clear();
readers_.shrink_to_fit();
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "SlotRecordDataset::ReleaseMemory() end";
VLOG(3) << "total_feasign_num_(" << STAT_GET(STAT_total_feasign_num_in_mem)
<< ") - current_fea_num_(" << total_fea_num_ << ") = ("
<< STAT_GET(STAT_total_feasign_num_in_mem) - total_fea_num_ << ")"
<< " object pool size=" << SlotRecordPool().capacity(); // For Debug
STAT_SUB(STAT_total_feasign_num_in_mem, total_fea_num_);
}
void SlotRecordDataset::GlobalShuffle(int thread_num) {
// TODO(yaoxuefeng)
return;
}
void SlotRecordDataset::DynamicAdjustChannelNum(int channel_num,
bool discard_remaining_ins) {
if (channel_num_ == channel_num) {
VLOG(3) << "DatasetImpl<T>::DynamicAdjustChannelNum channel_num_="
<< channel_num_ << ", channel_num_=channel_num, no need to adjust";
return;
}
VLOG(3) << "adjust channel num from " << channel_num_ << " to "
<< channel_num;
channel_num_ = channel_num;
if (static_cast<int>(input_channel_->Size()) >= channel_num) {
input_channel_->SetBlockSize(input_channel_->Size() / channel_num +
(discard_remaining_ins ? 0 : 1));
}
VLOG(3) << "adjust channel num done";
}
void SlotRecordDataset::PrepareTrain() {
#ifdef PADDLE_WITH_GLOO
return;
#else
PADDLE_THROW(platform::errors::Unavailable(
"dataset set heterps need compile with GLOO"));
#endif
return;
}
void SlotRecordDataset::DynamicAdjustReadersNum(int thread_num) {
if (thread_num_ == thread_num) {
VLOG(3) << "DatasetImpl<T>::DynamicAdjustReadersNum thread_num_="
<< thread_num_ << ", thread_num_=thread_num, no need to adjust";
return;
}
VLOG(3) << "adjust readers num from " << thread_num_ << " to " << thread_num;
thread_num_ = thread_num;
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
CreateReaders();
VLOG(3) << "adjust readers num done";
PrepareTrain();
}
} // end namespace framework
} // end namespace paddle
......@@ -149,7 +149,6 @@ class Dataset {
virtual void DynamicAdjustReadersNum(int thread_num) = 0;
// set fleet send sleep seconds
virtual void SetFleetSendSleepSeconds(int seconds) = 0;
virtual void SetHeterPs(bool enable_heterps) = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
......@@ -207,7 +206,7 @@ class DatasetImpl : public Dataset {
virtual void WaitPreLoadDone();
virtual void ReleaseMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle(int thread_num = -1);
virtual void GlobalShuffle(int thread_num = -1) {}
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
virtual const std::vector<T>& GetSlotsOriginalData() {
return slots_shuffle_original_data_;
......@@ -233,7 +232,11 @@ class DatasetImpl : public Dataset {
bool discard_remaining_ins = false);
virtual void DynamicAdjustReadersNum(int thread_num);
virtual void SetFleetSendSleepSeconds(int seconds);
virtual void SetHeterPs(bool enable_heterps);
/* for enable_heterps_
virtual void EnableHeterps(bool enable_heterps) {
enable_heterps_ = enable_heterps;
}
*/
std::vector<paddle::framework::Channel<T>>& GetMultiOutputChannel() {
return multi_output_channel_;
......@@ -251,7 +254,10 @@ class DatasetImpl : public Dataset {
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
const std::string& msg) {
// TODO(yaoxuefeng) for SlotRecordDataset
return -1;
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
paddle::framework::Channel<T> input_channel_;
......@@ -327,6 +333,32 @@ class MultiSlotDataset : public DatasetImpl<Record> {
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual ~MultiSlotDataset() {}
virtual void GlobalShuffle(int thread_num = -1);
virtual void DynamicAdjustReadersNum(int thread_num);
virtual void PrepareTrain();
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg);
};
class SlotRecordDataset : public DatasetImpl<SlotRecord> {
public:
SlotRecordDataset() { SlotRecordPool(); }
virtual ~SlotRecordDataset() {}
// create input channel
virtual void CreateChannel();
// create readers
virtual void CreateReaders();
// release memory
virtual void ReleaseMemory();
virtual void GlobalShuffle(int thread_num = -1);
virtual void DynamicAdjustChannelNum(int channel_num,
bool discard_remaining_ins);
virtual void PrepareTrain();
virtual void DynamicAdjustReadersNum(int thread_num);
protected:
bool enable_heterps_ = true;
};
} // end namespace framework
......
......@@ -53,7 +53,7 @@ std::unique_ptr<Dataset> DatasetFactory::CreateDataset(
std::string dataset_class) {
if (g_dataset_map.count(dataset_class) < 1) {
LOG(WARNING) << "Your Dataset " << dataset_class
<< "is not supported currently";
<< " is not supported currently";
LOG(WARNING) << "Supported Dataset: " << DatasetTypeList();
exit(-1);
}
......@@ -61,5 +61,6 @@ std::unique_ptr<Dataset> DatasetFactory::CreateDataset(
}
REGISTER_DATASET_CLASS(MultiSlotDataset);
REGISTER_DATASET_CLASS(SlotRecordDataset);
} // namespace framework
} // namespace paddle
......@@ -680,3 +680,11 @@ PADDLE_DEFINE_EXPORTED_int32(get_host_by_name_time, 120,
PADDLE_DEFINE_EXPORTED_bool(
apply_pass_to_program, false,
"It controls whether to apply IR pass to program when using Fleet APIs");
DEFINE_int32(record_pool_max_size, 2000000,
"SlotRecordDataset slot record pool max size");
DEFINE_int32(slotpool_thread_num, 1, "SlotRecordDataset slot pool thread num");
DEFINE_bool(enable_slotpool_wait_release, false,
"enable slotrecord obejct wait release, default false");
DEFINE_bool(enable_slotrecord_reset_shrink, false,
"enable slotrecord obejct reset shrink memory, default false");
\ No newline at end of file
......@@ -309,8 +309,6 @@ void BindDataset(py::module *m) {
&framework::Dataset::SetFleetSendSleepSeconds,
py::call_guard<py::gil_scoped_release>())
.def("enable_pv_merge", &framework::Dataset::EnablePvMerge,
py::call_guard<py::gil_scoped_release>())
.def("set_heter_ps", &framework::Dataset::SetHeterPs,
py::call_guard<py::gil_scoped_release>());
py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册