提交 d3374e89 编写于 作者: G gaotingquan 提交者: Wei Shengyu

revert for running

上级 d3b7690f
......@@ -207,25 +207,3 @@ def build_dataloader(config, mode, seed=None):
logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader
# # TODO(gaotingquan): the length of dataloader should be determined by sampler
# class DataIterator(object):
# def __init__(self, dataloader, use_dali=False):
# self.dataloader = dataloader
# self.use_dali = use_dali
# self.iterator = iter(dataloader)
# self.max_iter = dataloader.max_iter
# self.total_samples = dataloader.total_samples
# def get_batch(self):
# # fetch data batch from dataloader
# try:
# batch = next(self.iterator)
# except Exception:
# # NOTE: reset DALI dataloader manually
# if self.use_dali:
# self.dataloader.reset()
# self.iterator = iter(self.dataloader)
# batch = next(self.iterator)
# return batch
......@@ -31,7 +31,7 @@ class ClassEval(object):
self.model = model
self.print_batch_step = self.config["Global"]["print_batch_step"]
self.use_dali = self.config["Global"].get("use_dali", False)
self.eval_metric_func = build_metrics(self.config, "eval")
self.eval_metric_func = build_metrics(self.config, "Eval")
self.eval_dataloader = build_dataloader(self.config, "Eval")
self.eval_loss_func = build_loss(self.config, "Eval")
self.output_info = dict()
......
......@@ -48,12 +48,13 @@ class ClassTrainer(object):
# build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader = build_dataloader(self.config, "Train")
self.dataloader_iter = iter(self.dataloader)
# build loss
self.loss_func = build_loss(config, "Train")
# build metric
self.train_metric_func = build_metrics(config, "train")
self.train_metric_func = build_metrics(config, "Train")
# build optimizer
self.optimizer, self.lr_sch = build_optimizer(
......@@ -174,7 +175,17 @@ class ClassTrainer(object):
self.model.train()
tic = time.time()
for iter_id, batch in enumerate(self.dataloader):
for iter_id in range(self.dataloader.max_iter):
# fetch data batch from dataloader
try:
batch = next(self.dataloader_iter)
except Exception:
# NOTE: reset DALI dataloader manually
if self.use_dali:
self.dataloader.reset()
self.dataloader_iter = iter(self.dataloader)
batch = next(self.dataloader_iter)
profiler.add_profiler_step(self.config["profiler_options"])
if iter_id == 5:
for key in self.time_info:
......
......@@ -66,7 +66,7 @@ class CombinedMetrics(AvgMetrics):
def build_metrics(config, mode):
if mode == 'train' and "Metric" in config and "Train" in config[
if mode == 'Train' and "Metric" in config and "Train" in config[
"Metric"] and config["Metric"]["Train"]:
metric_config = config["Metric"]["Train"]
if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops",
......@@ -76,22 +76,17 @@ def build_metrics(config, mode):
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
return train_metric_func
return CombinedMetrics(copy.deepcopy(metric_config))
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
eval_mode = config["Global"].get("eval_mode", "classification")
if eval_mode == "classification":
if mode == "Eval":
task = config["Global"].get("task", "classification")
assert task in ["classification", "retrieval"]
if task == "classification":
if "Metric" in config and "Eval" in config["Metric"]:
eval_metric_func = CombinedMetrics(
copy.deepcopy(config["Metric"]["Eval"]))
else:
eval_metric_func = None
elif eval_mode == "retrieval":
return CombinedMetrics(copy.deepcopy(config["Metric"]["Eval"]))
elif task == "retrieval":
if "Metric" in config and "Eval" in config["Metric"]:
metric_config = config["Metric"]["Eval"]
else:
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
return eval_metric_func
return CombinedMetrics(copy.deepcopy(metric_config))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册