未验证 提交 59766313 编写于 作者: J Jackwaterveg 提交者: GitHub

Merge pull request #817 from PaddlePaddle/ctc

export ctc grad norm config
......@@ -128,8 +128,8 @@ class DeepSpeech2Model(nn.Layer):
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
share_rnn_weights=True, #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
ctc_grad_norm_type='instance', ))
if config is not None:
config.merge_from_other_cfg(default)
return default
......@@ -142,7 +142,8 @@ class DeepSpeech2Model(nn.Layer):
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
blank_id=0,
ctc_grad_norm_type='instance'):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
......@@ -160,7 +161,8 @@ class DeepSpeech2Model(nn.Layer):
blank_id=blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
batch_average=True, # sum / batch_size
grad_norm_type=ctc_grad_norm_type)
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
......
......@@ -289,7 +289,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
blank_id=blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
batch_average=True, # sum / batch_size
grad_norm_type='instance')
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
......
......@@ -864,7 +864,8 @@ class U2Model(U2BaseModel):
blank_id=0,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
batch_average=True, # sum / batch_size
grad_norm_type='instance')
return vocab_size, encoder, decoder, ctc
......
......@@ -649,7 +649,8 @@ class U2STModel(U2STBaseModel):
blank_id=0,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
batch_average=True, # sum / batch_size
grad_norm_type='instance')
return vocab_size, encoder, (st_decoder, decoder, ctc)
else:
......
......@@ -39,7 +39,8 @@ class CTCDecoder(nn.Layer):
blank_id=0,
dropout_rate: float=0.0,
reduction: bool=True,
batch_average: bool=True):
batch_average: bool=True,
grad_norm_type: str="instance"):
"""CTC decoder
Args:
......@@ -48,6 +49,7 @@ class CTCDecoder(nn.Layer):
dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average.
grad_norm_type (str): one of 'instance', 'batchsize', 'frame', None.
"""
assert check_argument_types()
super().__init__()
......@@ -60,7 +62,8 @@ class CTCDecoder(nn.Layer):
self.criterion = CTCLoss(
blank=self.blank_id,
reduction=reduction_type,
batch_average=batch_average)
batch_average=batch_average,
grad_norm_type=grad_norm_type)
# CTCDecoder LM Score handle
self._ext_scorer = None
......
......@@ -23,11 +23,32 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"]
class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum', batch_average=False):
def __init__(self,
blank=0,
reduction='sum',
batch_average=False,
grad_norm_type=None):
super().__init__()
# last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
logger.info(
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
# instance for norm_by_times
# batchsize for norm_by_batchsize
# frame for norm_by_total_logits_len
assert grad_norm_type in ('instance', 'batchsize', 'frame', None)
self.norm_by_times = False
self.norm_by_batchsize = False
self.norm_by_total_logits_len = False
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance':
self.norm_by_times = True
if grad_norm_type == 'batchsize':
self.norm_by_times = True
if grad_norm_type == 'frame':
self.norm_by_total_logits_len = True
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.
......@@ -48,7 +69,13 @@ class CTCLoss(nn.Layer):
logits = logits.transpose([1, 0, 2])
ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss(
logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average)
logits,
ys_pad,
hlens,
ys_lens,
norm_by_times=self.norm_by_times,
norm_by_batchsize=self.norm_by_batchsize,
norm_by_total_logits_len=self.norm_by_total_logits_len)
if self.batch_average:
# Batch-size average
loss = loss / B
......
......@@ -41,6 +41,7 @@ model:
use_gru: True
share_rnn_weights: False
blank_id: 0
ctc_grad_norm_type: instance
training:
n_epoch: 80
......
......@@ -43,6 +43,7 @@ model:
fc_layers_size_list: -1,
use_gru: False
blank_id: 0
ctc_grad_norm_type: instance
training:
n_epoch: 50
......
......@@ -41,6 +41,7 @@ model:
use_gru: False
share_rnn_weights: True
blank_id: 0
ctc_grad_norm_type: instance
training:
n_epoch: 50
......
......@@ -43,6 +43,7 @@ model:
fc_layers_size_list: 512, 256
use_gru: False
blank_id: 0
ctc_grad_norm_type: instance
training:
n_epoch: 50
......
......@@ -42,6 +42,7 @@ model:
use_gru: False
share_rnn_weights: True
blank_id: 0
ctc_grad_norm_type: instance
training:
n_epoch: 10
......
......@@ -44,6 +44,7 @@ model:
fc_layers_size_list: 512, 256
use_gru: True
blank_id: 0
ctc_grad_norm_type: instance
training:
n_epoch: 10
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册