提交 2480be8e 编写于 作者: H Hui Zhang

timer info for st,u2 kaldi

上级 28a0a641
...@@ -32,6 +32,7 @@ from deepspeech.io.dataloader import BatchDataLoader ...@@ -32,6 +32,7 @@ from deepspeech.io.dataloader import BatchDataLoader
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
...@@ -190,35 +191,37 @@ class U2Trainer(Trainer): ...@@ -190,35 +191,37 @@ class U2Trainer(Trainer):
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
self.model.train() with Timer("Epoch-Train Time Cost: {}"):
try: self.model.train()
data_start_time = time.time() try:
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: for batch_index, batch in enumerate(self.train_loader):
logger.error(e) dataload_time = time.time() - data_start_time
raise e msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
total_loss, num_seen_utts = self.valid() msg += "step: {}, ".format(self.iteration)
if dist.get_world_size() > 1: msg += "batch : {}/{}, ".format(batch_index + 1,
num_seen_utts = paddle.to_tensor(num_seen_utts) len(self.train_loader))
# the default operator in all_reduce function is sum. msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
dist.all_reduce(num_seen_utts) msg += "data time: {:>.3f}s, ".format(dataload_time)
total_loss = paddle.to_tensor(total_loss) self.train_batch(batch_index, batch, msg)
dist.all_reduce(total_loss) data_start_time = time.time()
cv_loss = total_loss / num_seen_utts except Exception as e:
cv_loss = float(cv_loss) logger.error(e)
else: raise e
cv_loss = total_loss / num_seen_utts
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
......
...@@ -38,6 +38,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler ...@@ -38,6 +38,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2_st import U2STModel from deepspeech.models.u2_st import U2STModel
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import bleu_score from deepspeech.utils import bleu_score
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
...@@ -207,35 +208,37 @@ class U2STTrainer(Trainer): ...@@ -207,35 +208,37 @@ class U2STTrainer(Trainer):
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
self.model.train() with Timer("Epoch-Train Time Cost: {}"):
try: self.model.train()
data_start_time = time.time() try:
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: for batch_index, batch in enumerate(self.train_loader):
logger.error(e) dataload_time = time.time() - data_start_time
raise e msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
total_loss, num_seen_utts = self.valid() msg += "step: {}, ".format(self.iteration)
if dist.get_world_size() > 1: msg += "batch : {}/{}, ".format(batch_index + 1,
num_seen_utts = paddle.to_tensor(num_seen_utts) len(self.train_loader))
# the default operator in all_reduce function is sum. msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
dist.all_reduce(num_seen_utts) msg += "data time: {:>.3f}s, ".format(dataload_time)
total_loss = paddle.to_tensor(total_loss) self.train_batch(batch_index, batch, msg)
dist.all_reduce(total_loss) data_start_time = time.time()
cv_loss = total_loss / num_seen_utts except Exception as e:
cv_loss = float(cv_loss) logger.error(e)
else: raise e
cv_loss = total_loss / num_seen_utts
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册