提交 80b8ca3f 编写于 作者: H HydrogenSulfate

fix optimizer/init.py

上级 3f117428
...@@ -51,7 +51,7 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): ...@@ -51,7 +51,7 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_name = optim_config.pop("name") optim_name = optim_config.pop("name")
optim_config: List[Dict[str, Dict]] = [{ optim_config: List[Dict[str, Dict]] = [{
optim_name: { optim_name: {
'scope': config["Arch"].get("name"), 'scope': "all",
** **
optim_config optim_config
} }
...@@ -59,10 +59,10 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): ...@@ -59,10 +59,10 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_list = [] optim_list = []
lr_list = [] lr_list = []
for optim_item in optim_config: for optim_item in optim_config:
# optim_cfg = {optim_name1: {scope: xxx, **optim_cfg}} # optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
# step1 build lr # step1 build lr
optim_name = list(optim_item.keys())[0] # get optim_name1 optim_name = list(optim_item.keys())[0] # get optim_name
optim_scope = optim_item[optim_name].pop('scope') # get scope optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
optim_cfg = optim_item[optim_name] # get optim_cfg optim_cfg = optim_item[optim_name] # get optim_cfg
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch) lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
...@@ -78,7 +78,8 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): ...@@ -78,7 +78,8 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
reg_name = reg_config.pop('name') + 'Decay' reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config) reg = getattr(paddle.regularizer, reg_name)(**reg_config)
optim_cfg["weight_decay"] = reg optim_cfg["weight_decay"] = reg
logger.debug("build regularizer ({}) success..".format(reg)) logger.debug("build regularizer ({}) for scope ({}) success..".
format(reg, optim_scope))
# step3 build optimizer # step3 build optimizer
if 'clip_norm' in optim_cfg: if 'clip_norm' in optim_cfg:
clip_norm = optim_cfg.pop('clip_norm') clip_norm = optim_cfg.pop('clip_norm')
...@@ -87,11 +88,16 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): ...@@ -87,11 +88,16 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
grad_clip = None grad_clip = None
optim_model = [] optim_model = []
for i in range(len(model_list)): for i in range(len(model_list)):
class_name = model_list[i].__class__.__name__ if len(model_list[i].parameters()) == 0:
if class_name == optim_scope: continue
if optim_scope == "all":
optim_model.append(model_list[i]) optim_model.append(model_list[i])
assert len(optim_model) == 1 and len(optim_model[0].parameters()) > 0, \ else:
f"Invalid optim model for optim scope({optim_scope}), number of optim_model={len(optim_model)}, and number of optim_model's params={len(optim_model[0].parameters())}" for m in model_list[i].sublayers(True):
if m.__class__.__name__ == optim_scope:
optim_model.append(model_list[i])
assert len(optim_model) == 1, \
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
optim = getattr(optimizer, optim_name)( optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip, learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model) **optim_cfg)(model_list=optim_model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册