From 7dcb2d4fd0fab1cfb021fc25bd868976100e5559 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Thu, 30 Sep 2021 10:48:23 +0000 Subject: [PATCH] fix: raise exception raise exception about using no_weight_decay of AdamW in static graph --- ppcls/optimizer/optimizer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index eb6e4f4a..f429755f 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -18,6 +18,8 @@ from __future__ import print_function from paddle import optimizer as optim +from ppcls.utils import logger + class Momentum(object): """ @@ -171,6 +173,13 @@ class AdamW(object): []) if model_list else None # TODO(gaotingquan): model_list is None when in static graph, "no_weight_decay" not work. + if model_list is None: + if self.one_dim_param_no_weight_decay or len( + self.no_weight_decay_name_list) != 0: + msg = "\"AdamW\" does not support setting \"no_weight_decay\" in static graph. Please use dynamic graph." + logger.error(Exception(msg)) + raise Exception(msg) + self.no_weight_decay_param_name_list = [ p.name for model in model_list for n, p in model.named_parameters() if any(nd in n for nd in self.no_weight_decay_name_list) -- GitLab