提交 53409a29 编写于 作者: T tangwei12

code optimized

上级 f28f41db
......@@ -141,14 +141,10 @@ class Trainer(object):
self.chief = True
self.checkpoint = checkpoint_config
if self.checkpoint:
if not isinstance(self.checkpoint, CheckpointConfig):
raise TypeError(
"The checkpoint_config shoule be an instance of CheckpointConfig"
)
else:
serial = io.get_latest_checkpoint_serial(
self.checkpoint.checkpoint_dir)
self.checkpoint.load_serial = serial if serial >= 0 else None
assert isinstance(self.checkpoint, CheckpointConfig)
serial = io.get_latest_checkpoint_serial(
self.checkpoint.checkpoint_dir)
self.checkpoint.load_serial = serial if serial >= 0 else None
self.scope = core.Scope()
......@@ -385,8 +381,8 @@ class Trainer(object):
else:
metrics = exe.run(feed=data, fetch_list=[])
event_handler(EndStepEvent(epoch_id, step_id, metrics))
self._save_checkpoint(epoch_id, step_id)
event_handler(EndStepEvent(epoch_id, step_id, metrics))
event_handler(EndEpochEvent(epoch_id))
self._clean_checkpoint()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册