From f54dc983b60db54039fc8a62b5ceaf4bfa4c074b Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 2 Sep 2021 09:21:51 +0000 Subject: [PATCH] using bw rnn in ds2 --- README.md | 2 +- README_cn.md | 2 +- deepspeech/models/ds2/rnn.py | 14 +++++++------- deepspeech/modules/ctc.py | 10 +++++++--- examples/aishell/s0/README.md | 5 ++++- utils/avg.sh | 4 ++-- 6 files changed, 22 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index de24abe2..4e2a5685 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ All tested under: * Ubuntu 16.04 * python>=3.7 -* paddlepaddle>=2.1.2 +* paddlepaddle>=2.2.0rc Please see [install](doc/src/install.md). diff --git a/README_cn.md b/README_cn.md index 4b927362..e3ad2009 100644 --- a/README_cn.md +++ b/README_cn.md @@ -20,7 +20,7 @@ * Ubuntu 16.04 * python>=3.7 -* paddlepaddle>=2.1.2 +* paddlepaddle>=2.2.0rc 参看 [安装](doc/src/install.md)。 diff --git a/deepspeech/models/ds2/rnn.py b/deepspeech/models/ds2/rnn.py index 0d8c9fd2..3ff91d0a 100644 --- a/deepspeech/models/ds2/rnn.py +++ b/deepspeech/models/ds2/rnn.py @@ -29,13 +29,13 @@ __all__ = ['RNNStack'] class RNNCell(nn.RNNCellBase): r""" - Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it computes the outputs and updates states. The formula used is as follows: .. math:: h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) y_{t} & = h_{t} - + where :math:`act` is for :attr:`activation`. """ @@ -92,7 +92,7 @@ class RNNCell(nn.RNNCellBase): class GRUCell(nn.RNNCellBase): r""" - Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, it computes the outputs and updates states. The formula for GRU used is as follows: .. math:: @@ -101,8 +101,8 @@ class GRUCell(nn.RNNCellBase): \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} y_{t} & = h_{t} - - where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise + + where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise multiplication operator. """ @@ -202,7 +202,7 @@ class BiRNNWithBN(nn.Layer): self.fw_rnn = nn.RNN( self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): # x, shape [B, T, D] @@ -246,7 +246,7 @@ class BiGRUWithBN(nn.Layer): self.fw_rnn = nn.RNN( self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] def forward(self, x, x_len): # x, shape [B, T, D] diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 10e69705..10c04638 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -22,6 +22,13 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() +try: + from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 + from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401 + from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401 +except Exception as e: + logger.info("ctcdecoder not installed!") + __all__ = ['CTCDecoder'] @@ -216,9 +223,6 @@ class CTCDecoder(nn.Layer): def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, decoding_method): - from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 - from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401 - from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401 if decoding_method == "ctc_beam_search": self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index e5ebfcba..ee0f1405 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -10,8 +10,11 @@ | Model | Params | Release | Config | Test set | Loss | CER | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 | +| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 6.016139030456543 | 0.066549 | +| --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 58.4M | 7181e427 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | | DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | +| --- | --- | --- | --- | --- | --- | --- | | DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 | diff --git a/utils/avg.sh b/utils/avg.sh index 399c9574..bde9dd25 100755 --- a/utils/avg.sh +++ b/utils/avg.sh @@ -5,8 +5,8 @@ if [ $# != 3 ]; then exit -1 fi -ckpt_dir=${1} -avg_mode=${2} # best,latest +avg_mode=${1} # best,latest +ckpt_dir=${2} average_num=${3} decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams -- GitLab