提交 284e2a67 编写于 作者: G gaotingquan 提交者: Wei Shengyu

refactor: mv all dataloaders to engine.dataloader_dict

上级 efe0d45c
...@@ -187,10 +187,37 @@ def build(config, mode, use_dali=False, seed=None): ...@@ -187,10 +187,37 @@ def build(config, mode, use_dali=False, seed=None):
collate_fn=batch_collate_fn, collate_fn=batch_collate_fn,
worker_init_fn=init_fn) worker_init_fn=init_fn)
total_samples = len(
data_loader.dataset) if not use_dali else data_loader.size
max_iter = len(data_loader) - 1 if platform.system() == "Windows" else len(
data_loader)
data_loader.max_iter = max_iter
data_loader.total_samples = total_samples
logger.debug("build data_loader({}) success...".format(data_loader)) logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader return data_loader
# TODO(gaotingquan): perf
class DataIterator(object):
def __init__(self, dataloader, use_dali=False):
self.dataloader = dataloader
self.use_dali = use_dali
self.iterator = iter(dataloader)
def get_batch(self):
# fetch data batch from dataloader
try:
batch = next(self.iterator)
except Exception:
# NOTE: reset DALI dataloader manually
if self.use_dali:
self.dataloader.reset()
self.iterator = iter(self.dataloader)
batch = next(self.iterator)
return batch
def build_dataloader(engine): def build_dataloader(engine):
if "class_num" in engine.config["Global"]: if "class_num" in engine.config["Global"]:
global_class_num = engine.config["Global"]["class_num"] global_class_num = engine.config["Global"]["class_num"]
...@@ -222,12 +249,15 @@ def build_dataloader(engine): ...@@ -222,12 +249,15 @@ def build_dataloader(engine):
iter_per_epoch = len(train_dataloader) - 1 if platform.system( iter_per_epoch = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
if engine.config["Global"].get("iter_per_epoch", None): if engine.config["Global"].get("iter_per_epoch", None):
# TODO(gaotingquan): iter_per_epoch should be set in Dataloader.Train, not Global
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
iter_per_epoch = engine.config["Global"].get("iter_per_epoch") iter_per_epoch = engine.config["Global"].get("iter_per_epoch")
iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq
engine.iter_per_epoch = iter_per_epoch # engine.iter_per_epoch = iter_per_epoch
train_dataloader.iter_per_epoch = iter_per_epoch train_dataloader.iter_per_epoch = iter_per_epoch
dataloader_dict["Train"] = train_dataloader dataloader_dict["Train"] = train_dataloader
# TODO(gaotingquan): set the iterator field in config, such as Dataloader.Train.convert_iterator=True
dataloader_dict["TrainIter"] = DataIterator(train_dataloader, use_dali)
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None: if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build( dataloader_dict["UnLabelTrain"] = build(
...@@ -249,5 +279,4 @@ def build_dataloader(engine): ...@@ -249,5 +279,4 @@ def build_dataloader(engine):
engine.config["DataLoader"]["Eval"], "Gallery", use_dali) engine.config["DataLoader"]["Eval"], "Gallery", use_dali)
dataloader_dict["Query"] = build( dataloader_dict["Query"] = build(
engine.config["DataLoader"]["Eval"], "Query", use_dali) engine.config["DataLoader"]["Eval"], "Query", use_dali)
return dataloader_dict return dataloader_dict
...@@ -63,7 +63,7 @@ class Engine(object): ...@@ -63,7 +63,7 @@ class Engine(object):
# init train_func and eval_func # init train_func and eval_func
self.train_epoch_func = build_train_epoch_func(self.config) self.train_epoch_func = build_train_epoch_func(self.config)
self.eval_epoch_func = build_eval_func(self.config) self.eval_func = build_eval_func(self.config)
# set device # set device
self._init_device() self._init_device()
...@@ -73,12 +73,6 @@ class Engine(object): ...@@ -73,12 +73,6 @@ class Engine(object):
# build dataloader # build dataloader
self.dataloader_dict = build_dataloader(self) self.dataloader_dict = build_dataloader(self)
self.train_dataloader, self.unlabel_train_dataloader, self.eval_dataloader = self.dataloader_dict[
"Train"], self.dataloader_dict[
"UnLabelTrain"], self.dataloader_dict["Eval"]
self.gallery_query_dataloader, self.gallery_dataloader, self.query_dataloader = self.dataloader_dict[
"GalleryQuery"], self.dataloader_dict[
"Gallery"], self.dataloader_dict["Query"]
# build loss # build loss
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss( self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
...@@ -94,9 +88,7 @@ class Engine(object): ...@@ -94,9 +88,7 @@ class Engine(object):
self._init_pretrained() self._init_pretrained()
# build optimizer # build optimizer
self.optimizer, self.lr_sch = build_optimizer( self.optimizer, self.lr_sch = build_optimizer(self)
self.config, self.train_dataloader,
[self.model, self.train_loss_func])
# AMP training and evaluating # AMP training and evaluating
self._init_amp() self._init_amp()
......
...@@ -35,13 +35,10 @@ def classification_eval(engine, epoch_id=0): ...@@ -35,13 +35,10 @@ def classification_eval(engine, epoch_id=0):
print_batch_step = engine.config["Global"]["print_batch_step"] print_batch_step = engine.config["Global"]["print_batch_step"]
tic = time.time() tic = time.time()
total_samples = engine.dataloader_dict["Eval"].total_samples
accum_samples = 0 accum_samples = 0
total_samples = len( max_iter = engine.dataloader_dict["Eval"].max_iter
engine.eval_dataloader. for iter_id, batch in enumerate(engine.dataloader_dict["Eval"]):
dataset) if not engine.use_dali else engine.eval_dataloader.size
max_iter = len(engine.eval_dataloader) - 1 if platform.system(
) == "Windows" else len(engine.eval_dataloader)
for iter_id, batch in enumerate(engine.eval_dataloader):
if iter_id >= max_iter: if iter_id >= max_iter:
break break
if iter_id == 5: if iter_id == 5:
...@@ -61,9 +58,9 @@ def classification_eval(engine, epoch_id=0): ...@@ -61,9 +58,9 @@ def classification_eval(engine, epoch_id=0):
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
}, },
level=engine.amp_level): level=engine.amp_level):
out = engine.model(batch[0]) out = engine.model(batch)
else: else:
out = engine.model(batch[0]) out = engine.model(batch)
# just for DistributedBatchSampler issue: repeat sampling # just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size() current_samples = batch_size * paddle.distributed.get_world_size()
...@@ -95,7 +92,8 @@ def classification_eval(engine, epoch_id=0): ...@@ -95,7 +92,8 @@ def classification_eval(engine, epoch_id=0):
paddle.distributed.all_gather(pred_list, out) paddle.distributed.all_gather(pred_list, out)
preds = paddle.concat(pred_list, 0) preds = paddle.concat(pred_list, 0)
if accum_samples > total_samples and not engine.use_dali: if accum_samples > total_samples and not engine.config[
"Global"].get("use_dali", False):
if isinstance(preds, list): if isinstance(preds, list):
preds = [ preds = [
pred[:total_samples + current_samples - accum_samples] pred[:total_samples + current_samples - accum_samples]
...@@ -151,12 +149,11 @@ def classification_eval(engine, epoch_id=0): ...@@ -151,12 +149,11 @@ def classification_eval(engine, epoch_id=0):
]) ])
metric_msg += ", {}".format(engine.eval_metric_func.avg_info) metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format( logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id, epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg))
len(engine.eval_dataloader), metric_msg, time_msg, ips_msg))
tic = time.time() tic = time.time()
if engine.use_dali: if engine.config["Global"].get("use_dali", False):
engine.eval_dataloader.reset() engine.dataloader_dict["Eval"].reset()
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]: if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([ metric_msg = ", ".join([
......
...@@ -22,19 +22,8 @@ from ppcls.utils import profiler ...@@ -22,19 +22,8 @@ from ppcls.utils import profiler
def regular_train_epoch(engine, epoch_id, print_batch_step): def regular_train_epoch(engine, epoch_id, print_batch_step):
tic = time.time() tic = time.time()
if not hasattr(engine, "train_dataloader_iter"): for iter_id in range(engine.dataloader_dict["Train"].iter_per_epoch):
engine.train_dataloader_iter = iter(engine.train_dataloader) batch = engine.dataloader_dict["TrainIter"].get_batch()
for iter_id in range(engine.iter_per_epoch):
# fetch data batch from dataloader
try:
batch = next(engine.train_dataloader_iter)
except Exception:
# NOTE: reset DALI dataloader manually
if engine.use_dali:
engine.train_dataloader.reset()
engine.train_dataloader_iter = iter(engine.train_dataloader)
batch = next(engine.train_dataloader_iter)
profiler.add_profiler_step(engine.config["profiler_options"]) profiler.add_profiler_step(engine.config["profiler_options"])
if iter_id == 5: if iter_id == 5:
......
...@@ -54,13 +54,14 @@ def log_info(trainer, batch_size, epoch_id, iter_id): ...@@ -54,13 +54,14 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
ips_msg = "ips: {:.5f} samples/s".format( ips_msg = "ips: {:.5f} samples/s".format(
batch_size / trainer.time_info["batch_cost"].avg) batch_size / trainer.time_info["batch_cost"].avg)
eta_sec = ( eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1
(trainer.config["Global"]["epochs"] - epoch_id + 1) * ) * trainer.dataloader_dict["Train"].iter_per_epoch - iter_id
trainer.iter_per_epoch - iter_id) * trainer.time_info["batch_cost"].avg ) * trainer.time_info["batch_cost"].avg
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec)))) eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format( logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
epoch_id, trainer.config["Global"]["epochs"], iter_id, trainer. epoch_id, trainer.config["Global"][
iter_per_epoch, lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) "epochs"], iter_id, trainer.dataloader_dict["Train"]
.iter_per_epoch, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
for i, lr in enumerate(trainer.lr_sch): for i, lr in enumerate(trainer.lr_sch):
logger.scaler( logger.scaler(
......
...@@ -70,8 +70,9 @@ def build_metrics(engine): ...@@ -70,8 +70,9 @@ def build_metrics(engine):
if mode == 'train' and "Metric" in config and "Train" in config[ if mode == 'train' and "Metric" in config and "Train" in config[
"Metric"] and config["Metric"]["Train"]: "Metric"] and config["Metric"]["Train"]:
metric_config = config["Metric"]["Train"] metric_config = config["Metric"]["Train"]
if hasattr(engine.train_dataloader, "collate_fn" if hasattr(engine.dataloader_dict["Train"],
) and engine.train_dataloader.collate_fn is not None: "collate_fn") and engine.dataloader_dict[
"Train"].collate_fn is not None:
for m_idx, m in enumerate(metric_config): for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m: if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed." msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
......
...@@ -45,11 +45,15 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -45,11 +45,15 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph # model_list is None in static graph
def build_optimizer(config, dataloader, model_list=None): def build_optimizer(engine):
if engine.mode != "train":
return None, None
config, iter_per_epoch, model_list = engine.config, engine.dataloader_dict[
"Train"].iter_per_epoch, [engine.mode, engine.train_loss_func]
optim_config = copy.deepcopy(config["Optimizer"]) optim_config = copy.deepcopy(config["Optimizer"])
epochs = config["Global"]["epochs"] epochs = config["Global"]["epochs"]
update_freq = config["Global"].get("update_freq", 1) update_freq = config["Global"].get("update_freq", 1)
step_each_epoch = dataloader.iter_per_epoch // update_freq step_each_epoch = iter_per_epoch // update_freq
if isinstance(optim_config, dict): if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}] # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name") optim_name = optim_config.pop("name")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册