提交 d5ee580c 编写于 作者: X xjqbest

move split filelist from trainer.py to fleet & fix error

test=develop
上级 126d2a2f
......@@ -128,7 +128,7 @@ class MPIRoleMaker(RoleMakerBase):
"""
finalize the current MPI instance.
"""
self.comm_.finalize()
pass
class MPISymetricRoleMaker(MPIRoleMaker):
......
......@@ -241,6 +241,35 @@ class Fleet(object):
"""
self._fleet_ptr.save_model(save_path)
def split_filelist(self, filelist):
"""
split filelist before distributed training,
for example, filelist is [a, b, c ,d, e] and trainer_num = 2,
then trainer 0 gets [a, b, c] and trainer 1 gets [d, e]
Args:
filelist(list): list of filename, can be local or hdfs/afs.
Returns: list of filename which belongs to this trainer.
"""
file_num = len(filelist)
trainer_id = self.get_worker_index()
trainer_num = self.get_worker_num()
if trainer_num > file_num:
raise ValueError(
"trainer_num should be <= file_num : "
"%s > %s" % (trainer_num, file_num)
)
# get interval of filelist, it's [ )
start = 0
end = 0
for i in range(0, trainer_id + 1):
length = file_num / trainer_num + (i < (file_num % trainer_num))
start = end
end += length
myfilelist = filelist[start : end]
return myfilelist
def _set_opt_info(self, opt_info):
"""
this function saves the result from DistributedOptimizer.minimize()
......@@ -337,3 +366,4 @@ save_pserver_model = fleet_instance.save_pserver_model
worker_num = fleet_instance.get_worker_num
server_num = fleet_instance.get_server_num
worker_index = fleet_instance.get_worker_index
split_filelist = fleet_instance.split_filelist
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册