提交 3a1276d3 编写于 作者: H HydrogenSulfate

train_loss_func only used in train mode

上级 24abea15
......@@ -214,17 +214,17 @@ class Engine(object):
if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"):
load_dygraph_pretrain_from_url(
[self.model, self.train_loss_func],
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
else:
load_dygraph_pretrain(
[self.model, self.train_loss_func],
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
# build optimizer
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"],
self.config, self.config["Global"]["epochs"],
len(self.train_dataloader),
[self.model, self.train_loss_func])
......@@ -259,7 +259,8 @@ class Engine(object):
if self.config["Global"]["distributed"]:
dist.init_parallel_env()
self.model = paddle.DataParallel(self.model)
if len(self.train_loss_func.parameters()) > 0:
if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
# build postprocess for infer
......
......@@ -45,19 +45,20 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config)
if isinstance(config, dict):
# convert to [{optim_name1: {scope: xxx, **optim_cfg}}, {optim_name2: {scope: xxx, **optim_cfg}}, ...]
optim_name = config.Optimizer.pop('name')
config: List[Dict[str, Dict]] = [{
optim_config = config["Optimizer"]
if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name")
optim_config: List[Dict[str, Dict]] = [{
optim_name: {
'scope': config.Arch.name,
'scope': config["Arch"].get("name"),
**
config.Optimizer
optim_config
}
}]
optim_list = []
lr_list = []
for optim_item in config:
for optim_item in optim_config:
# optim_cfg = {optim_name1: {scope: xxx, **optim_cfg}}
# step1 build lr
optim_name = optim_item.keys()[0] # get optim_name1
......
......@@ -49,7 +49,8 @@ def load_dygraph_pretrain(model, path=None):
param_state_dict = paddle.load(path + ".pdparams")
if isinstance(model, list):
for m in model:
m.set_dict(param_state_dict)
if hasattr(m, 'set_dict'):
m.set_dict(param_state_dict)
else:
model.set_dict(param_state_dict)
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册