提交 daaa72a6 编写于 作者: H Hui Zhang

resuem train with epoch and iteration increase

上级 3432de43
......@@ -183,15 +183,7 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.iteration)
if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......
......@@ -184,11 +184,7 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......
......@@ -198,14 +198,7 @@ class U2STTrainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......
......@@ -179,7 +179,8 @@ class Trainer():
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path)
if infos:
# restore from ckpt
# just restore ckpt
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
scratch = False
......@@ -190,14 +191,31 @@ class Trainer():
logger.info("Restore/Init checkpoint!")
return scratch
def maybe_batch_sampler_step(self):
""" batch_sampler seed by epoch """
if hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def before_train(self, from_scratch):
from_scratch = self.resume_or_scratch()
if from_scratch:
# scratch: save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
else:
# resume: train next_epoch and next_iteration
self.epoch += 1
self.iteration += 1
self.maybe_batch_sampler_step()
def new_epoch(self):
"""Reset the train loader seed and increment `epoch`.
"""
# `iteration` increased by train step
self.epoch += 1
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
self.maybe_batch_sampler_step()
def after_train_batch(self):
if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
......@@ -209,15 +227,7 @@ class Trainer():
def train(self):
"""The training process control by epoch."""
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.epoch)
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......@@ -275,6 +285,7 @@ class Trainer():
'epoch', {'cv_loss': cv_loss,
'lr': self.lr_scheduler()}, self.epoch)
# after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
# step lr every epoch
self.lr_scheduler.step()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册