From 4adddcc89ad2b2e57faabe5734bae47e2f44111a Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Mon, 30 Nov 2020 11:17:12 +0800 Subject: [PATCH] add set_trainer_num api in dataset (#29133) --- python/paddle/fluid/dataset.py | 24 ++++++++++++++++--- .../fluid/tests/unittests/test_dataset.py | 2 ++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 750f532d265..86c63ababbb 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -351,6 +351,7 @@ class InMemoryDataset(DatasetBase): self.enable_pv_merge = False self.merge_by_lineid = False self.fleet_send_sleep_seconds = None + self.trainer_num = -1 @deprecated( since="2.0.0", @@ -480,6 +481,23 @@ class InMemoryDataset(DatasetBase): """ self.parse_logkey = parse_logkey + def _set_trainer_num(self, trainer_num): + """ + Set trainer num + + Args: + trainer_num(int): trainer num + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset._set_trainer_num(1) + + """ + self.trainer_num = trainer_num + @deprecated( since="2.0.0", update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid") @@ -766,16 +784,16 @@ class InMemoryDataset(DatasetBase): thread_num(int): shuffle thread num. Default is 12. """ - trainer_num = 1 if fleet is not None: fleet._role_maker.barrier_worker() - trainer_num = fleet.worker_num() + if self.trainer_num == -1: + self.trainer_num = fleet.worker_num() if self.fleet_send_batch_size is None: self.fleet_send_batch_size = 1024 if self.fleet_send_sleep_seconds is None: self.fleet_send_sleep_seconds = 0 self.dataset.register_client2client_msg_handler() - self.dataset.set_trainer_num(trainer_num) + self.dataset.set_trainer_num(self.trainer_num) self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds) if fleet is not None: diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 7facc99a073..fcdac1d6241 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -65,8 +65,10 @@ class TestDataset(unittest.TestCase): dataset = fluid.InMemoryDataset() dataset.set_parse_ins_id(True) dataset.set_parse_content(True) + dataset._set_trainer_num(1) self.assertTrue(dataset.parse_ins_id) self.assertTrue(dataset.parse_content) + self.assertEqual(dataset.trainer_num, 1) def test_run_with_dump(self): """ -- GitLab