提交 8873ebe3 编写于 作者: H Hui Zhang

add timer for u2; refactor grad norm type

上级 890a28f9
......@@ -34,6 +34,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
......@@ -184,11 +185,12 @@ class U2Trainer(Trainer):
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
try:
data_start_time = time.time()
......@@ -207,6 +209,7 @@ class U2Trainer(Trainer):
logger.error(e)
raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
......
......@@ -36,16 +36,16 @@ class CTCLoss(nn.Layer):
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
# instance for norm_by_times
# batchsize for norm_by_batchsize
# batch for norm_by_batchsize
# frame for norm_by_total_logits_len
assert grad_norm_type in ('instance', 'batchsize', 'frame', None)
assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False
self.norm_by_batchsize = False
self.norm_by_total_logits_len = False
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance':
self.norm_by_times = True
if grad_norm_type == 'batchsize':
if grad_norm_type == 'batch':
self.norm_by_times = True
if grad_norm_type == 'frame':
self.norm_by_total_logits_len = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册