未验证 提交 0218742b 编写于 作者: W Walter 提交者: GitHub

Merge pull request #1868 from HydrogenSulfate/fix_optimizer_init

fix bug for static graph
...@@ -224,7 +224,7 @@ class Engine(object): ...@@ -224,7 +224,7 @@ class Engine(object):
# build optimizer # build optimizer
if self.mode == 'train': if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer( self.optimizer, self.lr_sch = build_optimizer(
self.config, self.config["Global"]["epochs"], self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader), len(self.train_dataloader),
[self.model, self.train_loss_func]) [self.model, self.train_loss_func])
......
...@@ -44,8 +44,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -44,8 +44,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph # model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None): def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config) optim_config = copy.deepcopy(config)
optim_config = config["Optimizer"]
if isinstance(optim_config, dict): if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}] # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name") optim_name = optim_config.pop("name")
...@@ -93,6 +92,15 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): ...@@ -93,6 +92,15 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
else: else:
grad_clip = None grad_clip = None
optim_model = [] optim_model = []
# for static graph
if model_list is None:
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
return optim, lr
# for dynamic graph
for i in range(len(model_list)): for i in range(len(model_list)):
if len(model_list[i].parameters()) == 0: if len(model_list[i].parameters()) == 0:
continue continue
...@@ -103,7 +111,7 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): ...@@ -103,7 +111,7 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
if optim_scope.endswith("Loss"): if optim_scope.endswith("Loss"):
# optimizer for loss # optimizer for loss
for m in model_list[i].sublayers(True): for m in model_list[i].sublayers(True):
if m.__class_name == optim_scope: if m.__class__.__name__ == optim_scope:
optim_model.append(m) optim_model.append(m)
else: else:
# opmizer for module in model, such as backbone, neck, head... # opmizer for module in model, such as backbone, neck, head...
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册