提交 c7aeec28 编写于 作者: G gaotingquan

fix: support static graph

上级 0dccfb91
......@@ -41,7 +41,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
return lr
def build_optimizer(config, epochs, step_each_epoch, model_list):
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config)
# step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
......
......@@ -43,7 +43,9 @@ class Momentum(object):
self.multi_precision = multi_precision
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
opt = optim.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
......@@ -79,7 +81,9 @@ class Adam(object):
self.multi_precision = multi_precision
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
opt = optim.Adam(
learning_rate=self.learning_rate,
beta1=self.beta1,
......@@ -123,7 +127,9 @@ class RMSProp(object):
self.grad_clip = grad_clip
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
opt = optim.RMSProp(
learning_rate=self.learning_rate,
momentum=self.momentum,
......@@ -160,18 +166,21 @@ class AdamW(object):
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
# TODO(gaotingquan): model_list is None when in static graph, "no_weight_decay" not work.
self.no_weight_decay_param_name_list = [
p.name for model in model_list for n, p in model.named_parameters()
if any(nd in n for nd in self.no_weight_decay_name_list)
]
] if model_list else []
if self.one_dim_param_no_weight_decay:
self.no_weight_decay_param_name_list += [
p.name for model in model_list
for n, p in model.named_parameters() if len(p.shape) == 1
]
] if model_list else []
opt = optim.AdamW(
learning_rate=self.learning_rate,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册