提交 02083cdb 编写于 作者: H huangyuxin

fix the bug of 'dev/null' and the test_export

上级 fc8a7a15
...@@ -21,11 +21,6 @@ from typing import Optional ...@@ -21,11 +21,6 @@ from typing import Optional
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
...@@ -44,6 +39,10 @@ from deepspeech.utils import mp_tools ...@@ -44,6 +39,10 @@ from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -412,6 +411,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -412,6 +411,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
class DeepSpeech2ExportTester(DeepSpeech2Tester): class DeepSpeech2ExportTester(DeepSpeech2Tester):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self.apply_static = True
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
if self.args.model_type == "online": if self.args.model_type == "online":
......
...@@ -18,9 +18,6 @@ from contextlib import contextmanager ...@@ -18,9 +18,6 @@ from contextlib import contextmanager
from pathlib import Path from pathlib import Path
import paddle import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from deepspeech.training.reporter import ObsScope from deepspeech.training.reporter import ObsScope
from deepspeech.training.reporter import report from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer from deepspeech.training.timer import Timer
...@@ -31,6 +28,8 @@ from deepspeech.utils.log import Log ...@@ -31,6 +28,8 @@ from deepspeech.utils.log import Log
from deepspeech.utils.utility import all_version from deepspeech.utils.utility import all_version
from deepspeech.utils.utility import seed_all from deepspeech.utils.utility import seed_all
from deepspeech.utils.utility import UpdateConfig from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from tensorboardX import SummaryWriter
__all__ = ["Trainer"] __all__ = ["Trainer"]
...@@ -348,8 +347,12 @@ class Trainer(): ...@@ -348,8 +347,12 @@ class Trainer():
try: try:
with Timer("Test/Decode Done: {}"): with Timer("Test/Decode Done: {}"):
with self.eval(): with self.eval():
self.restore() if hasattr(self,
self.test() "apply_static") and self.apply_static is True:
self.test()
else:
self.restore()
self.test()
except KeyboardInterrupt: except KeyboardInterrupt:
exit(-1) exit(-1)
...@@ -381,6 +384,8 @@ class Trainer(): ...@@ -381,6 +384,8 @@ class Trainer():
elif self.args.checkpoint_path: elif self.args.checkpoint_path:
output_dir = Path( output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent self.args.checkpoint_path).expanduser().parent.parent
elif self.args.export_path:
output_dir = Path(self.args.export_path).expanduser().parent.parent
self.output_dir = output_dir self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True) self.output_dir.mkdir(parents=True, exist_ok=True)
......
...@@ -13,7 +13,7 @@ ckpt_prefix=$2 ...@@ -13,7 +13,7 @@ ckpt_prefix=$2
model_type=$3 model_type=$3
# download language model # download language model
bash local/download_lm_ch.sh > dev/null 2>&1 bash local/download_lm_ch.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
......
...@@ -13,7 +13,7 @@ jit_model_export_path=$2 ...@@ -13,7 +13,7 @@ jit_model_export_path=$2
model_type=$3 model_type=$3
# download language model # download language model
bash local/download_lm_ch.sh > dev/null 2>&1 bash local/download_lm_ch.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册