From c7aeec28e248a7c6b08aa2553d5f18616bf50659 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Thu, 30 Sep 2021 06:52:15 +0000 Subject: [PATCH] fix: support static graph --- ppcls/optimizer/__init__.py | 3 ++- ppcls/optimizer/optimizer.py | 21 +++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index cc64a9ca..61db39f8 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -41,7 +41,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): return lr -def build_optimizer(config, epochs, step_each_epoch, model_list): +# model_list is None in static graph +def build_optimizer(config, epochs, step_each_epoch, model_list=None): config = copy.deepcopy(config) # step1 build lr lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch) diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index 72310f27..eb6e4f4a 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -43,7 +43,9 @@ class Momentum(object): self.multi_precision = multi_precision def __call__(self, model_list): - parameters = sum([m.parameters() for m in model_list], []) + # model_list is None in static graph + parameters = sum([m.parameters() for m in model_list], + []) if model_list else None opt = optim.Momentum( learning_rate=self.learning_rate, momentum=self.momentum, @@ -79,7 +81,9 @@ class Adam(object): self.multi_precision = multi_precision def __call__(self, model_list): - parameters = sum([m.parameters() for m in model_list], []) + # model_list is None in static graph + parameters = sum([m.parameters() for m in model_list], + []) if model_list else None opt = optim.Adam( learning_rate=self.learning_rate, beta1=self.beta1, @@ -123,7 +127,9 @@ class RMSProp(object): self.grad_clip = grad_clip def __call__(self, model_list): - parameters = sum([m.parameters() for m in model_list], []) + # model_list is None in static graph + parameters = sum([m.parameters() for m in model_list], + []) if model_list else None opt = optim.RMSProp( learning_rate=self.learning_rate, momentum=self.momentum, @@ -160,18 +166,21 @@ class AdamW(object): self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay def __call__(self, model_list): - parameters = sum([m.parameters() for m in model_list], []) + # model_list is None in static graph + parameters = sum([m.parameters() for m in model_list], + []) if model_list else None + # TODO(gaotingquan): model_list is None when in static graph, "no_weight_decay" not work. self.no_weight_decay_param_name_list = [ p.name for model in model_list for n, p in model.named_parameters() if any(nd in n for nd in self.no_weight_decay_name_list) - ] + ] if model_list else [] if self.one_dim_param_no_weight_decay: self.no_weight_decay_param_name_list += [ p.name for model in model_list for n, p in model.named_parameters() if len(p.shape) == 1 - ] + ] if model_list else [] opt = optim.AdamW( learning_rate=self.learning_rate, -- GitLab