未验证 提交 ecb5d4f8 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #847 from PaddlePaddle/resume_train

resume train with epoch and iteration increase
......@@ -183,15 +183,7 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.iteration)
if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......@@ -207,8 +199,8 @@ class U2Trainer(Trainer):
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report('step/total',
(batch_index + 1) / len(self.train_loader))
report('iter', batch_index + 1)
report('total',len(self.train_loader))
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
......
......@@ -184,11 +184,7 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......
......@@ -198,14 +198,7 @@ class U2STTrainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......
......@@ -179,25 +179,47 @@ class Trainer():
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path)
if infos:
# restore from ckpt
# just restore ckpt
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
else:
self.iteration = 0
self.epoch = 0
scratch = True
logger.info("Restore/Init checkpoint!")
logger.info("Init from scratch!")
return scratch
def maybe_batch_sampler_step(self):
""" batch_sampler seed by epoch """
if hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def before_train(self):
from_scratch = self.resume_or_scratch()
if from_scratch:
# scratch: save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
else:
# resume: train next_epoch and next_iteration
self.epoch += 1
self.iteration += 1
logger.info(
f"Resume train: epoch {self.epoch }, step {self.iteration}!")
self.maybe_batch_sampler_step()
def new_epoch(self):
"""Reset the train loader seed and increment `epoch`.
"""
# `iteration` increased by train step
self.epoch += 1
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
self.maybe_batch_sampler_step()
def after_train_batch(self):
if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
......@@ -209,15 +231,7 @@ class Trainer():
def train(self):
"""The training process control by epoch."""
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.epoch)
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.before_train()
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......@@ -233,8 +247,8 @@ class Trainer():
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report('step/total',
(batch_index + 1) / len(self.train_loader))
report('iter', batch_index + 1)
report('total',len(self.train_loader))
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
......@@ -275,6 +289,7 @@ class Trainer():
'epoch', {'cv_loss': cv_loss,
'lr': self.lr_scheduler()}, self.epoch)
# after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
# step lr every epoch
self.lr_scheduler.step()
......@@ -288,7 +303,6 @@ class Trainer():
try:
self.train()
except KeyboardInterrupt:
self.save()
exit(-1)
finally:
self.destory()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册