提交 c30df630 编写于 作者: G gaotingquan 提交者: Wei Shengyu

support Static

上级 392b75b1
......@@ -65,7 +65,7 @@ class ClassTrainer(object):
# build optimizer
self.optimizer, self.lr_sch = build_optimizer(
self.config, self.dataloader.max_iter,
[self.model, self.loss_func], self.update_freq)
[self.model, self.loss_func])
# build model saver
self.model_saver = ModelSaver(
......
......@@ -44,7 +44,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def build_optimizer(config, max_iter, model_list, update_freq):
def build_optimizer(config, max_iter, model_list=None):
optim_config = copy.deepcopy(config["Optimizer"])
epochs = config["Global"]["epochs"]
update_freq = config["Global"].get("update_freq", 1)
......
......@@ -109,7 +109,7 @@ def create_fetchs(out,
else:
target = paddle.reshape(feeds['label'], [-1, 1])
loss_func = build_loss(config["Loss"][mode])
loss_func = build_loss(config, mode)
loss_dict = loss_func(out, target)
loss_out = loss_dict["loss"]
......@@ -117,7 +117,7 @@ def create_fetchs(out,
# build metric
if not use_mix:
metric_func = build_metrics(config["Metric"][mode])
metric_func = build_metrics(config, mode)
metric_dict = metric_func(out, target)
......@@ -268,9 +268,9 @@ def build(config,
lr_scheduler = None
optimizer = None
if is_train:
optimizer, lr_scheduler = build_optimizer(
config["Optimizer"], config["Global"]["epochs"],
step_each_epoch)
optimizer, lr_scheduler = build_optimizer(config,
step_each_epoch)
optimizer = mixed_precision_optimizer(config, optimizer)
if is_distributed:
optimizer = dist_optimizer(config, optimizer)
......
......@@ -105,11 +105,9 @@ def main(args):
class_num = config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num})
train_dataloader = build_dataloader(
config["DataLoader"], "Train", device=device, use_dali=use_dali)
train_dataloader = build_dataloader(config, "Train")
if global_config["eval_during_train"]:
eval_dataloader = build_dataloader(
config["DataLoader"], "Eval", device=device, use_dali=use_dali)
eval_dataloader = build_dataloader(config, "Eval")
step_each_epoch = len(train_dataloader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册