From d3374e897e162053d93a20c21142135c3e7ee11c Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Wed, 8 Mar 2023 08:12:09 +0000 Subject: [PATCH] revert for running --- ppcls/data/__init__.py | 22 ---------------------- ppcls/engine/evaluation/classification.py | 2 +- ppcls/engine/train/classification.py | 15 +++++++++++++-- ppcls/metric/__init__.py | 23 +++++++++-------------- 4 files changed, 23 insertions(+), 39 deletions(-) diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 9915dc2c..5bc21e15 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -207,25 +207,3 @@ def build_dataloader(config, mode, seed=None): logger.debug("build data_loader({}) success...".format(data_loader)) return data_loader - - -# # TODO(gaotingquan): the length of dataloader should be determined by sampler -# class DataIterator(object): -# def __init__(self, dataloader, use_dali=False): -# self.dataloader = dataloader -# self.use_dali = use_dali -# self.iterator = iter(dataloader) -# self.max_iter = dataloader.max_iter -# self.total_samples = dataloader.total_samples - -# def get_batch(self): -# # fetch data batch from dataloader -# try: -# batch = next(self.iterator) -# except Exception: -# # NOTE: reset DALI dataloader manually -# if self.use_dali: -# self.dataloader.reset() -# self.iterator = iter(self.dataloader) -# batch = next(self.iterator) -# return batch diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 07f1703f..a35ff1ea 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -31,7 +31,7 @@ class ClassEval(object): self.model = model self.print_batch_step = self.config["Global"]["print_batch_step"] self.use_dali = self.config["Global"].get("use_dali", False) - self.eval_metric_func = build_metrics(self.config, "eval") + self.eval_metric_func = build_metrics(self.config, "Eval") self.eval_dataloader = build_dataloader(self.config, "Eval") self.eval_loss_func = build_loss(self.config, "Eval") self.output_info = dict() diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index d6a5c822..735beacc 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -48,12 +48,13 @@ class ClassTrainer(object): # build dataloader self.use_dali = self.config["Global"].get("use_dali", False) self.dataloader = build_dataloader(self.config, "Train") + self.dataloader_iter = iter(self.dataloader) # build loss self.loss_func = build_loss(config, "Train") # build metric - self.train_metric_func = build_metrics(config, "train") + self.train_metric_func = build_metrics(config, "Train") # build optimizer self.optimizer, self.lr_sch = build_optimizer( @@ -174,7 +175,17 @@ class ClassTrainer(object): self.model.train() tic = time.time() - for iter_id, batch in enumerate(self.dataloader): + for iter_id in range(self.dataloader.max_iter): + # fetch data batch from dataloader + try: + batch = next(self.dataloader_iter) + except Exception: + # NOTE: reset DALI dataloader manually + if self.use_dali: + self.dataloader.reset() + self.dataloader_iter = iter(self.dataloader) + batch = next(self.dataloader_iter) + profiler.add_profiler_step(self.config["profiler_options"]) if iter_id == 5: for key in self.time_info: diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index b614c116..48b37153 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -66,7 +66,7 @@ class CombinedMetrics(AvgMetrics): def build_metrics(config, mode): - if mode == 'train' and "Metric" in config and "Train" in config[ + if mode == 'Train' and "Metric" in config and "Train" in config[ "Metric"] and config["Metric"]["Train"]: metric_config = config["Metric"]["Train"] if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops", @@ -76,22 +76,17 @@ def build_metrics(config, mode): msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed." logger.warning(msg) metric_config.pop(m_idx) - train_metric_func = CombinedMetrics(copy.deepcopy(metric_config)) - return train_metric_func + return CombinedMetrics(copy.deepcopy(metric_config)) - if mode == "eval" or (mode == "train" and - config["Global"]["eval_during_train"]): - eval_mode = config["Global"].get("eval_mode", "classification") - if eval_mode == "classification": + if mode == "Eval": + task = config["Global"].get("task", "classification") + assert task in ["classification", "retrieval"] + if task == "classification": if "Metric" in config and "Eval" in config["Metric"]: - eval_metric_func = CombinedMetrics( - copy.deepcopy(config["Metric"]["Eval"])) - else: - eval_metric_func = None - elif eval_mode == "retrieval": + return CombinedMetrics(copy.deepcopy(config["Metric"]["Eval"])) + elif task == "retrieval": if "Metric" in config and "Eval" in config["Metric"]: metric_config = config["Metric"]["Eval"] else: metric_config = [{"name": "Recallk", "topk": (1, 5)}] - eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config)) - return eval_metric_func + return CombinedMetrics(copy.deepcopy(metric_config)) -- GitLab