提交 8873ebe3 编写于 作者: H Hui Zhang

add timer for u2; refactor grad norm type

上级 890a28f9
......@@ -34,6 +34,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
......@@ -184,40 +185,42 @@ class U2Trainer(Trainer):
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
self.model.train()
try:
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
try:
data_start_time = time.time()
except Exception as e:
logger.error(e)
raise e
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
data_start_time = time.time()
except Exception as e:
logger.error(e)
raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
......
......@@ -36,16 +36,16 @@ class CTCLoss(nn.Layer):
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
# instance for norm_by_times
# batchsize for norm_by_batchsize
# batch for norm_by_batchsize
# frame for norm_by_total_logits_len
assert grad_norm_type in ('instance', 'batchsize', 'frame', None)
assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False
self.norm_by_batchsize = False
self.norm_by_total_logits_len = False
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance':
self.norm_by_times = True
if grad_norm_type == 'batchsize':
if grad_norm_type == 'batch':
self.norm_by_times = True
if grad_norm_type == 'frame':
self.norm_by_total_logits_len = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册