diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 9915dc2c40536599c7d72a2618b00fed5916f99b..5bc21e15444c51e90218c2a9a97ff8c2bd6b912f 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -207,25 +207,3 @@ def build_dataloader(config, mode, seed=None): logger.debug("build data_loader({}) success...".format(data_loader)) return data_loader - - -# # TODO(gaotingquan): the length of dataloader should be determined by sampler -# class DataIterator(object): -# def __init__(self, dataloader, use_dali=False): -# self.dataloader = dataloader -# self.use_dali = use_dali -# self.iterator = iter(dataloader) -# self.max_iter = dataloader.max_iter -# self.total_samples = dataloader.total_samples - -# def get_batch(self): -# # fetch data batch from dataloader -# try: -# batch = next(self.iterator) -# except Exception: -# # NOTE: reset DALI dataloader manually -# if self.use_dali: -# self.dataloader.reset() -# self.iterator = iter(self.dataloader) -# batch = next(self.iterator) -# return batch diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 07f1703f8acdd29b6a69680e366a9824a731a101..a35ff1ea63d3f4a7c78497af3788f5ac4e9a1722 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -31,7 +31,7 @@ class ClassEval(object): self.model = model self.print_batch_step = self.config["Global"]["print_batch_step"] self.use_dali = self.config["Global"].get("use_dali", False) - self.eval_metric_func = build_metrics(self.config, "eval") + self.eval_metric_func = build_metrics(self.config, "Eval") self.eval_dataloader = build_dataloader(self.config, "Eval") self.eval_loss_func = build_loss(self.config, "Eval") self.output_info = dict() diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index d6a5c8228fa6acff2b923361613b4650219f5fa3..735beaccc4d8c307d3ce831f55f3d96bd5d03626 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -48,12 +48,13 @@ class ClassTrainer(object): # build dataloader self.use_dali = self.config["Global"].get("use_dali", False) self.dataloader = build_dataloader(self.config, "Train") + self.dataloader_iter = iter(self.dataloader) # build loss self.loss_func = build_loss(config, "Train") # build metric - self.train_metric_func = build_metrics(config, "train") + self.train_metric_func = build_metrics(config, "Train") # build optimizer self.optimizer, self.lr_sch = build_optimizer( @@ -174,7 +175,17 @@ class ClassTrainer(object): self.model.train() tic = time.time() - for iter_id, batch in enumerate(self.dataloader): + for iter_id in range(self.dataloader.max_iter): + # fetch data batch from dataloader + try: + batch = next(self.dataloader_iter) + except Exception: + # NOTE: reset DALI dataloader manually + if self.use_dali: + self.dataloader.reset() + self.dataloader_iter = iter(self.dataloader) + batch = next(self.dataloader_iter) + profiler.add_profiler_step(self.config["profiler_options"]) if iter_id == 5: for key in self.time_info: diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index b614c116efe3ad55e094cdbc3c0c24ad567634e9..48b37153ac0040ee5b09ec2ed3e82cf5024c7e77 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -66,7 +66,7 @@ class CombinedMetrics(AvgMetrics): def build_metrics(config, mode): - if mode == 'train' and "Metric" in config and "Train" in config[ + if mode == 'Train' and "Metric" in config and "Train" in config[ "Metric"] and config["Metric"]["Train"]: metric_config = config["Metric"]["Train"] if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops", @@ -76,22 +76,17 @@ def build_metrics(config, mode): msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed." logger.warning(msg) metric_config.pop(m_idx) - train_metric_func = CombinedMetrics(copy.deepcopy(metric_config)) - return train_metric_func + return CombinedMetrics(copy.deepcopy(metric_config)) - if mode == "eval" or (mode == "train" and - config["Global"]["eval_during_train"]): - eval_mode = config["Global"].get("eval_mode", "classification") - if eval_mode == "classification": + if mode == "Eval": + task = config["Global"].get("task", "classification") + assert task in ["classification", "retrieval"] + if task == "classification": if "Metric" in config and "Eval" in config["Metric"]: - eval_metric_func = CombinedMetrics( - copy.deepcopy(config["Metric"]["Eval"])) - else: - eval_metric_func = None - elif eval_mode == "retrieval": + return CombinedMetrics(copy.deepcopy(config["Metric"]["Eval"])) + elif task == "retrieval": if "Metric" in config and "Eval" in config["Metric"]: metric_config = config["Metric"]["Eval"] else: metric_config = [{"name": "Recallk", "topk": (1, 5)}] - eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config)) - return eval_metric_func + return CombinedMetrics(copy.deepcopy(metric_config))