未验证 提交 ed4313d6 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #222 from tink2123/fix_cosine_decay

fix can not find tmp_4
...@@ -114,7 +114,10 @@ def merge_config(config): ...@@ -114,7 +114,10 @@ def merge_config(config):
global_config[key] = value global_config[key] = value
else: else:
sub_keys = key.split('.') sub_keys = key.split('.')
assert (sub_keys[0] in global_config), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(global_config.keys(), sub_keys[0]) assert (
sub_keys[0] in global_config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]] cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]): for idx, sub_key in enumerate(sub_keys[1:]):
assert (sub_key in cur) assert (sub_key in cur)
...@@ -177,7 +180,6 @@ def build(config, main_prog, startup_prog, mode): ...@@ -177,7 +180,6 @@ def build(config, main_prog, startup_prog, mode):
optimizer.minimize(opt_loss) optimizer.minimize(opt_loss)
opt_loss_name = opt_loss.name opt_loss_name = opt_loss.name
global_lr = optimizer._global_learning_rate() global_lr = optimizer._global_learning_rate()
global_lr.persistable = True
fetch_name_list.insert(0, "lr") fetch_name_list.insert(0, "lr")
fetch_varname_list.insert(0, global_lr.name) fetch_varname_list.insert(0, global_lr.name)
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name) return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册