diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index b7b4d42105ea8f71d41f12f4a93423ea65707e2f..a440eac463e2a5ebd74cf676b2487973baea9546 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -58,6 +58,12 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): }] optim_list = [] lr_list = [] + """NOTE: + Currently only support optim objets below. + 1. single optimizer config. + 2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head. + 3. loss which has parameters, such as CenterLoss. + """ for optim_item in optim_config: # optim_cfg = {optim_name: {scope: xxx, **optim_cfg}} # step1 build lr @@ -91,11 +97,19 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): if len(model_list[i].parameters()) == 0: continue if optim_scope == "all": + # optimizer for all optim_model.append(model_list[i]) else: - for m in model_list[i].sublayers(True): - if m.__class__.__name__ == optim_scope: - optim_model.append(model_list[i]) + if optim_scope.endswith("Loss"): + # optimizer for loss + for m in model_list[i].sublayers(True): + if m.__class_name == optim_scope: + optim_model.append(m) + else: + # opmizer for module in model, such as backbone, neck, head... + if hasattr(model_list[i], optim_scope): + optim_model.append(getattr(model_list[i], optim_scope)) + assert len(optim_model) == 1, \ "Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model)) optim = getattr(optimizer, optim_name)(