提交 514d727a 编写于 作者: X xjqbest

fix dataset bug

test=develop
上级 271b7147
...@@ -240,12 +240,17 @@ class InMemoryDataset(DatasetBase): ...@@ -240,12 +240,17 @@ class InMemoryDataset(DatasetBase):
Args: Args:
fleet: fleet singleton. Default None. fleet: fleet singleton. Default None.
""" """
trainer_id = 0
trainer_num = 1 trainer_num = 1
fleet_send_batch_size = 80000
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker() fleet.fleet_instance.role_maker_._barrier_worker()
trainer_id = fleet.worker_index()
trainer_num = fleet.worker_num() trainer_num = fleet.worker_num()
self.dataset.register_client2client_msg_handler() self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_id(trainer_id)
self.dataset.set_trainer_num(trainer_num) self.dataset.set_trainer_num(trainer_num)
self.dataset.set_fleet_send_batch_size(fleet_send_batch_size)
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_._barrier_worker() fleet.fleet_instance.role_maker_._barrier_worker()
self.dataset.global_shuffle() self.dataset.global_shuffle()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册