提交 f54dc983 编写于 作者: H Hui Zhang

using bw rnn in ds2

上级 0d90d3f9
......@@ -18,7 +18,7 @@
All tested under:
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.1.2
* paddlepaddle>=2.2.0rc
Please see [install](doc/src/install.md).
......
......@@ -20,7 +20,7 @@
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.1.2
* paddlepaddle>=2.2.0rc
参看 [安装](doc/src/install.md)
......
......@@ -202,7 +202,7 @@ class BiRNNWithBN(nn.Layer):
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
......@@ -246,7 +246,7 @@ class BiGRUWithBN(nn.Layer):
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
......
......@@ -22,6 +22,13 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
try:
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401
except Exception as e:
logger.info("ctcdecoder not installed!")
__all__ = ['CTCDecoder']
......@@ -216,9 +223,6 @@ class CTCDecoder(nn.Layer):
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
decoding_method):
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401
if decoding_method == "ctc_beam_search":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
......
......@@ -10,8 +10,11 @@
| Model | Params | Release | Config | Test set | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 |
| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 6.016139030456543 | 0.066549 |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 7181e427 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 |
......@@ -5,8 +5,8 @@ if [ $# != 3 ]; then
exit -1
fi
ckpt_dir=${1}
avg_mode=${2} # best,latest
avg_mode=${1} # best,latest
ckpt_dir=${2}
average_num=${3}
decode_checkpoint=${ckpt_dir}/avg_${average_num}.pdparams
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册