提交 53409a29 编写于 作者: T tangwei12

code optimized

上级 f28f41db
...@@ -141,11 +141,7 @@ class Trainer(object): ...@@ -141,11 +141,7 @@ class Trainer(object):
self.chief = True self.chief = True
self.checkpoint = checkpoint_config self.checkpoint = checkpoint_config
if self.checkpoint: if self.checkpoint:
if not isinstance(self.checkpoint, CheckpointConfig): assert isinstance(self.checkpoint, CheckpointConfig)
raise TypeError(
"The checkpoint_config shoule be an instance of CheckpointConfig"
)
else:
serial = io.get_latest_checkpoint_serial( serial = io.get_latest_checkpoint_serial(
self.checkpoint.checkpoint_dir) self.checkpoint.checkpoint_dir)
self.checkpoint.load_serial = serial if serial >= 0 else None self.checkpoint.load_serial = serial if serial >= 0 else None
...@@ -385,8 +381,8 @@ class Trainer(object): ...@@ -385,8 +381,8 @@ class Trainer(object):
else: else:
metrics = exe.run(feed=data, fetch_list=[]) metrics = exe.run(feed=data, fetch_list=[])
event_handler(EndStepEvent(epoch_id, step_id, metrics))
self._save_checkpoint(epoch_id, step_id) self._save_checkpoint(epoch_id, step_id)
event_handler(EndStepEvent(epoch_id, step_id, metrics))
event_handler(EndEpochEvent(epoch_id)) event_handler(EndEpochEvent(epoch_id))
self._clean_checkpoint() self._clean_checkpoint()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册