diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 668e5d25733e534fb255f24348ece79e36e14db2..902a33b614675eeac0d6bf643b3b519325fd150d 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -235,7 +235,7 @@ class InMemoryDataset(DatasetBase): """ Init. """ super(InMemoryDataset, self).__init__() self.proto_desc.name = "MultiSlotInMemoryDataFeed" - self.fleet_send_batch_size = 80000 + self.fleet_send_batch_size = None self.queue_num = None self.merge_by_lineid = False @@ -413,6 +413,8 @@ class InMemoryDataset(DatasetBase): if fleet is not None: fleet._role_maker._barrier_worker() trainer_num = fleet.worker_num() + if self.fleet_send_batch_size is None: + self.fleet_send_batch_size = 800 * trainer_num self.dataset.register_client2client_msg_handler() self.dataset.set_trainer_num(trainer_num) self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 92f1314816250781254d580e5265fb22981eaa41..ac56142245b6ab3b4d94546c0abce7bc9f6f0971 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -320,11 +320,13 @@ class PSLib(Fleet): scope = kwargs.get("scope", None) model_proto_file = kwargs.get("model_proto_file", None) load_combine = kwargs.get("load_combine", False) + self._role_maker._barrier_worker() if scope is not None and model_proto_file is not None: self._load_one_table_from_paddle_model( scope, table_id, model_path, model_proto_file, load_combine) - else: + elif self._role_maker.is_first_worker(): self._fleet_ptr.load_model_one_table(table_id, model_path, mode) + self._role_maker._barrier_worker() def _load_one_table_from_paddle_model(self, scope,