diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 710630a7864a6296a0f0ed4f19ede9f17df136c9..56743629b6b1cfac1e0ca1e57ab3748290789cad 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -21,11 +21,6 @@ from typing import Optional import jsonlines import numpy as np import paddle -from paddle import distributed as dist -from paddle import inference -from paddle.io import DataLoader -from yacs.config import CfgNode - from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset @@ -44,6 +39,10 @@ from deepspeech.utils import mp_tools from deepspeech.utils.log import Autolog from deepspeech.utils.log import Log from deepspeech.utils.utility import UpdateConfig +from paddle import distributed as dist +from paddle import inference +from paddle.io import DataLoader +from yacs.config import CfgNode logger = Log(__name__).getlog() @@ -412,6 +411,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): class DeepSpeech2ExportTester(DeepSpeech2Tester): def __init__(self, config, args): super().__init__(config, args) + self.apply_static = True def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): if self.args.model_type == "online": diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 2da838047206fc34011986daf33aa4203c5e92dd..ddde1e885c2cf33f6dc2af13e90a96315780340c 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,9 +18,6 @@ from contextlib import contextmanager from pathlib import Path import paddle -from paddle import distributed as dist -from tensorboardX import SummaryWriter - from deepspeech.training.reporter import ObsScope from deepspeech.training.reporter import report from deepspeech.training.timer import Timer @@ -31,6 +28,8 @@ from deepspeech.utils.log import Log from deepspeech.utils.utility import all_version from deepspeech.utils.utility import seed_all from deepspeech.utils.utility import UpdateConfig +from paddle import distributed as dist +from tensorboardX import SummaryWriter __all__ = ["Trainer"] @@ -348,8 +347,12 @@ class Trainer(): try: with Timer("Test/Decode Done: {}"): with self.eval(): - self.restore() - self.test() + if hasattr(self, + "apply_static") and self.apply_static is True: + self.test() + else: + self.restore() + self.test() except KeyboardInterrupt: exit(-1) @@ -381,6 +384,8 @@ class Trainer(): elif self.args.checkpoint_path: output_dir = Path( self.args.checkpoint_path).expanduser().parent.parent + elif self.args.export_path: + output_dir = Path(self.args.export_path).expanduser().parent.parent self.output_dir = output_dir self.output_dir.mkdir(parents=True, exist_ok=True) diff --git a/examples/aishell/s0/local/test.sh b/examples/aishell/s0/local/test.sh index 64d7250304137e7d658d3bb48d916a346229d876..d539ac4943039fe6c33eb1373985aa98617a587f 100755 --- a/examples/aishell/s0/local/test.sh +++ b/examples/aishell/s0/local/test.sh @@ -13,7 +13,7 @@ ckpt_prefix=$2 model_type=$3 # download language model -bash local/download_lm_ch.sh > dev/null 2>&1 +bash local/download_lm_ch.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi diff --git a/examples/aishell/s0/local/test_export.sh b/examples/aishell/s0/local/test_export.sh index 71469753db5b2615585851f7f3c37a4119ff5056..f0a30ce56fbc4cc43b559295fba7ef3ac3b3be26 100755 --- a/examples/aishell/s0/local/test_export.sh +++ b/examples/aishell/s0/local/test_export.sh @@ -13,7 +13,7 @@ jit_model_export_path=$2 model_type=$3 # download language model -bash local/download_lm_ch.sh > dev/null 2>&1 +bash local/download_lm_ch.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi