From 73e2cde617df2a13ffe647d1fceb5ab4ec967f8b Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Mon, 27 Feb 2023 14:03:55 +0000 Subject: [PATCH] mv some attrs to __init__() --- ppcls/engine/engine.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index a0bd0376..cb43b6a9 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -50,6 +50,9 @@ class Engine(object): assert mode in ["train", "eval", "infer", "export"] self.mode = mode self.config = config + self.start_eval_epoch = self.config["Global"].get("start_eval_epoch", + 0) - 1 + self.epochs = self.config["Global"].get("epochs", 1) # set seed self._init_seed() @@ -113,8 +116,6 @@ class Engine(object): assert self.mode == "train" print_batch_step = self.config['Global']['print_batch_step'] save_interval = self.config["Global"]["save_interval"] - start_eval_epoch = self.config["Global"].get("start_eval_epoch", 0) - 1 - epochs = self.config["Global"]["epochs"] best_metric = { "metric": -1.0, @@ -140,20 +141,20 @@ class Engine(object): # global iter counter self.global_step = 0 - for epoch_id in range(best_metric["epoch"] + 1, epochs + 1): + for epoch_id in range(best_metric["epoch"] + 1, self.epochs + 1): # for one epoch train self.train_epoch_func(self, epoch_id, print_batch_step) metric_msg = ", ".join( [self.output_info[key].avg_info for key in self.output_info]) - logger.info("[Train][Epoch {}/{}][Avg]{}".format(epoch_id, epochs, - metric_msg)) + logger.info("[Train][Epoch {}/{}][Avg]{}".format( + epoch_id, self.epochs, metric_msg)) self.output_info.clear() acc = 0.0 if self.config["Global"][ "eval_during_train"] and epoch_id % self.config["Global"][ - "eval_interval"] == 0 and epoch_id > start_eval_epoch: + "eval_interval"] == 0 and epoch_id > self.start_eval_epoch: acc = self.eval(epoch_id) # step lr (by epoch) according to given metric, such as acc @@ -469,7 +470,6 @@ class Engine(object): self.model_ema) if metric_info is not None: best_metric.update(metric_info) - return best_metric class ExportModel(TheseusLayer): -- GitLab