diff --git a/tools/program.py b/tools/program.py index 30e9d7379bd07a0570c613ea25d1b440ed8bc682..3c71065a167fa18fc9d00535dace97737904b74d 100755 --- a/tools/program.py +++ b/tools/program.py @@ -114,7 +114,10 @@ def merge_config(config): global_config[key] = value else: sub_keys = key.split('.') - assert (sub_keys[0] in global_config), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(global_config.keys(), sub_keys[0]) + assert ( + sub_keys[0] in global_config + ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( + global_config.keys(), sub_keys[0]) cur = global_config[sub_keys[0]] for idx, sub_key in enumerate(sub_keys[1:]): assert (sub_key in cur) @@ -177,7 +180,6 @@ def build(config, main_prog, startup_prog, mode): optimizer.minimize(opt_loss) opt_loss_name = opt_loss.name global_lr = optimizer._global_learning_rate() - global_lr.persistable = True fetch_name_list.insert(0, "lr") fetch_varname_list.insert(0, global_lr.name) return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name)