未验证 提交 2644ef3e 编写于 作者: J jiaqi 提交者: GitHub

Merge pull request #16923 from xjqbest/my_cherry_pick_16746

Merge pull request #16746 from xjqbest/dataset_merge_develop
...@@ -213,6 +213,7 @@ class InMemoryDataset(DatasetBase): ...@@ -213,6 +213,7 @@ class InMemoryDataset(DatasetBase):
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset") >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"] >>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist) >>> dataset.set_filelist(filelist)
>>> dataset.load_into_memory()
>>> dataset.local_shuffle() >>> dataset.local_shuffle()
""" """
self.dataset.local_shuffle() self.dataset.local_shuffle()
...@@ -230,6 +231,7 @@ class InMemoryDataset(DatasetBase): ...@@ -230,6 +231,7 @@ class InMemoryDataset(DatasetBase):
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset") >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"] >>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist) >>> dataset.set_filelist(filelist)
>>> dataset.load_into_memory()
>>> dataset.global_shuffle(fleet) >>> dataset.global_shuffle(fleet)
Args: Args:
...@@ -249,6 +251,25 @@ class InMemoryDataset(DatasetBase): ...@@ -249,6 +251,25 @@ class InMemoryDataset(DatasetBase):
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker() fleet.fleet_instance.role_maker_._barrier_worker()
def release_memory(self):
"""
Release InMemoryDataset memory data, when data will not be used again.
Example:
>>> import paddle.fluid as fluid
>>> import paddle.fluid.incubate.fleet.parameter_server as fleet
>>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
>>> filelist = ["a.txt", "b.txt"]
>>> dataset.set_filelist(filelist)
>>> dataset.load_into_memory()
>>> dataset.global_shuffle(fleet)
>>> exe = fluid.Executor(fluid.CPUPlace())
>>> exe.run(fluid.default_startup_program())
>>> exe.train_from_dataset(fluid.default_main_program(), dataset)
>>> dataset.release_memory()
"""
self.dataset.release_memory()
class QueueDataset(DatasetBase): class QueueDataset(DatasetBase):
""" """
......
...@@ -128,7 +128,7 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -128,7 +128,7 @@ class MPIRoleMaker(RoleMakerBase):
""" """
finalize the current MPI instance. finalize the current MPI instance.
""" """
self.comm_.finalize() pass
class MPISymetricRoleMaker(MPIRoleMaker): class MPISymetricRoleMaker(MPIRoleMaker):
......
...@@ -241,6 +241,40 @@ class Fleet(object): ...@@ -241,6 +241,40 @@ class Fleet(object):
""" """
self._fleet_ptr.save_model(save_path) self._fleet_ptr.save_model(save_path)
def split_filelist(self, filelist):
"""
split filelist before distributed training,
for example, filelist is [a, b, c ,d, e] and trainer_num = 2,
then trainer 0 gets [a, b, c] and trainer 1 gets [d, e]
Example:
>>> all_filelist = ["a.txt", "b.txt", "c.txt"]
>>> my_filelist = fleet.split_filelist(all_filelist)
>>> dataset = fluid.DatasetFactory().create_dataset()
>>> dataset.set_filelist(my_filelist)
Args:
filelist(list): list of filename, can be local or hdfs/afs.
Returns:
list of filename which belongs to this trainer.
"""
file_num = len(filelist)
trainer_id = self.get_worker_index()
trainer_num = self.get_worker_num()
if trainer_num > file_num:
raise ValueError("trainer_num should be <= file_num : "
"%s > %s" % (trainer_num, file_num))
# get interval of filelist, it's [ )
start = 0
end = 0
for i in range(0, trainer_id + 1):
length = file_num / trainer_num + (i < (file_num % trainer_num))
start = end
end += length
my_filelist = filelist[start:end]
return my_filelist
def _set_opt_info(self, opt_info): def _set_opt_info(self, opt_info):
""" """
this function saves the result from DistributedOptimizer.minimize() this function saves the result from DistributedOptimizer.minimize()
...@@ -337,3 +371,4 @@ save_pserver_model = fleet_instance.save_pserver_model ...@@ -337,3 +371,4 @@ save_pserver_model = fleet_instance.save_pserver_model
worker_num = fleet_instance.get_worker_num worker_num = fleet_instance.get_worker_num
server_num = fleet_instance.get_server_num server_num = fleet_instance.get_server_num
worker_index = fleet_instance.get_worker_index worker_index = fleet_instance.get_worker_index
split_filelist = fleet_instance.split_filelist
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册