diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 750f532d265dccd3d4bf898ddbccd7d92174d0c6..86c63ababbbfdbc9b7d07c95e37dda8c67d18d2f 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -351,6 +351,7 @@ class InMemoryDataset(DatasetBase): self.enable_pv_merge = False self.merge_by_lineid = False self.fleet_send_sleep_seconds = None + self.trainer_num = -1 @deprecated( since="2.0.0", @@ -480,6 +481,23 @@ class InMemoryDataset(DatasetBase): """ self.parse_logkey = parse_logkey + def _set_trainer_num(self, trainer_num): + """ + Set trainer num + + Args: + trainer_num(int): trainer num + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset._set_trainer_num(1) + + """ + self.trainer_num = trainer_num + @deprecated( since="2.0.0", update_to="paddle.distributed.InMemoryDataset._set_merge_by_sid") @@ -766,16 +784,16 @@ class InMemoryDataset(DatasetBase): thread_num(int): shuffle thread num. Default is 12. """ - trainer_num = 1 if fleet is not None: fleet._role_maker.barrier_worker() - trainer_num = fleet.worker_num() + if self.trainer_num == -1: + self.trainer_num = fleet.worker_num() if self.fleet_send_batch_size is None: self.fleet_send_batch_size = 1024 if self.fleet_send_sleep_seconds is None: self.fleet_send_sleep_seconds = 0 self.dataset.register_client2client_msg_handler() - self.dataset.set_trainer_num(trainer_num) + self.dataset.set_trainer_num(self.trainer_num) self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) self.dataset.set_fleet_send_sleep_seconds(self.fleet_send_sleep_seconds) if fleet is not None: diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 7facc99a0736e4a44077b02cfff1de9e668a10b5..fcdac1d62412e74ff1749a52cb365ef2b530e7d6 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -65,8 +65,10 @@ class TestDataset(unittest.TestCase): dataset = fluid.InMemoryDataset() dataset.set_parse_ins_id(True) dataset.set_parse_content(True) + dataset._set_trainer_num(1) self.assertTrue(dataset.parse_ins_id) self.assertTrue(dataset.parse_content) + self.assertEqual(dataset.trainer_num, 1) def test_run_with_dump(self): """