未验证 提交 2fa68123 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #955 from Jackwaterveg/fix

fix the run_test in test_export
...@@ -21,6 +21,11 @@ from typing import Optional ...@@ -21,6 +21,11 @@ from typing import Optional
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
...@@ -32,6 +37,7 @@ from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline ...@@ -32,6 +37,7 @@ from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.reporter import report from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
...@@ -39,10 +45,6 @@ from deepspeech.utils import mp_tools ...@@ -39,10 +45,6 @@ from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -441,6 +443,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): ...@@ -441,6 +443,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
return result_transcripts return result_transcripts
def run_test(self):
"""Do Test/Decode"""
try:
with Timer("Test/Decode Done: {}"):
with self.eval():
self.test()
except KeyboardInterrupt:
exit(-1)
def static_forward_online(self, audio, audio_len, def static_forward_online(self, audio, audio_len,
decoder_chunk_size: int=1): decoder_chunk_size: int=1):
""" """
......
...@@ -18,6 +18,9 @@ from contextlib import contextmanager ...@@ -18,6 +18,9 @@ from contextlib import contextmanager
from pathlib import Path from pathlib import Path
import paddle import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from deepspeech.training.reporter import ObsScope from deepspeech.training.reporter import ObsScope
from deepspeech.training.reporter import report from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer from deepspeech.training.timer import Timer
...@@ -28,8 +31,6 @@ from deepspeech.utils.log import Log ...@@ -28,8 +31,6 @@ from deepspeech.utils.log import Log
from deepspeech.utils.utility import all_version from deepspeech.utils.utility import all_version
from deepspeech.utils.utility import seed_all from deepspeech.utils.utility import seed_all
from deepspeech.utils.utility import UpdateConfig from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from tensorboardX import SummaryWriter
__all__ = ["Trainer"] __all__ = ["Trainer"]
...@@ -347,10 +348,6 @@ class Trainer(): ...@@ -347,10 +348,6 @@ class Trainer():
try: try:
with Timer("Test/Decode Done: {}"): with Timer("Test/Decode Done: {}"):
with self.eval(): with self.eval():
if hasattr(self,
"apply_static") and self.apply_static is True:
self.test()
else:
self.restore() self.restore()
self.test() self.test()
except KeyboardInterrupt: except KeyboardInterrupt:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册