提交 7dcb2d4f 编写于 作者: G gaotingquan

fix: raise exception

raise exception about using no_weight_decay of AdamW in static graph
上级 c7aeec28
......@@ -18,6 +18,8 @@ from __future__ import print_function
from paddle import optimizer as optim
from ppcls.utils import logger
class Momentum(object):
"""
......@@ -171,6 +173,13 @@ class AdamW(object):
[]) if model_list else None
# TODO(gaotingquan): model_list is None when in static graph, "no_weight_decay" not work.
if model_list is None:
if self.one_dim_param_no_weight_decay or len(
self.no_weight_decay_name_list) != 0:
msg = "\"AdamW\" does not support setting \"no_weight_decay\" in static graph. Please use dynamic graph."
logger.error(Exception(msg))
raise Exception(msg)
self.no_weight_decay_param_name_list = [
p.name for model in model_list for n, p in model.named_parameters()
if any(nd in n for nd in self.no_weight_decay_name_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册