未验证 提交 02c370c3 编写于 作者: J jiaqi 提交者: GitHub

support filelist size < trainer num && fix pull dense (#18956)

* support filelist size < trainer num
* pull dense when stop, to make sure local dense params are same as pserver, so save paddle model will save dense model same as pserver
*  enable QueueDataset train same filelist for serveral times
上级 e7da0940
......@@ -352,8 +352,6 @@ void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Filelist size in Dataset: " << filelist_.size();
VLOG(3) << "channel num in Dataset: " << channel_num_;
CHECK(thread_num_ > 0) << "thread num should > 0";
CHECK(thread_num_ <= filelist_.size())
<< "thread num should <= filelist size";
CHECK(channel_num_ > 0) << "channel num should > 0";
CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num";
VLOG(3) << "readers size: " << readers_.size();
......
......@@ -80,6 +80,9 @@ void PullDenseWorker::Stop() {
if (running_) {
running_ = false;
t_.join();
// pull dense when stop, to make sure local dense params are same as
// pserver, so save paddle model will save dense model same as pserver
PullDense(true);
}
}
......
......@@ -246,6 +246,8 @@ class InMemoryDataset(DatasetBase):
"""
if self.thread_num > len(self.filelist):
self.thread_num = len(self.filelist)
if self.thread_num == 0:
self.thread_num = 1
self.dataset.set_thread_num(self.thread_num)
if self.queue_num is None:
self.queue_num = self.thread_num
......@@ -545,6 +547,20 @@ class QueueDataset(DatasetBase):
super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed"
def _prepare_to_run(self):
"""
Set data_feed_desc/thread num/filelist before run,
user no need to call this function.
"""
if self.thread_num > len(self.filelist):
self.thread_num = len(self.filelist)
if self.thread_num == 0:
self.thread_num = 1
self.dataset.set_thread_num(self.thread_num)
self.dataset.set_filelist(self.filelist)
self.dataset.set_data_feed_desc(self.desc())
self.dataset.create_readers()
def local_shuffle(self):
"""
Local shuffle data.
......
......@@ -148,8 +148,10 @@ class Fleet(object):
def split_files(self, files):
"""
split files before distributed training,
for example, files is [a, b, c ,d, e] and trainer_num = 2,
then trainer 0 gets [a, b, c] and trainer 1 gets [d, e]
example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer
0 gets [a, b, c] and trainer 1 gets [d, e].
example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
[a], trainer 1 gets [b], trainer 2 gets []
Args:
files(list): file list need to be read.
......@@ -160,9 +162,6 @@ class Fleet(object):
trainer_id = self.worker_index()
trainers = self.worker_num()
if len(files) < trainers:
raise ValueError("file number must gather or equal trainer number")
remainder = len(files) % trainers
blocksize = len(files) / trainers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册