未验证 提交 e0a87a5a 编写于 作者: H Hui Zhang 提交者: GitHub

batch average ctc loss (#567)

* when loss div batchsize, change lr, more epoch, loss can reduce more and cer lower than before

* since loss reduce more when loss div batchsize,  less lm alpha can be better.

* less lm alpha, more cer reduce

* alpha 2.2, cer 0.077478

* alpha 1.9, cer 0.077249

* large librispeech lr for batch_average ctc loss

* since loss reduce and model more confidence, then less lm alpha
上级 258307df
...@@ -39,7 +39,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler ...@@ -39,7 +39,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.modules.loss import CTCLoss
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.deepspeech2 import DeepSpeech2InferModel from deepspeech.models.deepspeech2 import DeepSpeech2InferModel
......
...@@ -170,7 +170,8 @@ class DeepSpeech2Model(nn.Layer): ...@@ -170,7 +170,8 @@ class DeepSpeech2Model(nn.Layer):
odim=dict_size + 1, # <blank> is append after vocab odim=dict_size + 1, # <blank> is append after vocab
blank_id=dict_size, # last token is <blank> blank_id=dict_size, # last token is <blank>
dropout_rate=0.0, dropout_rate=0.0,
reduction=True) reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, text, audio_len, text_len): def forward(self, audio, text, audio_len, text_len):
"""Compute Model loss """Compute Model loss
......
...@@ -36,14 +36,16 @@ class CTCDecoder(nn.Layer): ...@@ -36,14 +36,16 @@ class CTCDecoder(nn.Layer):
odim, odim,
blank_id=0, blank_id=0,
dropout_rate: float=0.0, dropout_rate: float=0.0,
reduction: bool=True): reduction: bool=True,
batch_average: bool=False):
"""CTC decoder """CTC decoder
Args: Args:
enc_n_units ([int]): encoder output dimention enc_n_units ([int]): encoder output dimention
vocab_size ([int]): text vocabulary size vocab_size ([int]): text vocabulary size
dropout_rate (float): dropout rate (0.0 ~ 1.0) dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average.
""" """
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
...@@ -53,7 +55,10 @@ class CTCDecoder(nn.Layer): ...@@ -53,7 +55,10 @@ class CTCDecoder(nn.Layer):
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.ctc_lo = nn.Linear(enc_n_units, self.odim) self.ctc_lo = nn.Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none" reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type) self.criterion = CTCLoss(
blank=self.blank_id,
reduction=reduction_type,
batch_average=batch_average)
# CTCDecoder LM Score handle # CTCDecoder LM Score handle
self._ext_scorer = None self._ext_scorer = None
......
...@@ -53,10 +53,11 @@ F.ctc_loss = ctc_loss ...@@ -53,10 +53,11 @@ F.ctc_loss = ctc_loss
class CTCLoss(nn.Layer): class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum'): def __init__(self, blank=0, reduction='sum', batch_average=False):
super().__init__() super().__init__()
# last token id as blank id # last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
def forward(self, logits, ys_pad, hlens, ys_lens): def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss. """Compute CTC loss.
...@@ -76,8 +77,7 @@ class CTCLoss(nn.Layer): ...@@ -76,8 +77,7 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D) # logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2]) logits = logits.transpose([1, 0, 2])
loss = self.loss(logits, ys_pad, hlens, ys_lens) loss = self.loss(logits, ys_pad, hlens, ys_lens)
if self.batch_average:
# wenet do batch-size average, deepspeech2 not do this # Batch-size average
# Batch-size average loss = loss / B
# loss = loss / B
return loss return loss
# Aishell-1 # Aishell-1
## CTC ## CTC
| Model | Config | Test set | CER | | Model | Config | Test Set | CER | Valid Loss |
| --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test | 0.078977 | | DeepSpeech2 | conf/deepspeech2.yaml | test | 0.077249 | 7.036566 |
| DeepSpeech2 | release 1.8.5 | test | 0.080447 | | DeepSpeech2 | release 1.8.5 | test | 0.087004 | 8.575452 |
...@@ -29,8 +29,8 @@ model: ...@@ -29,8 +29,8 @@ model:
use_gru: True use_gru: True
share_rnn_weights: False share_rnn_weights: False
training: training:
n_epoch: 30 n_epoch: 50
lr: 5e-4 lr: 2e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
...@@ -39,7 +39,7 @@ decoding: ...@@ -39,7 +39,7 @@ decoding:
error_rate_type: cer error_rate_type: cer
decoding_method: ctc_beam_search decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 2.6 alpha: 1.9
beta: 5.0 beta: 5.0
beam_size: 300 beam_size: 300
cutoff_prob: 0.99 cutoff_prob: 0.99
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# train model # train model
# if you wish to resume from an exists model, uncomment --init_from_pretrained_model # if you wish to resume from an exists model, uncomment --init_from_pretrained_model
export FLAGS_sync_nccl_allreduce=0 #export FLAGS_sync_nccl_allreduce=0
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
......
# LibriSpeech # LibriSpeech
## CTC ## CTC
| Model | Config | Test set | WER | | Model | Config | Test Set | WER | Valid Loss |
| --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 | | DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.069357 | 15.078561 |
| DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 | | DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 | 15.351633 |
...@@ -29,8 +29,8 @@ model: ...@@ -29,8 +29,8 @@ model:
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
training: training:
n_epoch: 20 n_epoch: 50
lr: 5e-4 lr: 1e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
...@@ -39,7 +39,7 @@ decoding: ...@@ -39,7 +39,7 @@ decoding:
error_rate_type: wer error_rate_type: wer
decoding_method: ctc_beam_search decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5 alpha: 1.9
beta: 0.3 beta: 0.3
beam_size: 500 beam_size: 500
cutoff_prob: 1.0 cutoff_prob: 1.0
......
#! /usr/bin/env bash #! /usr/bin/env bash
export FLAGS_sync_nccl_allreduce=0 #export FLAGS_sync_nccl_allreduce=0
# https://github.com/PaddlePaddle/Paddle/pull/28484 # https://github.com/PaddlePaddle/Paddle/pull/28484
export NCCL_SHM_DISABLE=1 #export NCCL_SHM_DISABLE=1
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册