From a34fe6248fe12fd86898d044d07773fa48f70945 Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Wed, 20 Mar 2019 16:59:50 +0800 Subject: [PATCH] add some doc --- paddle/fluid/framework/data_set.cc | 29 ++++++++------------ paddle/fluid/framework/data_set.h | 13 +++++++++ paddle/fluid/framework/fleet/fleet_wrapper.h | 2 ++ 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index b0f5d1867a..fe71160c1d 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -24,6 +24,7 @@ namespace paddle { namespace framework { +// constructor template DatasetImpl::DatasetImpl() { thread_num_ = 1; @@ -31,37 +32,24 @@ DatasetImpl::DatasetImpl() { file_idx_ = 0; } +// set filelist, file_idx_ will reset to zero. template void DatasetImpl::SetFileList(const std::vector& filelist) { VLOG(3) << "filelist size: " << filelist.size(); filelist_ = filelist; file_idx_ = 0; - /* - int file_cnt = filelist_.size(); - if (thread_num_ > file_cnt) { - VLOG(1) << "DataSet thread num = " << thread_num_ - << ", file num = " << file_cnt - << ". Changing DataSet thread num = " << file_cnt; - thread_num_ = file_cnt; - }*/ } -// buggy here, a user should set filelist first before this function -// not user friendly +// set expect thread num. actually it may change template void DatasetImpl::SetThreadNum(int thread_num) { VLOG(3) << "SetThreadNum thread_num=" << thread_num; - //int file_cnt = filelist_.size(); - /* - if (file_cnt != 0 && thread_num > file_cnt) { - VLOG(3) << "DataSet thread num = " << thread_num - << ", file num = " << file_cnt - << ". Changing DataSet thread num = " << file_cnt; - thread_num = file_cnt; - }*/ thread_num_ = thread_num; } +// if you run distributed, and want to do global shuffle, +// set this before global shuffle. +// be sure you call CreateReaders before SetTrainerNum template void DatasetImpl::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; @@ -86,12 +74,16 @@ void DatasetImpl::SetDataFeedDesc(const std::string& data_feed_desc_str) { &data_feed_desc_); } +// readers_.size() may not be equal to thread_num_, +// it changes when filelist_.size() < thread_num_ template std::vector>& DatasetImpl::GetReaders() { return readers_; } +// load data into memory, Dataset hold this memory, +// which will later be fed into readers' channel template void DatasetImpl::LoadIntoMemory() { VLOG(3) << "DatasetImpl::LoadIntoMemory() begin"; @@ -114,6 +106,7 @@ void DatasetImpl::LoadIntoMemory() { << ", cost time=" << timeline.ElapsedSec() << " seconds"; } +// do local shuffle template void DatasetImpl::LocalShuffle() { VLOG(3) << "DatasetImpl::LocalShuffle() begin"; diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 02e07c5b5f..a13d0f869d 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -26,6 +26,16 @@ namespace paddle { namespace framework { +// Dataset is a abstract class, which defines user interfaces +// Example Usage: +// Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset") +// dataset->SetFileList(std::vector{"a.txt", "b.txt"}) +// dataset->SetThreadNum(1) +// dataset->CreateReaders(); +// dataset->SetDataFeedDesc(your_data_feed_desc); +// dataset->LoadIntoMemory(); +// dataset->SetTrainerNum(2); +// dataset->GlobalShuffle(); class Dataset { public: Dataset() {} @@ -53,6 +63,8 @@ class Dataset { const std::string& msg) = 0; }; +// DatasetImpl is the implementation of Dataset, +// it holds memory data if user calls load_into_memory template class DatasetImpl : public Dataset { public: @@ -95,6 +107,7 @@ class DatasetImpl : public Dataset { std::mutex mutex_for_pick_file_; }; +// use std::vector as data type class MultiSlotDataset : public DatasetImpl> { public: MultiSlotDataset() {} diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 9e08ef6474..6090ec753d 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -146,7 +146,9 @@ class FleetWrapper { private: static std::shared_ptr s_instance_; +#ifdef PADDLE_WITH_PSLIB std::map> _regions; +#endif protected: static bool is_initialized_; -- GitLab