提交 03795249 编写于 作者: T Tingquan Gao

Revert "revert for running"

This reverts commit d3374e89.
上级 7353b073
...@@ -207,3 +207,25 @@ def build_dataloader(config, mode, seed=None): ...@@ -207,3 +207,25 @@ def build_dataloader(config, mode, seed=None):
logger.debug("build data_loader({}) success...".format(data_loader)) logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader return data_loader
# # TODO(gaotingquan): the length of dataloader should be determined by sampler
# class DataIterator(object):
# def __init__(self, dataloader, use_dali=False):
# self.dataloader = dataloader
# self.use_dali = use_dali
# self.iterator = iter(dataloader)
# self.max_iter = dataloader.max_iter
# self.total_samples = dataloader.total_samples
# def get_batch(self):
# # fetch data batch from dataloader
# try:
# batch = next(self.iterator)
# except Exception:
# # NOTE: reset DALI dataloader manually
# if self.use_dali:
# self.dataloader.reset()
# self.iterator = iter(self.dataloader)
# batch = next(self.iterator)
# return batch
...@@ -31,7 +31,7 @@ class ClassEval(object): ...@@ -31,7 +31,7 @@ class ClassEval(object):
self.model = model self.model = model
self.print_batch_step = self.config["Global"]["print_batch_step"] self.print_batch_step = self.config["Global"]["print_batch_step"]
self.use_dali = self.config["Global"].get("use_dali", False) self.use_dali = self.config["Global"].get("use_dali", False)
self.eval_metric_func = build_metrics(self.config, "Eval") self.eval_metric_func = build_metrics(self.config, "eval")
self.eval_dataloader = build_dataloader(self.config, "Eval") self.eval_dataloader = build_dataloader(self.config, "Eval")
self.eval_loss_func = build_loss(self.config, "Eval") self.eval_loss_func = build_loss(self.config, "Eval")
self.output_info = dict() self.output_info = dict()
......
...@@ -48,13 +48,12 @@ class ClassTrainer(object): ...@@ -48,13 +48,12 @@ class ClassTrainer(object):
# build dataloader # build dataloader
self.use_dali = self.config["Global"].get("use_dali", False) self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader = build_dataloader(self.config, "Train") self.dataloader = build_dataloader(self.config, "Train")
self.dataloader_iter = iter(self.dataloader)
# build loss # build loss
self.loss_func = build_loss(config, "Train") self.loss_func = build_loss(config, "Train")
# build metric # build metric
self.train_metric_func = build_metrics(config, "Train") self.train_metric_func = build_metrics(config, "train")
# build optimizer # build optimizer
self.optimizer, self.lr_sch = build_optimizer( self.optimizer, self.lr_sch = build_optimizer(
...@@ -175,17 +174,7 @@ class ClassTrainer(object): ...@@ -175,17 +174,7 @@ class ClassTrainer(object):
self.model.train() self.model.train()
tic = time.time() tic = time.time()
for iter_id in range(self.dataloader.max_iter): for iter_id, batch in enumerate(self.dataloader):
# fetch data batch from dataloader
try:
batch = next(self.dataloader_iter)
except Exception:
# NOTE: reset DALI dataloader manually
if self.use_dali:
self.dataloader.reset()
self.dataloader_iter = iter(self.dataloader)
batch = next(self.dataloader_iter)
profiler.add_profiler_step(self.config["profiler_options"]) profiler.add_profiler_step(self.config["profiler_options"])
if iter_id == 5: if iter_id == 5:
for key in self.time_info: for key in self.time_info:
......
...@@ -66,7 +66,7 @@ class CombinedMetrics(AvgMetrics): ...@@ -66,7 +66,7 @@ class CombinedMetrics(AvgMetrics):
def build_metrics(config, mode): def build_metrics(config, mode):
if mode == 'Train' and "Metric" in config and "Train" in config[ if mode == 'train' and "Metric" in config and "Train" in config[
"Metric"] and config["Metric"]["Train"]: "Metric"] and config["Metric"]["Train"]:
metric_config = config["Metric"]["Train"] metric_config = config["Metric"]["Train"]
if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops", if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops",
...@@ -76,17 +76,22 @@ def build_metrics(config, mode): ...@@ -76,17 +76,22 @@ def build_metrics(config, mode):
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed." msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg) logger.warning(msg)
metric_config.pop(m_idx) metric_config.pop(m_idx)
return CombinedMetrics(copy.deepcopy(metric_config)) train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
return train_metric_func
if mode == "Eval": if mode == "eval" or (mode == "train" and
task = config["Global"].get("task", "classification") config["Global"]["eval_during_train"]):
assert task in ["classification", "retrieval"] eval_mode = config["Global"].get("eval_mode", "classification")
if task == "classification": if eval_mode == "classification":
if "Metric" in config and "Eval" in config["Metric"]: if "Metric" in config and "Eval" in config["Metric"]:
return CombinedMetrics(copy.deepcopy(config["Metric"]["Eval"])) eval_metric_func = CombinedMetrics(
elif task == "retrieval": copy.deepcopy(config["Metric"]["Eval"]))
else:
eval_metric_func = None
elif eval_mode == "retrieval":
if "Metric" in config and "Eval" in config["Metric"]: if "Metric" in config and "Eval" in config["Metric"]:
metric_config = config["Metric"]["Eval"] metric_config = config["Metric"]["Eval"]
else: else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}] metric_config = [{"name": "Recallk", "topk": (1, 5)}]
return CombinedMetrics(copy.deepcopy(metric_config)) eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
return eval_metric_func
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册