From d5ee580c5ca97ba21020c6ff83ba6ecb2f0aa79c Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Tue, 9 Apr 2019 20:46:04 +0800 Subject: [PATCH] move split filelist from trainer.py to fleet & fix error test=develop --- .../fluid/incubate/fleet/base/role_maker.py | 2 +- .../fleet/parameter_server/__init__.py | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index 528f7b3269e..506a38059c3 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -128,7 +128,7 @@ class MPIRoleMaker(RoleMakerBase): """ finalize the current MPI instance. """ - self.comm_.finalize() + pass class MPISymetricRoleMaker(MPIRoleMaker): diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py index 9b1ec412c73..1c49ea1f55e 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py @@ -241,6 +241,35 @@ class Fleet(object): """ self._fleet_ptr.save_model(save_path) + def split_filelist(self, filelist): + """ + split filelist before distributed training, + for example, filelist is [a, b, c ,d, e] and trainer_num = 2, + then trainer 0 gets [a, b, c] and trainer 1 gets [d, e] + + Args: + filelist(list): list of filename, can be local or hdfs/afs. + + Returns: list of filename which belongs to this trainer. + """ + file_num = len(filelist) + trainer_id = self.get_worker_index() + trainer_num = self.get_worker_num() + if trainer_num > file_num: + raise ValueError( + "trainer_num should be <= file_num : " + "%s > %s" % (trainer_num, file_num) + ) + # get interval of filelist, it's [ ) + start = 0 + end = 0 + for i in range(0, trainer_id + 1): + length = file_num / trainer_num + (i < (file_num % trainer_num)) + start = end + end += length + myfilelist = filelist[start : end] + return myfilelist + def _set_opt_info(self, opt_info): """ this function saves the result from DistributedOptimizer.minimize() @@ -337,3 +366,4 @@ save_pserver_model = fleet_instance.save_pserver_model worker_num = fleet_instance.get_worker_num server_num = fleet_instance.get_server_num worker_index = fleet_instance.get_worker_index +split_filelist = fleet_instance.split_filelist -- GitLab