提交 db45f4cb 编写于 作者: K kinghuin

fix data_feeder bug

上级 635b2ca3
......@@ -836,9 +836,16 @@ class BaseTask(object):
global_run_states = []
period_run_states = []
parallel_batch = []
for run_step, batch in enumerate(self.reader(), start=1):
if self.config.use_data_parallel and len(batch) < self.device_count:
continue
if self.config.use_data_parallel:
parallel_batch += batch
if len(parallel_batch) < self.device_count:
continue
else:
batch = parallel_batch
parallel_batch = []
step_run_state = RunState(len(self.fetch_list))
step_run_state.run_step = 1
num_batch_examples = len(batch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册