From 233746d89d38486911ba28255c5bad7cd32f32f4 Mon Sep 17 00:00:00 2001 From: jiaqi <173596896@qq.com> Date: Wed, 31 Jul 2019 21:44:16 +0800 Subject: [PATCH] set fleet_send_batch_num a default value according to trainer num MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (1) set fleet_send_batch_num a default value according to trainer num, the previous 80000 is fixed,if trainer num is much less or larger than 100,global shuffle may have timeout error. (2) fix load one table bug, add barrier --- python/paddle/fluid/dataset.py | 4 +++- .../fluid/incubate/fleet/parameter_server/pslib/__init__.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 668e5d25733..902a33b6146 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -235,7 +235,7 @@ class InMemoryDataset(DatasetBase): """ Init. """ super(InMemoryDataset, self).__init__() self.proto_desc.name = "MultiSlotInMemoryDataFeed" - self.fleet_send_batch_size = 80000 + self.fleet_send_batch_size = None self.queue_num = None self.merge_by_lineid = False @@ -413,6 +413,8 @@ class InMemoryDataset(DatasetBase): if fleet is not None: fleet._role_maker._barrier_worker() trainer_num = fleet.worker_num() + if self.fleet_send_batch_size is None: + self.fleet_send_batch_size = 800 * trainer_num self.dataset.register_client2client_msg_handler() self.dataset.set_trainer_num(trainer_num) self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 92f13148162..ac56142245b 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -320,11 +320,13 @@ class PSLib(Fleet): scope = kwargs.get("scope", None) model_proto_file = kwargs.get("model_proto_file", None) load_combine = kwargs.get("load_combine", False) + self._role_maker._barrier_worker() if scope is not None and model_proto_file is not None: self._load_one_table_from_paddle_model( scope, table_id, model_path, model_proto_file, load_combine) - else: + elif self._role_maker.is_first_worker(): self._fleet_ptr.load_model_one_table(table_id, model_path, mode) + self._role_maker._barrier_worker() def _load_one_table_from_paddle_model(self, scope, -- GitLab