未验证 提交 4adddcc8 编写于 作者: T Thunderbrook 提交者: GitHub

add set_trainer_num api in dataset (#29133)

上级 e0344081
...@@ -351,6 +351,7 @@ class InMemoryDataset(DatasetBase): ...@@ -351,6 +351,7 @@ class InMemoryDataset(DatasetBase):
self.enable_pv_merge = False self.enable_pv_merge = False
self.merge_by_lineid = False self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None self.fleet_send_sleep_seconds = None
self.trainer_num = -1
@deprecated( @deprecated(
since="2.0.0", since="2.0.0",
...@@ -480,6 +481,23 @@ class InMemoryDataset(DatasetBase): ...@@ -480,6 +481,23 @@ class InMemoryDataset(DatasetBase):
""" """
self.parse_logkey = parse_logkey self.parse_logkey = parse_logkey
def _set_trainer_num(self, trainer_num):
"""
Set trainer num
Args:
trainer_num(int): trainer num
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset._set_trainer_num(1)
"""
self.trainer_num = trainer_num
@deprecated( @deprecated(
since="2.0.0", since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid") update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid")
...@@ -766,16 +784,16 @@ class InMemoryDataset(DatasetBase): ...@@ -766,16 +784,16 @@ class InMemoryDataset(DatasetBase):
thread_num(int): shuffle thread num. Default is 12. thread_num(int): shuffle thread num. Default is 12.
""" """
trainer_num = 1
if fleet is not None: if fleet is not None:
fleet._role_maker.barrier_worker() fleet._role_maker.barrier_worker()
trainer_num = fleet.worker_num() if self.trainer_num == -1:
self.trainer_num = fleet.worker_num()
if self.fleet_send_batch_size is None: if self.fleet_send_batch_size is None:
self.fleet_send_batch_size = 1024 self.fleet_send_batch_size = 1024
if self.fleet_send_sleep_seconds is None: if self.fleet_send_sleep_seconds is None:
self.fleet_send_sleep_seconds = 0 self.fleet_send_sleep_seconds = 0
self.dataset.register_client2client_msg_handler() self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num) self.dataset.set_trainer_num(self.trainer_num)
self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size)
self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds) self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds)
if fleet is not None: if fleet is not None:
......
...@@ -65,8 +65,10 @@ class TestDataset(unittest.TestCase): ...@@ -65,8 +65,10 @@ class TestDataset(unittest.TestCase):
dataset = fluid.InMemoryDataset() dataset = fluid.InMemoryDataset()
dataset.set_parse_ins_id(True) dataset.set_parse_ins_id(True)
dataset.set_parse_content(True) dataset.set_parse_content(True)
dataset._set_trainer_num(1)
self.assertTrue(dataset.parse_ins_id) self.assertTrue(dataset.parse_ins_id)
self.assertTrue(dataset.parse_content) self.assertTrue(dataset.parse_content)
self.assertEqual(dataset.trainer_num, 1)
def test_run_with_dump(self): def test_run_with_dump(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册