From a742bb0ff430cfe513fd58177de0507e5d81c745 Mon Sep 17 00:00:00 2001 From: guru4elephant <35550832+guru4elephant@users.noreply.github.com> Date: Tue, 16 Apr 2019 17:52:13 +0800 Subject: [PATCH] Merge pull request #16746 from xjqbest/dataset_merge_develop move split filelist from trainer.py to fleet & fix error --- python/paddle/fluid/dataset.py | 21 +++++++++++ .../fluid/incubate/fleet/base/role_maker.py | 2 +- .../fleet/parameter_server/__init__.py | 35 +++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index e90c36da9a..d83ad37a89 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -213,6 +213,7 @@ class InMemoryDataset(DatasetBase): >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset") >>> filelist = ["a.txt", "b.txt"] >>> dataset.set_filelist(filelist) + >>> dataset.load_into_memory() >>> dataset.local_shuffle() """ self.dataset.local_shuffle() @@ -230,6 +231,7 @@ class InMemoryDataset(DatasetBase): >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset") >>> filelist = ["a.txt", "b.txt"] >>> dataset.set_filelist(filelist) + >>> dataset.load_into_memory() >>> dataset.global_shuffle(fleet) Args: @@ -247,6 +249,25 @@ class InMemoryDataset(DatasetBase): if fleet is not None: fleet.fleet_instance.role_maker_._barrier_worker() + def release_memory(self): + """ + Release InMemoryDataset memory data, when data will not be used again. + + Example: + >>> import paddle.fluid as fluid + >>> import paddle.fluid.incubate.fleet.parameter_server as fleet + >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset") + >>> filelist = ["a.txt", "b.txt"] + >>> dataset.set_filelist(filelist) + >>> dataset.load_into_memory() + >>> dataset.global_shuffle(fleet) + >>> exe = fluid.Executor(fluid.CPUPlace()) + >>> exe.run(fluid.default_startup_program()) + >>> exe.train_from_dataset(fluid.default_main_program(), dataset) + >>> dataset.release_memory() + """ + self.dataset.release_memory() + class QueueDataset(DatasetBase): """ diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index 528f7b3269..506a38059c 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -128,7 +128,7 @@ class MPIRoleMaker(RoleMakerBase): """ finalize the current MPI instance. """ - self.comm_.finalize() + pass class MPISymetricRoleMaker(MPIRoleMaker): diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py index 044aa33c2b..ee5dd5acf1 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py @@ -228,6 +228,40 @@ class Fleet(object): """ self._fleet_ptr.save_model(save_path) + def split_filelist(self, filelist): + """ + split filelist before distributed training, + for example, filelist is [a, b, c ,d, e] and trainer_num = 2, + then trainer 0 gets [a, b, c] and trainer 1 gets [d, e] + + Example: + >>> all_filelist = ["a.txt", "b.txt", "c.txt"] + >>> my_filelist = fleet.split_filelist(all_filelist) + >>> dataset = fluid.DatasetFactory().create_dataset() + >>> dataset.set_filelist(my_filelist) + + Args: + filelist(list): list of filename, can be local or hdfs/afs. + + Returns: + list of filename which belongs to this trainer. + """ + file_num = len(filelist) + trainer_id = self.get_worker_index() + trainer_num = self.get_worker_num() + if trainer_num > file_num: + raise ValueError("trainer_num should be <= file_num : " + "%s > %s" % (trainer_num, file_num)) + # get interval of filelist, it's [ ) + start = 0 + end = 0 + for i in range(0, trainer_id + 1): + length = file_num / trainer_num + (i < (file_num % trainer_num)) + start = end + end += length + my_filelist = filelist[start:end] + return my_filelist + def _set_opt_info(self, opt_info): """ this function saves the result from DistributedOptimizer.minimize() @@ -324,3 +358,4 @@ save_pserver_model = fleet_instance.save_pserver_model worker_num = fleet_instance.get_worker_num server_num = fleet_instance.get_server_num worker_index = fleet_instance.get_worker_index +split_filelist = fleet_instance.split_filelist -- GitLab