diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index cc64a9caaaf2c90099f111c8863515f0bebec351..61db39f89a3775eec70fe3cfa63d7e820e3e60e6 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -41,7 +41,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): return lr -def build_optimizer(config, epochs, step_each_epoch, model_list): +# model_list is None in static graph +def build_optimizer(config, epochs, step_each_epoch, model_list=None): config = copy.deepcopy(config) # step1 build lr lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch) diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index 72310f27b4ea15b5ba738820e5a5b7858f7ff2eb..eb6e4f4ab4dee24107761f3ce4c80ea5fb20887d 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -43,7 +43,9 @@ class Momentum(object): self.multi_precision = multi_precision def __call__(self, model_list): - parameters = sum([m.parameters() for m in model_list], []) + # model_list is None in static graph + parameters = sum([m.parameters() for m in model_list], + []) if model_list else None opt = optim.Momentum( learning_rate=self.learning_rate, momentum=self.momentum, @@ -79,7 +81,9 @@ class Adam(object): self.multi_precision = multi_precision def __call__(self, model_list): - parameters = sum([m.parameters() for m in model_list], []) + # model_list is None in static graph + parameters = sum([m.parameters() for m in model_list], + []) if model_list else None opt = optim.Adam( learning_rate=self.learning_rate, beta1=self.beta1, @@ -123,7 +127,9 @@ class RMSProp(object): self.grad_clip = grad_clip def __call__(self, model_list): - parameters = sum([m.parameters() for m in model_list], []) + # model_list is None in static graph + parameters = sum([m.parameters() for m in model_list], + []) if model_list else None opt = optim.RMSProp( learning_rate=self.learning_rate, momentum=self.momentum, @@ -160,18 +166,21 @@ class AdamW(object): self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay def __call__(self, model_list): - parameters = sum([m.parameters() for m in model_list], []) + # model_list is None in static graph + parameters = sum([m.parameters() for m in model_list], + []) if model_list else None + # TODO(gaotingquan): model_list is None when in static graph, "no_weight_decay" not work. self.no_weight_decay_param_name_list = [ p.name for model in model_list for n, p in model.named_parameters() if any(nd in n for nd in self.no_weight_decay_name_list) - ] + ] if model_list else [] if self.one_dim_param_no_weight_decay: self.no_weight_decay_param_name_list += [ p.name for model in model_list for n, p in model.named_parameters() if len(p.shape) == 1 - ] + ] if model_list else [] opt = optim.AdamW( learning_rate=self.learning_rate,