未验证 提交 02c370c3 编写于 作者: J jiaqi 提交者: GitHub

support filelist size < trainer num && fix pull dense (#18956)

* support filelist size < trainer num
* pull dense when stop, to make sure local dense params are same as pserver, so save paddle model will save dense model same as pserver
*  enable QueueDataset train same filelist for serveral times
上级 e7da0940
...@@ -352,8 +352,6 @@ void DatasetImpl<T>::CreateReaders() { ...@@ -352,8 +352,6 @@ void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Filelist size in Dataset: " << filelist_.size(); VLOG(3) << "Filelist size in Dataset: " << filelist_.size();
VLOG(3) << "channel num in Dataset: " << channel_num_; VLOG(3) << "channel num in Dataset: " << channel_num_;
CHECK(thread_num_ > 0) << "thread num should > 0"; CHECK(thread_num_ > 0) << "thread num should > 0";
CHECK(thread_num_ <= filelist_.size())
<< "thread num should <= filelist size";
CHECK(channel_num_ > 0) << "channel num should > 0"; CHECK(channel_num_ > 0) << "channel num should > 0";
CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num"; CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num";
VLOG(3) << "readers size: " << readers_.size(); VLOG(3) << "readers size: " << readers_.size();
......
...@@ -80,6 +80,9 @@ void PullDenseWorker::Stop() { ...@@ -80,6 +80,9 @@ void PullDenseWorker::Stop() {
if (running_) { if (running_) {
running_ = false; running_ = false;
t_.join(); t_.join();
// pull dense when stop, to make sure local dense params are same as
// pserver, so save paddle model will save dense model same as pserver
PullDense(true);
} }
} }
......
...@@ -246,6 +246,8 @@ class InMemoryDataset(DatasetBase): ...@@ -246,6 +246,8 @@ class InMemoryDataset(DatasetBase):
""" """
if self.thread_num > len(self.filelist): if self.thread_num > len(self.filelist):
self.thread_num = len(self.filelist) self.thread_num = len(self.filelist)
if self.thread_num == 0:
self.thread_num = 1
self.dataset.set_thread_num(self.thread_num) self.dataset.set_thread_num(self.thread_num)
if self.queue_num is None: if self.queue_num is None:
self.queue_num = self.thread_num self.queue_num = self.thread_num
...@@ -545,6 +547,20 @@ class QueueDataset(DatasetBase): ...@@ -545,6 +547,20 @@ class QueueDataset(DatasetBase):
super(QueueDataset, self).__init__() super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed" self.proto_desc.name = "MultiSlotDataFeed"
def _prepare_to_run(self):
"""
Set data_feed_desc/thread num/filelist before run,
user no need to call this function.
"""
if self.thread_num > len(self.filelist):
self.thread_num = len(self.filelist)
if self.thread_num == 0:
self.thread_num = 1
self.dataset.set_thread_num(self.thread_num)
self.dataset.set_filelist(self.filelist)
self.dataset.set_data_feed_desc(self.desc())
self.dataset.create_readers()
def local_shuffle(self): def local_shuffle(self):
""" """
Local shuffle data. Local shuffle data.
......
...@@ -148,8 +148,10 @@ class Fleet(object): ...@@ -148,8 +148,10 @@ class Fleet(object):
def split_files(self, files): def split_files(self, files):
""" """
split files before distributed training, split files before distributed training,
for example, files is [a, b, c ,d, e] and trainer_num = 2, example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer
then trainer 0 gets [a, b, c] and trainer 1 gets [d, e] 0 gets [a, b, c] and trainer 1 gets [d, e].
example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
[a], trainer 1 gets [b], trainer 2 gets []
Args: Args:
files(list): file list need to be read. files(list): file list need to be read.
...@@ -160,9 +162,6 @@ class Fleet(object): ...@@ -160,9 +162,6 @@ class Fleet(object):
trainer_id = self.worker_index() trainer_id = self.worker_index()
trainers = self.worker_num() trainers = self.worker_num()
if len(files) < trainers:
raise ValueError("file number must gather or equal trainer number")
remainder = len(files) % trainers remainder = len(files) % trainers
blocksize = len(files) / trainers blocksize = len(files) / trainers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册