From 73f4d8e4cedf169b3978b0bb8c7743cf2c5ed6cd Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Wed, 12 Apr 2023 11:48:56 +0000 Subject: [PATCH] to avoid cause issues for unset no_weight_decay models. there seems be a diff for optimizer about using [] and [{"params":}, {"params":}] params --- ppcls/optimizer/optimizer.py | 41 +++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index 77269949..87190f56 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -113,22 +113,26 @@ class Momentum(object): # model_list is None in static graph parameters = None if model_list: - params_with_decay = [] - params_without_decay = [] - for m in model_list: - for n, p in m.named_parameters(): - if any(nd in n for nd in self.no_weight_decay_name_list) \ - or (self.one_dim_param_no_weight_decay and len(p.shape) == 1): - params_without_decay.append(p) - else: - params_with_decay.append(p) - parameters = [{ - "params": params_with_decay, - "weight_decay": self.weight_decay - }, { - "params": params_without_decay, - "weight_decay": 0.0 - }] + # TODO(gaotingquan): to avoid cause issues for unset no_weight_decay models + if len(self.no_weight_decay_name_list) > 0: + params_with_decay = [] + params_without_decay = [] + for m in model_list: + for n, p in m.named_parameters(): + if any(nd in n for nd in self.no_weight_decay_name_list) \ + or (self.one_dim_param_no_weight_decay and len(p.shape) == 1): + params_without_decay.append(p) + else: + params_with_decay.append(p) + parameters = [{ + "params": params_with_decay, + "weight_decay": self.weight_decay + }, { + "params": params_without_decay, + "weight_decay": 0.0 + }] + else: + parameters = sum([m.parameters() for m in model_list], []) opt = optim.Momentum( learning_rate=self.learning_rate, momentum=self.momentum, @@ -279,9 +283,8 @@ class AdamW(object): if self.one_dim_param_no_weight_decay: self.no_weight_decay_param_name_list += [ - p.name - for model in model_list for n, p in model.named_parameters() - if len(p.shape) == 1 + p.name for model in model_list + for n, p in model.named_parameters() if len(p.shape) == 1 ] if model_list else [] opt = optim.AdamW( -- GitLab