From c30df630356604fe0846de769d92a04d0130af61 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Fri, 10 Mar 2023 03:25:04 +0000 Subject: [PATCH] support Static --- ppcls/engine/train/classification.py | 2 +- ppcls/optimizer/__init__.py | 2 +- ppcls/static/program.py | 10 +++++----- ppcls/static/train.py | 6 ++---- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index 5ed04e2a..9074a6b4 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -65,7 +65,7 @@ class ClassTrainer(object): # build optimizer self.optimizer, self.lr_sch = build_optimizer( self.config, self.dataloader.max_iter, - [self.model, self.loss_func], self.update_freq) + [self.model, self.loss_func]) # build model saver self.model_saver = ModelSaver( diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index ccccd6f3..609636ea 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -44,7 +44,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): # model_list is None in static graph -def build_optimizer(config, max_iter, model_list, update_freq): +def build_optimizer(config, max_iter, model_list=None): optim_config = copy.deepcopy(config["Optimizer"]) epochs = config["Global"]["epochs"] update_freq = config["Global"].get("update_freq", 1) diff --git a/ppcls/static/program.py b/ppcls/static/program.py index 188393d1..2464fef1 100644 --- a/ppcls/static/program.py +++ b/ppcls/static/program.py @@ -109,7 +109,7 @@ def create_fetchs(out, else: target = paddle.reshape(feeds['label'], [-1, 1]) - loss_func = build_loss(config["Loss"][mode]) + loss_func = build_loss(config, mode) loss_dict = loss_func(out, target) loss_out = loss_dict["loss"] @@ -117,7 +117,7 @@ def create_fetchs(out, # build metric if not use_mix: - metric_func = build_metrics(config["Metric"][mode]) + metric_func = build_metrics(config, mode) metric_dict = metric_func(out, target) @@ -268,9 +268,9 @@ def build(config, lr_scheduler = None optimizer = None if is_train: - optimizer, lr_scheduler = build_optimizer( - config["Optimizer"], config["Global"]["epochs"], - step_each_epoch) + optimizer, lr_scheduler = build_optimizer(config, + step_each_epoch) + optimizer = mixed_precision_optimizer(config, optimizer) if is_distributed: optimizer = dist_optimizer(config, optimizer) diff --git a/ppcls/static/train.py b/ppcls/static/train.py index 53566267..8255ab8e 100755 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -105,11 +105,9 @@ def main(args): class_num = config["Arch"].get("class_num", None) config["DataLoader"].update({"class_num": class_num}) - train_dataloader = build_dataloader( - config["DataLoader"], "Train", device=device, use_dali=use_dali) + train_dataloader = build_dataloader(config, "Train") if global_config["eval_during_train"]: - eval_dataloader = build_dataloader( - config["DataLoader"], "Eval", device=device, use_dali=use_dali) + eval_dataloader = build_dataloader(config, "Eval") step_each_epoch = len(train_dataloader) -- GitLab