未验证 提交 cc230f01 编写于 作者: u010070587's avatar u010070587 提交者: GitHub

Merge pull request #1277 from TingquanGao/dev/fix_opt

fix: support static graph
......@@ -79,7 +79,7 @@ class UnifiedResize(object):
if isinstance(interpolation, str):
interpolation = _cv2_interp_from_str[interpolation.lower()]
# compatible with opencv < version 4.4.0
elif not interpolation:
elif interpolation is None:
interpolation = cv2.INTER_LINEAR
self.resize_func = partial(cv2.resize, interpolation=interpolation)
elif backend.lower() == "pil":
......
......@@ -60,7 +60,7 @@ class UnifiedResize(object):
if isinstance(interpolation, str):
interpolation = _cv2_interp_from_str[interpolation.lower()]
# compatible with opencv < version 4.4.0
elif not interpolation:
elif interpolation is None:
interpolation = cv2.INTER_LINEAR
self.resize_func = partial(cv2.resize, interpolation=interpolation)
elif backend.lower() == "pil":
......
......@@ -41,7 +41,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
return lr
def build_optimizer(config, epochs, step_each_epoch, model_list):
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config)
# step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
......
......@@ -18,6 +18,8 @@ from __future__ import print_function
from paddle import optimizer as optim
from ppcls.utils import logger
class Momentum(object):
"""
......@@ -43,7 +45,9 @@ class Momentum(object):
self.multi_precision = multi_precision
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
opt = optim.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
......@@ -79,7 +83,9 @@ class Adam(object):
self.multi_precision = multi_precision
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
opt = optim.Adam(
learning_rate=self.learning_rate,
beta1=self.beta1,
......@@ -123,7 +129,9 @@ class RMSProp(object):
self.grad_clip = grad_clip
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
opt = optim.RMSProp(
learning_rate=self.learning_rate,
momentum=self.momentum,
......@@ -160,18 +168,28 @@ class AdamW(object):
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
# TODO(gaotingquan): model_list is None when in static graph, "no_weight_decay" not work.
if model_list is None:
if self.one_dim_param_no_weight_decay or len(
self.no_weight_decay_name_list) != 0:
msg = "\"AdamW\" does not support setting \"no_weight_decay\" in static graph. Please use dynamic graph."
logger.error(Exception(msg))
raise Exception(msg)
self.no_weight_decay_param_name_list = [
p.name for model in model_list for n, p in model.named_parameters()
if any(nd in n for nd in self.no_weight_decay_name_list)
]
] if model_list else []
if self.one_dim_param_no_weight_decay:
self.no_weight_decay_param_name_list += [
p.name for model in model_list
for n, p in model.named_parameters() if len(p.shape) == 1
]
] if model_list else []
opt = optim.AdamW(
learning_rate=self.learning_rate,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册