From ac9fcf7f4a53026bba8efe235d90a0693a70eae6 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Wed, 20 Apr 2022 00:15:37 +0800 Subject: [PATCH] fix the asr infernece model, paddle.no_grad, test=doc --- paddlespeech/server/engine/asr/online/asr_engine.py | 3 +++ paddlespeech/server/engine/asr/online/ctc_search.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 34a028a3..758cbaab 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -356,6 +356,7 @@ class PaddleASRConnectionHanddler: else: raise Exception("invalid model name") + @paddle.no_grad() def decode_one_chunk(self, x_chunk, x_chunk_lens): logger.info("start to decoce one chunk with deepspeech2 model") input_names = self.am_predictor.get_input_names() @@ -397,6 +398,7 @@ class PaddleASRConnectionHanddler: logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] + @paddle.no_grad() def advance_decoding(self, is_finished=False): logger.info("start to decode with advanced_decoding method") cfg = self.ctc_decode_config @@ -503,6 +505,7 @@ class PaddleASRConnectionHanddler: else: return '' + @paddle.no_grad() def rescoring(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: return diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index b1c80c36..8aee0a50 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict - +import paddle from paddlespeech.cli.log import logger from paddlespeech.s2t.utils.utility import log_add @@ -29,6 +29,7 @@ class CTCPrefixBeamSearch: self.config = config self.reset() + @paddle.no_grad() def search(self, ctc_probs, device, blank_id=0): """ctc prefix beam search method decode a chunk feature -- GitLab