提交 3a1276d3 编写于 作者: H HydrogenSulfate

train_loss_func only used in train mode

上级 24abea15
...@@ -214,17 +214,17 @@ class Engine(object): ...@@ -214,17 +214,17 @@ class Engine(object):
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"): if self.config["Global"]["pretrained_model"].startswith("http"):
load_dygraph_pretrain_from_url( load_dygraph_pretrain_from_url(
[self.model, self.train_loss_func], [self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"]) self.config["Global"]["pretrained_model"])
else: else:
load_dygraph_pretrain( load_dygraph_pretrain(
[self.model, self.train_loss_func], [self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"]) self.config["Global"]["pretrained_model"])
# build optimizer # build optimizer
if self.mode == 'train': if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer( self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"], self.config, self.config["Global"]["epochs"],
len(self.train_dataloader), len(self.train_dataloader),
[self.model, self.train_loss_func]) [self.model, self.train_loss_func])
...@@ -259,7 +259,8 @@ class Engine(object): ...@@ -259,7 +259,8 @@ class Engine(object):
if self.config["Global"]["distributed"]: if self.config["Global"]["distributed"]:
dist.init_parallel_env() dist.init_parallel_env()
self.model = paddle.DataParallel(self.model) self.model = paddle.DataParallel(self.model)
if len(self.train_loss_func.parameters()) > 0: if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.DataParallel( self.train_loss_func = paddle.DataParallel(
self.train_loss_func) self.train_loss_func)
# build postprocess for infer # build postprocess for infer
......
...@@ -45,19 +45,20 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -45,19 +45,20 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph # model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None): def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
if isinstance(config, dict): optim_config = config["Optimizer"]
# convert to [{optim_name1: {scope: xxx, **optim_cfg}}, {optim_name2: {scope: xxx, **optim_cfg}}, ...] if isinstance(optim_config, dict):
optim_name = config.Optimizer.pop('name') # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
config: List[Dict[str, Dict]] = [{ optim_name = optim_config.pop("name")
optim_config: List[Dict[str, Dict]] = [{
optim_name: { optim_name: {
'scope': config.Arch.name, 'scope': config["Arch"].get("name"),
** **
config.Optimizer optim_config
} }
}] }]
optim_list = [] optim_list = []
lr_list = [] lr_list = []
for optim_item in config: for optim_item in optim_config:
# optim_cfg = {optim_name1: {scope: xxx, **optim_cfg}} # optim_cfg = {optim_name1: {scope: xxx, **optim_cfg}}
# step1 build lr # step1 build lr
optim_name = optim_item.keys()[0] # get optim_name1 optim_name = optim_item.keys()[0] # get optim_name1
......
...@@ -49,7 +49,8 @@ def load_dygraph_pretrain(model, path=None): ...@@ -49,7 +49,8 @@ def load_dygraph_pretrain(model, path=None):
param_state_dict = paddle.load(path + ".pdparams") param_state_dict = paddle.load(path + ".pdparams")
if isinstance(model, list): if isinstance(model, list):
for m in model: for m in model:
m.set_dict(param_state_dict) if hasattr(m, 'set_dict'):
m.set_dict(param_state_dict)
else: else:
model.set_dict(param_state_dict) model.set_dict(param_state_dict)
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册