diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 71a7e182ae925e18173589dd80bb345bb04a1109..7a7bbde6b4bd431275f8e5def2190d3061c5ec6f 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -224,7 +224,7 @@ class Engine(object): # build optimizer if self.mode == 'train': self.optimizer, self.lr_sch = build_optimizer( - self.config, self.config["Global"]["epochs"], + self.config["Optimizer"], self.config["Global"]["epochs"], len(self.train_dataloader), [self.model, self.train_loss_func]) diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index a440eac463e2a5ebd74cf676b2487973baea9546..3f63777202c1cc31292bcdfd6077b3b995302be8 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -44,8 +44,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): # model_list is None in static graph def build_optimizer(config, epochs, step_each_epoch, model_list=None): - config = copy.deepcopy(config) - optim_config = config["Optimizer"] + optim_config = copy.deepcopy(config) if isinstance(optim_config, dict): # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}] optim_name = optim_config.pop("name") @@ -93,6 +92,15 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): else: grad_clip = None optim_model = [] + + # for static graph + if model_list is None: + optim = getattr(optimizer, optim_name)( + learning_rate=lr, grad_clip=grad_clip, + **optim_cfg)(model_list=optim_model) + return optim, lr + + # for dynamic graph for i in range(len(model_list)): if len(model_list[i].parameters()) == 0: continue