checkpoint载入后第一步训练前大量时间空置
Created by: ChuRao
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
if self.checkpoint_cfg:
epochs = [
epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint_cfg.epoch_id
]
else:
epochs = [epoch_id for epoch_id in range(num_epochs)]
for epoch_id in epochs:
event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()):
if self.__stop:
if self.checkpoint_cfg:
self._clean_checkpoint()
return
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \
and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id:
continue
begin_event = BeginStepEvent(epoch_id, step_id)
event_handler(begin_event)
if begin_event.fetch_metrics:
metrics = exe.run(feed=data,
fetch_list=[
var.name
for var in self.train_func_outputs
])
else:
metrics = exe.run(feed=data, fetch_list=[])
if self.checkpoint_cfg:
self._save_checkpoint(epoch_id, step_id)
event_handler(EndStepEvent(epoch_id, step_id, metrics))
event_handler(EndEpochEvent(epoch_id))
if self.checkpoint_cfg:
self._clean_checkpoint()
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \ and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id: continue
这一段非常影响性能。在读取样本较慢、训练总数较大的时候(比如1秒读取一个batch数据的时候,有2万个batch数据)。一个存档点step_id在5000,加载了存档点岂不是要当空读数据5000秒之后才真正开始训练?
在有shuffle的时候重复一个epoch里一些数据重复用于训练问题不是很大吧?
我还以为怎么回事,一下午怎么event_handler一点输出也没有。 建议给个参数用来遗弃存档点的step_id