提交 02083cdb 编写于 作者: H huangyuxin

fix the bug of 'dev/null' and the test_export

上级 fc8a7a15
......@@ -21,11 +21,6 @@ from typing import Optional
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
......@@ -44,6 +39,10 @@ from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
logger = Log(__name__).getlog()
......@@ -412,6 +411,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
class DeepSpeech2ExportTester(DeepSpeech2Tester):
def __init__(self, config, args):
super().__init__(config, args)
self.apply_static = True
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
if self.args.model_type == "online":
......
......@@ -18,9 +18,6 @@ from contextlib import contextmanager
from pathlib import Path
import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from deepspeech.training.reporter import ObsScope
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
......@@ -31,6 +28,8 @@ from deepspeech.utils.log import Log
from deepspeech.utils.utility import all_version
from deepspeech.utils.utility import seed_all
from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from tensorboardX import SummaryWriter
__all__ = ["Trainer"]
......@@ -348,8 +347,12 @@ class Trainer():
try:
with Timer("Test/Decode Done: {}"):
with self.eval():
self.restore()
self.test()
if hasattr(self,
"apply_static") and self.apply_static is True:
self.test()
else:
self.restore()
self.test()
except KeyboardInterrupt:
exit(-1)
......@@ -381,6 +384,8 @@ class Trainer():
elif self.args.checkpoint_path:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
elif self.args.export_path:
output_dir = Path(self.args.export_path).expanduser().parent.parent
self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True)
......
......@@ -13,7 +13,7 @@ ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh > dev/null 2>&1
bash local/download_lm_ch.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then
exit 1
fi
......
......@@ -13,7 +13,7 @@ jit_model_export_path=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh > dev/null 2>&1
bash local/download_lm_ch.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then
exit 1
fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册