未验证 提交 233746d8 编写于 作者: J jiaqi 提交者: GitHub

set fleet_send_batch_num a default value according to trainer num

(1) set fleet_send_batch_num a default value according to trainer num, the previous 80000 is fixed,if trainer num is much less or larger than 100,global shuffle may have timeout error.

(2) fix load one table bug, add barrier
上级 ea6ee76f
...@@ -235,7 +235,7 @@ class InMemoryDataset(DatasetBase): ...@@ -235,7 +235,7 @@ class InMemoryDataset(DatasetBase):
""" Init. """ """ Init. """
super(InMemoryDataset, self).__init__() super(InMemoryDataset, self).__init__()
self.proto_desc.name = "MultiSlotInMemoryDataFeed" self.proto_desc.name = "MultiSlotInMemoryDataFeed"
self.fleet_send_batch_size = 80000 self.fleet_send_batch_size = None
self.queue_num = None self.queue_num = None
self.merge_by_lineid = False self.merge_by_lineid = False
...@@ -413,6 +413,8 @@ class InMemoryDataset(DatasetBase): ...@@ -413,6 +413,8 @@ class InMemoryDataset(DatasetBase):
if fleet is not None: if fleet is not None:
fleet._role_maker._barrier_worker() fleet._role_maker._barrier_worker()
trainer_num = fleet.worker_num() trainer_num = fleet.worker_num()
if self.fleet_send_batch_size is None:
self.fleet_send_batch_size = 800 * trainer_num
self.dataset.register_client2client_msg_handler() self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num) self.dataset.set_trainer_num(trainer_num)
self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size) self.dataset.set_fleet_send_batch_size(self.fleet_send_batch_size)
......
...@@ -320,11 +320,13 @@ class PSLib(Fleet): ...@@ -320,11 +320,13 @@ class PSLib(Fleet):
scope = kwargs.get("scope", None) scope = kwargs.get("scope", None)
model_proto_file = kwargs.get("model_proto_file", None) model_proto_file = kwargs.get("model_proto_file", None)
load_combine = kwargs.get("load_combine", False) load_combine = kwargs.get("load_combine", False)
self._role_maker._barrier_worker()
if scope is not None and model_proto_file is not None: if scope is not None and model_proto_file is not None:
self._load_one_table_from_paddle_model( self._load_one_table_from_paddle_model(
scope, table_id, model_path, model_proto_file, load_combine) scope, table_id, model_path, model_proto_file, load_combine)
else: elif self._role_maker.is_first_worker():
self._fleet_ptr.load_model_one_table(table_id, model_path, mode) self._fleet_ptr.load_model_one_table(table_id, model_path, mode)
self._role_maker._barrier_worker()
def _load_one_table_from_paddle_model(self, def _load_one_table_from_paddle_model(self,
scope, scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册