提交 a34fe624 编写于 作者: X xjqbest 提交者: dongdaxiang

add some doc

上级 20b76f3d
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// constructor
template <typename T> template <typename T>
DatasetImpl<T>::DatasetImpl() { DatasetImpl<T>::DatasetImpl() {
thread_num_ = 1; thread_num_ = 1;
...@@ -31,37 +32,24 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -31,37 +32,24 @@ DatasetImpl<T>::DatasetImpl() {
file_idx_ = 0; file_idx_ = 0;
} }
// set filelist, file_idx_ will reset to zero.
template <typename T> template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) { void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size(); VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist; filelist_ = filelist;
file_idx_ = 0; file_idx_ = 0;
/*
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num_ = file_cnt;
}*/
} }
// buggy here, a user should set filelist first before this function // set expect thread num. actually it may change
// not user friendly
template <typename T> template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) { void DatasetImpl<T>::SetThreadNum(int thread_num) {
VLOG(3) << "SetThreadNum thread_num=" << thread_num; VLOG(3) << "SetThreadNum thread_num=" << thread_num;
//int file_cnt = filelist_.size();
/*
if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(3) << "DataSet thread num = " << thread_num
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt;
}*/
thread_num_ = thread_num; thread_num_ = thread_num;
} }
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerNum
template <typename T> template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) { void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num; trainer_num_ = trainer_num;
...@@ -86,12 +74,16 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) { ...@@ -86,12 +74,16 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
&data_feed_desc_); &data_feed_desc_);
} }
// readers_.size() may not be equal to thread_num_,
// it changes when filelist_.size() < thread_num_
template <typename T> template <typename T>
std::vector<std::shared_ptr<paddle::framework::DataFeed>>& std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
DatasetImpl<T>::GetReaders() { DatasetImpl<T>::GetReaders() {
return readers_; return readers_;
} }
// load data into memory, Dataset hold this memory,
// which will later be fed into readers' channel
template <typename T> template <typename T>
void DatasetImpl<T>::LoadIntoMemory() { void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin"; VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
...@@ -114,6 +106,7 @@ void DatasetImpl<T>::LoadIntoMemory() { ...@@ -114,6 +106,7 @@ void DatasetImpl<T>::LoadIntoMemory() {
<< ", cost time=" << timeline.ElapsedSec() << " seconds"; << ", cost time=" << timeline.ElapsedSec() << " seconds";
} }
// do local shuffle
template <typename T> template <typename T>
void DatasetImpl<T>::LocalShuffle() { void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin"; VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
......
...@@ -26,6 +26,16 @@ ...@@ -26,6 +26,16 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Dataset is a abstract class, which defines user interfaces
// Example Usage:
// Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
// dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
// dataset->SetThreadNum(1)
// dataset->CreateReaders();
// dataset->SetDataFeedDesc(your_data_feed_desc);
// dataset->LoadIntoMemory();
// dataset->SetTrainerNum(2);
// dataset->GlobalShuffle();
class Dataset { class Dataset {
public: public:
Dataset() {} Dataset() {}
...@@ -53,6 +63,8 @@ class Dataset { ...@@ -53,6 +63,8 @@ class Dataset {
const std::string& msg) = 0; const std::string& msg) = 0;
}; };
// DatasetImpl is the implementation of Dataset,
// it holds memory data if user calls load_into_memory
template <typename T> template <typename T>
class DatasetImpl : public Dataset { class DatasetImpl : public Dataset {
public: public:
...@@ -95,6 +107,7 @@ class DatasetImpl : public Dataset { ...@@ -95,6 +107,7 @@ class DatasetImpl : public Dataset {
std::mutex mutex_for_pick_file_; std::mutex mutex_for_pick_file_;
}; };
// use std::vector<MultiSlotType> as data type
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> { class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
public: public:
MultiSlotDataset() {} MultiSlotDataset() {}
......
...@@ -146,7 +146,9 @@ class FleetWrapper { ...@@ -146,7 +146,9 @@ class FleetWrapper {
private: private:
static std::shared_ptr<FleetWrapper> s_instance_; static std::shared_ptr<FleetWrapper> s_instance_;
#ifdef PADDLE_WITH_PSLIB
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions; std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
#endif
protected: protected:
static bool is_initialized_; static bool is_initialized_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册