From a1e840e0daa9f92cd48b0223a94ec226a3ff0500 Mon Sep 17 00:00:00 2001 From: Tingquan Gao <35441050@qq.com> Date: Tue, 14 Mar 2023 16:16:40 +0800 Subject: [PATCH] Revert "refactor: iter_per_epoch -> max_iter" This reverts commit a38e42f644fcba8b60a9672762211b6f7054b290. --- ppcls/data/__init__.py | 21 +++++++++++---------- ppcls/engine/engine.py | 1 - ppcls/engine/evaluation/classification.py | 5 +++-- ppcls/engine/train/regular_train_epoch.py | 4 ++-- ppcls/engine/train/utils.py | 4 ++-- ppcls/optimizer/__init__.py | 8 ++++---- 6 files changed, 22 insertions(+), 21 deletions(-) diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index a964a831..5b470541 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -204,8 +204,6 @@ class DataIterator(object): self.dataloader = dataloader self.use_dali = use_dali self.iterator = iter(dataloader) - self.max_iter = dataloader.max_iter - self.total_samples = dataloader.total_samples def get_batch(self): # fetch data batch from dataloader @@ -236,7 +234,7 @@ def build_dataloader(engine): "epochs": engine.config["Global"]["epochs"] }) - use_dali = engine.use_dali + use_dali = engine.config['Global'].get("use_dali", False) dataloader_dict = { "Train": None, "UnLabelTrain": None, @@ -248,15 +246,18 @@ def build_dataloader(engine): if engine.mode == 'train': train_dataloader = build( engine.config["DataLoader"], "Train", use_dali, seed=None) - - if engine.config["DataLoader"]["Train"].get("max_iter", None): + iter_per_epoch = len(train_dataloader) - 1 if platform.system( + ) == "Windows" else len(train_dataloader) + if engine.config["Global"].get("iter_per_epoch", None): + # TODO(gaotingquan): iter_per_epoch should be set in Dataloader.Train, not Global # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. - max_iter = engine.config["Train"].get("max_iter") - max_iter = train_dataloader.max_iter // engine.update_freq * engine.update_freq - train_dataloader.max_iter = max_iter - if engine.config["DataLoader"]["Train"].get("convert_iterator", True): - train_dataloader = DataIterator(train_dataloader, use_dali) + iter_per_epoch = engine.config["Global"].get("iter_per_epoch") + iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq + # engine.iter_per_epoch = iter_per_epoch + train_dataloader.iter_per_epoch = iter_per_epoch dataloader_dict["Train"] = train_dataloader + # TODO(gaotingquan): set the iterator field in config, such as Dataloader.Train.convert_iterator=True + dataloader_dict["TrainIter"] = DataIterator(train_dataloader, use_dali) if engine.config["DataLoader"].get('UnLabelTrain', None) is not None: dataloader_dict["UnLabelTrain"] = build( diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 307cc968..147cd3b8 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -72,7 +72,6 @@ class Engine(object): self.update_freq = self.config["Global"].get("update_freq", 1) # build dataloader - self.use_dali = self.config["Global"].get("use_dali", False) self.dataloader_dict = build_dataloader(self) # build loss diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index e37f0e2a..ada851c0 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -92,7 +92,8 @@ def classification_eval(engine, epoch_id=0): paddle.distributed.all_gather(pred_list, out) preds = paddle.concat(pred_list, 0) - if accum_samples > total_samples and not engine.use_dali: + if accum_samples > total_samples and not engine.config[ + "Global"].get("use_dali", False): if isinstance(preds, list): preds = [ pred[:total_samples + current_samples - accum_samples] @@ -151,7 +152,7 @@ def classification_eval(engine, epoch_id=0): epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg)) tic = time.time() - if engine.use_dali: + if engine.config["Global"].get("use_dali", False): engine.dataloader_dict["Eval"].reset() if "ATTRMetric" in engine.config["Metric"]["Eval"][0]: diff --git a/ppcls/engine/train/regular_train_epoch.py b/ppcls/engine/train/regular_train_epoch.py index 78629396..ac2a0013 100644 --- a/ppcls/engine/train/regular_train_epoch.py +++ b/ppcls/engine/train/regular_train_epoch.py @@ -22,8 +22,8 @@ from ppcls.utils import profiler def regular_train_epoch(engine, epoch_id, print_batch_step): tic = time.time() - for iter_id in range(engine.dataloader_dict["Train"].max_iter): - batch = engine.dataloader_dict["Train"].get_batch() + for iter_id in range(engine.dataloader_dict["Train"].iter_per_epoch): + batch = engine.dataloader_dict["TrainIter"].get_batch() profiler.add_profiler_step(engine.config["profiler_options"]) if iter_id == 5: diff --git a/ppcls/engine/train/utils.py b/ppcls/engine/train/utils.py index c31084e8..ea68714b 100644 --- a/ppcls/engine/train/utils.py +++ b/ppcls/engine/train/utils.py @@ -55,13 +55,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id): batch_size / trainer.time_info["batch_cost"].avg) eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1 - ) * trainer.dataloader_dict["Train"].max_iter - iter_id + ) * trainer.dataloader_dict["Train"].iter_per_epoch - iter_id ) * trainer.time_info["batch_cost"].avg eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec)))) logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format( epoch_id, trainer.config["Global"][ "epochs"], iter_id, trainer.dataloader_dict["Train"] - .max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) + .iter_per_epoch, lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) for i, lr in enumerate(trainer.lr_sch): logger.scaler( diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index 559a340d..8e19325b 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -48,12 +48,12 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): def build_optimizer(engine): if engine.mode != "train": return None, None - config, max_iter, model_list = engine.config, engine.dataloader_dict[ - "Train"].max_iter, [engine.model, engine.train_loss_func] + config, iter_per_epoch, model_list = engine.config, engine.dataloader_dict[ + "Train"].iter_per_epoch, [engine.mode, engine.train_loss_func] optim_config = copy.deepcopy(config["Optimizer"]) epochs = config["Global"]["epochs"] - update_freq = engine.update_freq - step_each_epoch = max_iter // update_freq + update_freq = config["Global"].get("update_freq", 1) + step_each_epoch = iter_per_epoch // update_freq if isinstance(optim_config, dict): # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}] optim_name = optim_config.pop("name") -- GitLab