提交 a38e42f6 编写于 作者: G gaotingquan 提交者: Wei Shengyu

refactor: iter_per_epoch -> max_iter

上级 284e2a67
...@@ -204,6 +204,8 @@ class DataIterator(object): ...@@ -204,6 +204,8 @@ class DataIterator(object):
self.dataloader = dataloader self.dataloader = dataloader
self.use_dali = use_dali self.use_dali = use_dali
self.iterator = iter(dataloader) self.iterator = iter(dataloader)
self.max_iter = dataloader.max_iter
self.total_samples = dataloader.total_samples
def get_batch(self): def get_batch(self):
# fetch data batch from dataloader # fetch data batch from dataloader
...@@ -234,7 +236,7 @@ def build_dataloader(engine): ...@@ -234,7 +236,7 @@ def build_dataloader(engine):
"epochs": engine.config["Global"]["epochs"] "epochs": engine.config["Global"]["epochs"]
}) })
use_dali = engine.config['Global'].get("use_dali", False) use_dali = engine.use_dali
dataloader_dict = { dataloader_dict = {
"Train": None, "Train": None,
"UnLabelTrain": None, "UnLabelTrain": None,
...@@ -246,18 +248,15 @@ def build_dataloader(engine): ...@@ -246,18 +248,15 @@ def build_dataloader(engine):
if engine.mode == 'train': if engine.mode == 'train':
train_dataloader = build( train_dataloader = build(
engine.config["DataLoader"], "Train", use_dali, seed=None) engine.config["DataLoader"], "Train", use_dali, seed=None)
iter_per_epoch = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) if engine.config["DataLoader"]["Train"].get("max_iter", None):
if engine.config["Global"].get("iter_per_epoch", None):
# TODO(gaotingquan): iter_per_epoch should be set in Dataloader.Train, not Global
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
iter_per_epoch = engine.config["Global"].get("iter_per_epoch") max_iter = engine.config["Train"].get("max_iter")
iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq max_iter = train_dataloader.max_iter // engine.update_freq * engine.update_freq
# engine.iter_per_epoch = iter_per_epoch train_dataloader.max_iter = max_iter
train_dataloader.iter_per_epoch = iter_per_epoch if engine.config["DataLoader"]["Train"].get("convert_iterator", True):
train_dataloader = DataIterator(train_dataloader, use_dali)
dataloader_dict["Train"] = train_dataloader dataloader_dict["Train"] = train_dataloader
# TODO(gaotingquan): set the iterator field in config, such as Dataloader.Train.convert_iterator=True
dataloader_dict["TrainIter"] = DataIterator(train_dataloader, use_dali)
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None: if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
dataloader_dict["UnLabelTrain"] = build( dataloader_dict["UnLabelTrain"] = build(
......
...@@ -72,6 +72,7 @@ class Engine(object): ...@@ -72,6 +72,7 @@ class Engine(object):
self.update_freq = self.config["Global"].get("update_freq", 1) self.update_freq = self.config["Global"].get("update_freq", 1)
# build dataloader # build dataloader
self.use_dali = self.config["Global"].get("use_dali", False)
self.dataloader_dict = build_dataloader(self) self.dataloader_dict = build_dataloader(self)
# build loss # build loss
......
...@@ -92,8 +92,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -92,8 +92,7 @@ def classification_eval(engine, epoch_id=0):
paddle.distributed.all_gather(pred_list, out) paddle.distributed.all_gather(pred_list, out)
preds = paddle.concat(pred_list, 0) preds = paddle.concat(pred_list, 0)
if accum_samples > total_samples and not engine.config[ if accum_samples > total_samples and not engine.use_dali:
"Global"].get("use_dali", False):
if isinstance(preds, list): if isinstance(preds, list):
preds = [ preds = [
pred[:total_samples + current_samples - accum_samples] pred[:total_samples + current_samples - accum_samples]
...@@ -152,7 +151,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -152,7 +151,7 @@ def classification_eval(engine, epoch_id=0):
epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg)) epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg))
tic = time.time() tic = time.time()
if engine.config["Global"].get("use_dali", False): if engine.use_dali:
engine.dataloader_dict["Eval"].reset() engine.dataloader_dict["Eval"].reset()
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]: if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
......
...@@ -22,8 +22,8 @@ from ppcls.utils import profiler ...@@ -22,8 +22,8 @@ from ppcls.utils import profiler
def regular_train_epoch(engine, epoch_id, print_batch_step): def regular_train_epoch(engine, epoch_id, print_batch_step):
tic = time.time() tic = time.time()
for iter_id in range(engine.dataloader_dict["Train"].iter_per_epoch): for iter_id in range(engine.dataloader_dict["Train"].max_iter):
batch = engine.dataloader_dict["TrainIter"].get_batch() batch = engine.dataloader_dict["Train"].get_batch()
profiler.add_profiler_step(engine.config["profiler_options"]) profiler.add_profiler_step(engine.config["profiler_options"])
if iter_id == 5: if iter_id == 5:
......
...@@ -55,13 +55,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id): ...@@ -55,13 +55,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
batch_size / trainer.time_info["batch_cost"].avg) batch_size / trainer.time_info["batch_cost"].avg)
eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1 eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1
) * trainer.dataloader_dict["Train"].iter_per_epoch - iter_id ) * trainer.dataloader_dict["Train"].max_iter - iter_id
) * trainer.time_info["batch_cost"].avg ) * trainer.time_info["batch_cost"].avg
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec)))) eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format( logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
epoch_id, trainer.config["Global"][ epoch_id, trainer.config["Global"][
"epochs"], iter_id, trainer.dataloader_dict["Train"] "epochs"], iter_id, trainer.dataloader_dict["Train"]
.iter_per_epoch, lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) .max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
for i, lr in enumerate(trainer.lr_sch): for i, lr in enumerate(trainer.lr_sch):
logger.scaler( logger.scaler(
......
...@@ -48,12 +48,12 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -48,12 +48,12 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
def build_optimizer(engine): def build_optimizer(engine):
if engine.mode != "train": if engine.mode != "train":
return None, None return None, None
config, iter_per_epoch, model_list = engine.config, engine.dataloader_dict[ config, max_iter, model_list = engine.config, engine.dataloader_dict[
"Train"].iter_per_epoch, [engine.mode, engine.train_loss_func] "Train"].max_iter, [engine.model, engine.train_loss_func]
optim_config = copy.deepcopy(config["Optimizer"]) optim_config = copy.deepcopy(config["Optimizer"])
epochs = config["Global"]["epochs"] epochs = config["Global"]["epochs"]
update_freq = config["Global"].get("update_freq", 1) update_freq = engine.update_freq
step_each_epoch = iter_per_epoch // update_freq step_each_epoch = max_iter // update_freq
if isinstance(optim_config, dict): if isinstance(optim_config, dict):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}] # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name = optim_config.pop("name") optim_name = optim_config.pop("name")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册