提交 b24e2f92 编写于 作者: W wangxiao

fix save ckpt % infermodel

上级 7175222f
...@@ -593,7 +593,7 @@ class Controller(object): ...@@ -593,7 +593,7 @@ class Controller(object):
cur_task.cur_train_step += 1 cur_task.cur_train_step += 1
cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch
if cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0: if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0:
cur_task.save(suffix='.step'+str(cur_task_global_step)) cur_task.save(suffix='.step'+str(cur_task_global_step))
if global_step % main_conf.get('print_every_n_steps', 5) == 0: if global_step % main_conf.get('print_every_n_steps', 5) == 0:
...@@ -612,7 +612,7 @@ class Controller(object): ...@@ -612,7 +612,7 @@ class Controller(object):
print(cur_task.name+': train finished!') print(cur_task.name+': train finished!')
cur_task.save() cur_task.save()
if 'save_every_n_steps' in main_conf and global_step % main_conf['save_every_n_steps'] == 0: if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf['save_ckpt_every_n_steps'] == 0:
save_path = os.path.join(main_conf['save_path'], 'ckpt', save_path = os.path.join(main_conf['save_path'], 'ckpt',
"step_" + str(global_step)) "step_" + str(global_step))
fluid.io.save_persistables(self.exe, save_path, saver_program) fluid.io.save_persistables(self.exe, save_path, saver_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册