diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index ce6b4919825b0a9d8669c9d66c2bebd0808313d1..5b470541e51dcf545d5679a1b7f7357acd284d18 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -187,10 +187,37 @@ def build(config, mode, use_dali=False, seed=None): collate_fn=batch_collate_fn, worker_init_fn=init_fn) + total_samples = len( + data_loader.dataset) if not use_dali else data_loader.size + max_iter = len(data_loader) - 1 if platform.system() == "Windows" else len( + data_loader) + data_loader.max_iter = max_iter + data_loader.total_samples = total_samples + logger.debug("build data_loader({}) success...".format(data_loader)) return data_loader +# TODO(gaotingquan): perf +class DataIterator(object): + def __init__(self, dataloader, use_dali=False): + self.dataloader = dataloader + self.use_dali = use_dali + self.iterator = iter(dataloader) + + def get_batch(self): + # fetch data batch from dataloader + try: + batch = next(self.iterator) + except Exception: + # NOTE: reset DALI dataloader manually + if self.use_dali: + self.dataloader.reset() + self.iterator = iter(self.dataloader) + batch = next(self.iterator) + return batch + + def build_dataloader(engine): if "class_num" in engine.config["Global"]: global_class_num = engine.config["Global"]["class_num"] @@ -222,12 +249,15 @@ def build_dataloader(engine): iter_per_epoch = len(train_dataloader) - 1 if platform.system( ) == "Windows" else len(train_dataloader) if engine.config["Global"].get("iter_per_epoch", None): + # TODO(gaotingquan): iter_per_epoch should be set in Dataloader.Train, not Global # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. iter_per_epoch = engine.config["Global"].get("iter_per_epoch") iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq - engine.iter_per_epoch = iter_per_epoch + # engine.iter_per_epoch = iter_per_epoch train_dataloader.iter_per_epoch = iter_per_epoch dataloader_dict["Train"] = train_dataloader + # TODO(gaotingquan): set the iterator field in config, such as Dataloader.Train.convert_iterator=True + dataloader_dict["TrainIter"] = DataIterator(train_dataloader, use_dali) if engine.config["DataLoader"].get('UnLabelTrain', None) is not None: dataloader_dict["UnLabelTrain"] = build( @@ -249,5 +279,4 @@ def build_dataloader(engine): engine.config["DataLoader"]["Eval"], "Gallery", use_dali) dataloader_dict["Query"] = build( engine.config["DataLoader"]["Eval"], "Query", use_dali) - return dataloader_dict diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index a11f00ed9b77c9224ee51e57fb3e3b1a14a10588..147cd3b802c6fbc374e102066b87061edd867ca8 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -63,7 +63,7 @@ class Engine(object): # init train_func and eval_func self.train_epoch_func = build_train_epoch_func(self.config) - self.eval_epoch_func = build_eval_func(self.config) + self.eval_func = build_eval_func(self.config) # set device self._init_device() @@ -73,12 +73,6 @@ class Engine(object): # build dataloader self.dataloader_dict = build_dataloader(self) - self.train_dataloader, self.unlabel_train_dataloader, self.eval_dataloader = self.dataloader_dict[ - "Train"], self.dataloader_dict[ - "UnLabelTrain"], self.dataloader_dict["Eval"] - self.gallery_query_dataloader, self.gallery_dataloader, self.query_dataloader = self.dataloader_dict[ - "GalleryQuery"], self.dataloader_dict[ - "Gallery"], self.dataloader_dict["Query"] # build loss self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss( @@ -94,9 +88,7 @@ class Engine(object): self._init_pretrained() # build optimizer - self.optimizer, self.lr_sch = build_optimizer( - self.config, self.train_dataloader, - [self.model, self.train_loss_func]) + self.optimizer, self.lr_sch = build_optimizer(self) # AMP training and evaluating self._init_amp() diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 637b54f8cb7844d3dcb7e4d73231b35123b9c2bc..ada851c097cb1e7c4c4fb3d2f67a42440d54bf3b 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -35,13 +35,10 @@ def classification_eval(engine, epoch_id=0): print_batch_step = engine.config["Global"]["print_batch_step"] tic = time.time() + total_samples = engine.dataloader_dict["Eval"].total_samples accum_samples = 0 - total_samples = len( - engine.eval_dataloader. - dataset) if not engine.use_dali else engine.eval_dataloader.size - max_iter = len(engine.eval_dataloader) - 1 if platform.system( - ) == "Windows" else len(engine.eval_dataloader) - for iter_id, batch in enumerate(engine.eval_dataloader): + max_iter = engine.dataloader_dict["Eval"].max_iter + for iter_id, batch in enumerate(engine.dataloader_dict["Eval"]): if iter_id >= max_iter: break if iter_id == 5: @@ -61,9 +58,9 @@ def classification_eval(engine, epoch_id=0): "flatten_contiguous_range", "greater_than" }, level=engine.amp_level): - out = engine.model(batch[0]) + out = engine.model(batch) else: - out = engine.model(batch[0]) + out = engine.model(batch) # just for DistributedBatchSampler issue: repeat sampling current_samples = batch_size * paddle.distributed.get_world_size() @@ -95,7 +92,8 @@ def classification_eval(engine, epoch_id=0): paddle.distributed.all_gather(pred_list, out) preds = paddle.concat(pred_list, 0) - if accum_samples > total_samples and not engine.use_dali: + if accum_samples > total_samples and not engine.config[ + "Global"].get("use_dali", False): if isinstance(preds, list): preds = [ pred[:total_samples + current_samples - accum_samples] @@ -151,12 +149,11 @@ def classification_eval(engine, epoch_id=0): ]) metric_msg += ", {}".format(engine.eval_metric_func.avg_info) logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format( - epoch_id, iter_id, - len(engine.eval_dataloader), metric_msg, time_msg, ips_msg)) + epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg)) tic = time.time() - if engine.use_dali: - engine.eval_dataloader.reset() + if engine.config["Global"].get("use_dali", False): + engine.dataloader_dict["Eval"].reset() if "ATTRMetric" in engine.config["Metric"]["Eval"][0]: metric_msg = ", ".join([ diff --git a/ppcls/engine/train/regular_train_epoch.py b/ppcls/engine/train/regular_train_epoch.py index f5933d74b88f9d517ced4287d8b240810caf6c3a..ac2a00134cc540af022bb3a05283a5ed9a3f67d2 100644 --- a/ppcls/engine/train/regular_train_epoch.py +++ b/ppcls/engine/train/regular_train_epoch.py @@ -22,19 +22,8 @@ from ppcls.utils import profiler def regular_train_epoch(engine, epoch_id, print_batch_step): tic = time.time() - if not hasattr(engine, "train_dataloader_iter"): - engine.train_dataloader_iter = iter(engine.train_dataloader) - - for iter_id in range(engine.iter_per_epoch): - # fetch data batch from dataloader - try: - batch = next(engine.train_dataloader_iter) - except Exception: - # NOTE: reset DALI dataloader manually - if engine.use_dali: - engine.train_dataloader.reset() - engine.train_dataloader_iter = iter(engine.train_dataloader) - batch = next(engine.train_dataloader_iter) + for iter_id in range(engine.dataloader_dict["Train"].iter_per_epoch): + batch = engine.dataloader_dict["TrainIter"].get_batch() profiler.add_profiler_step(engine.config["profiler_options"]) if iter_id == 5: diff --git a/ppcls/engine/train/utils.py b/ppcls/engine/train/utils.py index 091a64d8326699100538ceb238ca454e3192b1af..ea68714bf03f40dc2e655aae9ac7508c91543110 100644 --- a/ppcls/engine/train/utils.py +++ b/ppcls/engine/train/utils.py @@ -54,13 +54,14 @@ def log_info(trainer, batch_size, epoch_id, iter_id): ips_msg = "ips: {:.5f} samples/s".format( batch_size / trainer.time_info["batch_cost"].avg) - eta_sec = ( - (trainer.config["Global"]["epochs"] - epoch_id + 1) * - trainer.iter_per_epoch - iter_id) * trainer.time_info["batch_cost"].avg + eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1 + ) * trainer.dataloader_dict["Train"].iter_per_epoch - iter_id + ) * trainer.time_info["batch_cost"].avg eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec)))) logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format( - epoch_id, trainer.config["Global"]["epochs"], iter_id, trainer. - iter_per_epoch, lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) + epoch_id, trainer.config["Global"][ + "epochs"], iter_id, trainer.dataloader_dict["Train"] + .iter_per_epoch, lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) for i, lr in enumerate(trainer.lr_sch): logger.scaler( diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index 77a3e48ef6de511521899fd467ca916528bfb991..55a3f345ace90d10e7214c8dc0d1446289b88fc9 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -70,8 +70,9 @@ def build_metrics(engine): if mode == 'train' and "Metric" in config and "Train" in config[ "Metric"] and config["Metric"]["Train"]: metric_config = config["Metric"]["Train"] - if hasattr(engine.train_dataloader, "collate_fn" - ) and engine.train_dataloader.collate_fn is not None: + if hasattr(engine.dataloader_dict["Train"], + "collate_fn") and engine.dataloader_dict[ + "Train"].collate_fn is not None: for m_idx, m in enumerate(metric_config): if "TopkAcc" in m: msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed." diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index 782a80d688bb09b415af8e514f64cd4ca58fd1d3..8e19325b7ef8da799f65c8df89ece7bfd9a32ffc 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -45,11 +45,15 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): # model_list is None in static graph -def build_optimizer(config, dataloader, model_list=None): +def build_optimizer(engine): + if engine.mode != "train": + return None, None + config, iter_per_epoch, model_list = engine.config, engine.dataloader_dict[ + "Train"].iter_per_epoch, [engine.mode, engine.train_loss_func] optim_config = copy.deepcopy(config["Optimizer"]) epochs = config["Global"]["epochs"] update_freq = config["Global"].get("update_freq", 1) - step_each_epoch = dataloader.iter_per_epoch // update_freq + step_each_epoch = iter_per_epoch // update_freq if isinstance(optim_config, dict): # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}] optim_name = optim_config.pop("name")