提交 17fd1bc2 编写于 作者: H HydrogenSulfate

refine code

上级 15242df1
......@@ -39,7 +39,7 @@ def update_loss(trainer, loss_dict, batch_size):
def log_info(trainer, batch_size, epoch_id, iter_id):
lr_msg = ", ".join([
"lr_{}: {:.8f}".format(i + 1, lr.get_lr())
"lr({}): {:.8f}".format(lr.__class__.__name__, lr.get_lr())
for i, lr in enumerate(trainer.lr_sch)
])
metric_msg = ", ".join([
......
......@@ -44,10 +44,9 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config)
optim_config = config["Optimizer"]
optim_config = copy.deepcopy(config)
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
# convert {'name': xxx, **optim_cfg} to [{'name': {'scope': xxx, **optim_cfg}}]
optim_name = optim_config.pop("name")
optim_config: List[Dict[str, Dict]] = [{
optim_name: {
......@@ -61,19 +60,19 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
"""NOTE:
Currently only support optim objets below.
1. single optimizer config.
2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
3. loss which has parameters, such as CenterLoss.
2. model(entire Arch), backbone, neck, head.
3. loss(entire Loss), specific loss listed in ppcls/loss/__init__.py.
"""
for optim_item in optim_config:
# optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
# optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}}
# step1 build lr
optim_name = list(optim_item.keys())[0] # get optim_name
optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
optim_cfg = optim_item[optim_name] # get optim_cfg
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
logger.debug("build lr ({}) for scope ({}) success..".format(
lr, optim_scope))
logger.info("build lr ({}) for scope ({}) success..".format(
lr.__class__.__name__, optim_scope))
# step2 build regularization
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
if 'weight_decay' in optim_cfg:
......@@ -84,8 +83,8 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
optim_cfg["weight_decay"] = reg
logger.debug("build regularizer ({}) for scope ({}) success..".
format(reg, optim_scope))
logger.info("build regularizer ({}) for scope ({}) success..".
format(reg.__class__.__name__, optim_scope))
# step3 build optimizer
if 'clip_norm' in optim_cfg:
clip_norm = optim_cfg.pop('clip_norm')
......@@ -93,30 +92,42 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
else:
grad_clip = None
optim_model = []
for i in range(len(model_list)):
if len(model_list[i].parameters()) == 0:
continue
# for static graph
if model_list is None:
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
return optim, lr
# for dynamic graph
if optim_scope == "all":
# optimizer for all
optim_model.append(model_list[i])
else:
if optim_scope.endswith("Loss"):
# optimizer for loss
for m in model_list[i].sublayers(True):
if m.__class_name == optim_scope:
optim_model.append(m)
optim_model = model_list
elif optim_scope == "model":
optim_model = [model_list[0], ]
elif optim_scope in ["backbone", "neck", "head"]:
optim_model = [getattr(model_list[0], optim_scope, None), ]
elif optim_scope == "loss":
optim_model = [model_list[1], ]
else:
# opmizer for module in model, such as backbone, neck, head...
if hasattr(model_list[i], optim_scope):
optim_model.append(getattr(model_list[i], optim_scope))
optim_model = [
model_list[1].loss_func[i]
for i in range(len(model_list[1].loss_func))
if model_list[1].loss_func[i].__class__.__name__ == optim_scope
]
optim_model = [
optim_model[i] for i in range(len(optim_model))
if (optim_model[i] is not None
) and (len(optim_model[i].parameters()) > 0)
]
assert len(optim_model) > 0, \
f"optim_model is empty for optim_scope({optim_scope})"
assert len(optim_model) == 1, \
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
logger.debug("build optimizer ({}) for scope ({}) success..".format(
optim, optim_scope))
logger.info("build optimizer ({}) for scope ({}) success..".format(
optim.__class__.__name__, optim_scope))
optim_list.append(optim)
lr_list.append(lr)
return optim_list, lr_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册