提交 94eaf94c 编写于 作者: T tangwei12

bug fix about lru and save

上级 b44ede80
......@@ -495,11 +495,11 @@ def save_checkpoint(executor,
serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args)
if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program)
save_trainer_args(cur_dir, trainer_id, trainer_args)
_lru_delete(checkpoint_dir, max_num_checkpoints)
_lru_delete(checkpoint_dir, max_num_checkpoints)
def need_load_checkpoint(checkpoint_dir):
......@@ -639,7 +639,13 @@ def _is_checkpoint_var(var):
var.desc.type() == core.VarDesc.VarType.RAW:
return False
if var.name.endswith("@GRAD"):
if "@GRAD" in var.name:
return False
if ".trainer_" in var.name:
return False
if ".block" in var.name:
return False
return var.persistable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册