未验证 提交 4adddcc8 编写于 作者: T Thunderbrook 提交者: GitHub

add set_trainer_num api in dataset (#29133)

上级 e0344081
......@@ -351,6 +351,7 @@ class InMemoryDataset(DatasetBase):
self.enable_pv_merge = False
self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None
self.trainer_num = -1
@deprecated(
since="2.0.0",
......@@ -480,6 +481,23 @@ class InMemoryDataset(DatasetBase):
"""
self.parse_logkey = parse_logkey
def _set_trainer_num(self, trainer_num):
"""
Set trainer num
Args:
trainer_num(int): trainer num
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset._set_trainer_num(1)
"""
self.trainer_num = trainer_num
@deprecated(
since="2.0.0",
update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid")
......@@ -766,16 +784,16 @@ class InMemoryDataset(DatasetBase):
thread_num(int): shuffle thread num. Default is 12.
"""
trainer_num = 1
if fleet is not None:
fleet._role_maker.barrier_worker()
trainer_num = fleet.worker_num()
if self.trainer_num == -1:
self.trainer_num = fleet.worker_num()
if self.fleet_send_batch_size is None:
self.fleet_send_batch_size = 1024
if self.fleet_send_sleep_seconds is None:
self.fleet_send_sleep_seconds = 0
self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num)
self.dataset.set_trainer_num(self.trainer_num)
self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size)
self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds)
if fleet is not None:
......
......@@ -65,8 +65,10 @@ class TestDataset(unittest.TestCase):
dataset = fluid.InMemoryDataset()
dataset.set_parse_ins_id(True)
dataset.set_parse_content(True)
dataset._set_trainer_num(1)
self.assertTrue(dataset.parse_ins_id)
self.assertTrue(dataset.parse_content)
self.assertEqual(dataset.trainer_num, 1)
def test_run_with_dump(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册