From 02c370c3dc5424941efa6e231a122e8ee80593d6 Mon Sep 17 00:00:00 2001 From: jiaqi <173596896@qq.com> Date: Fri, 2 Aug 2019 10:39:39 +0800 Subject: [PATCH] support filelist size < trainer num && fix pull dense (#18956) * support filelist size < trainer num * pull dense when stop, to make sure local dense params are same as pserver, so save paddle model will save dense model same as pserver * enable QueueDataset train same filelist for serveral times --- paddle/fluid/framework/data_set.cc | 2 -- paddle/fluid/framework/pull_dense_worker.cc | 3 +++ python/paddle/fluid/dataset.py | 16 ++++++++++++++++ .../fluid/incubate/fleet/base/fleet_base.py | 9 ++++----- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index f0c8ccc243c..11449608542 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -352,8 +352,6 @@ void DatasetImpl::CreateReaders() { VLOG(3) << "Filelist size in Dataset: " << filelist_.size(); VLOG(3) << "channel num in Dataset: " << channel_num_; CHECK(thread_num_ > 0) << "thread num should > 0"; - CHECK(thread_num_ <= filelist_.size()) - << "thread num should <= filelist size"; CHECK(channel_num_ > 0) << "channel num should > 0"; CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num"; VLOG(3) << "readers size: " << readers_.size(); diff --git a/paddle/fluid/framework/pull_dense_worker.cc b/paddle/fluid/framework/pull_dense_worker.cc index 20d7f98e936..3fe0d516e2d 100644 --- a/paddle/fluid/framework/pull_dense_worker.cc +++ b/paddle/fluid/framework/pull_dense_worker.cc @@ -80,6 +80,9 @@ void PullDenseWorker::Stop() { if (running_) { running_ = false; t_.join(); + // pull dense when stop, to make sure local dense params are same as + // pserver, so save paddle model will save dense model same as pserver + PullDense(true); } } diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 902a33b6146..20ffd13d605 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -246,6 +246,8 @@ class InMemoryDataset(DatasetBase): """ if self.thread_num > len(self.filelist): self.thread_num = len(self.filelist) + if self.thread_num == 0: + self.thread_num = 1 self.dataset.set_thread_num(self.thread_num) if self.queue_num is None: self.queue_num = self.thread_num @@ -545,6 +547,20 @@ class QueueDataset(DatasetBase): super(QueueDataset, self).__init__() self.proto_desc.name = "MultiSlotDataFeed" + def _prepare_to_run(self): + """ + Set data_feed_desc/thread num/filelist before run, + user no need to call this function. + """ + if self.thread_num > len(self.filelist): + self.thread_num = len(self.filelist) + if self.thread_num == 0: + self.thread_num = 1 + self.dataset.set_thread_num(self.thread_num) + self.dataset.set_filelist(self.filelist) + self.dataset.set_data_feed_desc(self.desc()) + self.dataset.create_readers() + def local_shuffle(self): """ Local shuffle data. diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index a52970fad12..ac9b0f23276 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -148,8 +148,10 @@ class Fleet(object): def split_files(self, files): """ split files before distributed training, - for example, files is [a, b, c ,d, e] and trainer_num = 2, - then trainer 0 gets [a, b, c] and trainer 1 gets [d, e] + example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer + 0 gets [a, b, c] and trainer 1 gets [d, e]. + example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets + [a], trainer 1 gets [b], trainer 2 gets [] Args: files(list): file list need to be read. @@ -160,9 +162,6 @@ class Fleet(object): trainer_id = self.worker_index() trainers = self.worker_num() - if len(files) < trainers: - raise ValueError("file number must gather or equal trainer number") - remainder = len(files) % trainers blocksize = len(files) / trainers -- GitLab