recog_bin.py 13.0 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hui Zhang 已提交
14
# Reference espnet Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
H
Hui Zhang 已提交
15 16 17 18 19 20 21 22 23 24
"""End-to-end speech recognition model decoding script."""
import logging
import os
import random
import sys
from distutils.util import strtobool

import configargparse
import numpy as np

H
Hui Zhang 已提交
25

H
Hui Zhang 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
def get_parser():
    """Get default arguments."""
    parser = configargparse.ArgumentParser(
        description="Transcribe text from speech using "
        "a speech recognition model on one CPU or GPU",
        config_file_parser_class=configargparse.YAMLConfigFileParser,
        formatter_class=configargparse.ArgumentDefaultsHelpFormatter, )
    parser.add(
        '--model-name',
        type=str,
        default='u2_kaldi',
        help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
    # general configuration
    parser.add("--config", is_config_file=True, help="Config file path")
    parser.add(
        "--config2",
        is_config_file=True,
        help="Second config file path that overwrites the settings in `--config`",
    )
    parser.add(
        "--config3",
        is_config_file=True,
        help="Third config file path that overwrites the settings "
        "in `--config` and `--config2`", )

    parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
    parser.add_argument(
        "--dtype",
        choices=("float16", "float32", "float64"),
        default="float32",
        help="Float precision (only available in --api v2)", )
    parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
    parser.add_argument("--seed", type=int, default=1, help="Random seed")
    parser.add_argument(
        "--verbose", "-V", type=int, default=2, help="Verbose option")
    parser.add_argument(
        "--batchsize",
        type=int,
        default=1,
        help="Batch size for beam search (0: means no batch processing)", )
    parser.add_argument(
        "--preprocess-conf",
        type=str,
        default=None,
        help="The configuration file for the pre-processing", )
    parser.add_argument(
        "--api",
        default="v2",
        choices=["v2"],
        help="Beam search APIs "
        "v2: Experimental API. It supports any models that implements ScorerInterface.",
    )
    # task related
    parser.add_argument(
        "--recog-json", type=str, help="Filename of recognition data (json)")
    parser.add_argument(
        "--result-label",
        type=str,
        required=True,
        help="Filename of result label data (json)", )
    # model (parameter) related
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Model file parameters to read")
    parser.add_argument(
        "--model-conf", type=str, default=None, help="Model config file")
    parser.add_argument(
        "--num-spkrs",
        type=int,
        default=1,
        choices=[1, 2],
        help="Number of speakers in the speech", )
    parser.add_argument(
        "--num-encs",
        default=1,
        type=int,
        help="Number of encoders in the model.")
    # search related
    parser.add_argument(
        "--nbest", type=int, default=1, help="Output N-best hypotheses")
    parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
    parser.add_argument(
        "--penalty", type=float, default=0.0, help="Incertion penalty")
    parser.add_argument(
        "--maxlenratio",
        type=float,
        default=0.0,
        help="""Input length ratio to obtain max output length.
                        If maxlenratio=0.0 (default), it uses a end-detect function
                        to automatically find maximum hypothesis lengths.
                        If maxlenratio<0.0, its absolute value is interpreted
                        as a constant max output length""", )
    parser.add_argument(
        "--minlenratio",
        type=float,
        default=0.0,
        help="Input length ratio to obtain min output length", )
    parser.add_argument(
        "--ctc-weight",
        type=float,
        default=0.0,
        help="CTC weight in joint decoding")
    parser.add_argument(
        "--weights-ctc-dec",
        type=float,
        action="append",
        help="ctc weight assigned to each encoder during decoding."
        "[in multi-encoder mode only]", )
    parser.add_argument(
        "--ctc-window-margin",
        type=int,
        default=0,
        help="""Use CTC window with margin parameter to accelerate
                        CTC/attention decoding especially on GPU. Smaller magin
                        makes decoding faster, but may increase search errors.
                        If margin=0 (default), this function is disabled""", )
    # transducer related
    parser.add_argument(
        "--search-type",
        type=str,
        default="default",
        choices=["default", "nsc", "tsd", "alsd", "maes"],
        help="""Type of beam search implementation to use during inference.
        Can be either: default beam search ("default"),
        N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"),
        Alignment-Length Synchronous Decoding ("alsd") or
        modified Adaptive Expansion Search ("maes").""", )
    parser.add_argument(
        "--nstep",
        type=int,
        default=1,
        help="""Number of expansion steps allowed in NSC beam search or mAES
        (nstep > 0 for NSC and nstep > 1 for mAES).""", )
    parser.add_argument(
        "--prefix-alpha",
        type=int,
        default=2,
        help="Length prefix difference allowed in NSC beam search or mAES.", )
    parser.add_argument(
        "--max-sym-exp",
        type=int,
        default=2,
        help="Number of symbol expansions allowed in TSD.", )
    parser.add_argument(
        "--u-max",
        type=int,
        default=400,
        help="Length prefix difference allowed in ALSD.", )
    parser.add_argument(
        "--expansion-gamma",
        type=float,
        default=2.3,
        help="Allowed logp difference for prune-by-value method in mAES.", )
    parser.add_argument(
        "--expansion-beta",
        type=int,
        default=2,
        help="""Number of additional candidates for expanded hypotheses
                selection in mAES.""", )
    parser.add_argument(
        "--score-norm",
        type=strtobool,
        nargs="?",
        default=True,
        help="Normalize final hypotheses' score by length", )
    parser.add_argument(
        "--softmax-temperature",
        type=float,
        default=1.0,
        help="Penalization term for softmax function.", )
    # rnnlm related
    parser.add_argument(
        "--rnnlm", type=str, default=None, help="RNNLM model file to read")
    parser.add_argument(
        "--rnnlm-conf",
        type=str,
        default=None,
        help="RNNLM model config file to read")
    parser.add_argument(
        "--word-rnnlm",
        type=str,
        default=None,
        help="Word RNNLM model file to read")
    parser.add_argument(
        "--word-rnnlm-conf",
        type=str,
        default=None,
        help="Word RNNLM model config file to read", )
    parser.add_argument(
        "--word-dict", type=str, default=None, help="Word list to read")
    parser.add_argument(
        "--lm-weight", type=float, default=0.1, help="RNNLM weight")
    # ngram related
    parser.add_argument(
        "--ngram-model",
        type=str,
        default=None,
        help="ngram model file to read")
    parser.add_argument(
        "--ngram-weight", type=float, default=0.1, help="ngram weight")
    parser.add_argument(
        "--ngram-scorer",
        type=str,
        default="part",
        choices=("full", "part"),
        help="""if the ngram is set as a part scorer, similar with CTC scorer,
                ngram scorer only scores topK hypethesis.
                if the ngram is set as full scorer, ngram scorer scores all hypthesis
                the decoding speed of part scorer is musch faster than full one""",
    )
    # streaming related
    parser.add_argument(
        "--streaming-mode",
        type=str,
        default=None,
        choices=["window", "segment"],
        help="""Use streaming recognizer for inference.
                        `--batchsize` must be set to 0 to enable this mode""", )
    parser.add_argument(
        "--streaming-window", type=int, default=10, help="Window size")
    parser.add_argument(
        "--streaming-min-blank-dur",
        type=int,
        default=10,
        help="Minimum blank duration threshold", )
    parser.add_argument(
        "--streaming-onset-margin", type=int, default=1, help="Onset margin")
    parser.add_argument(
        "--streaming-offset-margin", type=int, default=1, help="Offset margin")
    # non-autoregressive related
    # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
    parser.add_argument(
        "--maskctc-n-iterations",
        type=int,
        default=10,
        help="Number of decoding iterations."
        "For Mask CTC, set 0 to predict 1 mask/iter.", )
    parser.add_argument(
        "--maskctc-probability-threshold",
        type=float,
        default=0.999,
        help="Threshold probability for CTC output", )
    # quantize model related
    parser.add_argument(
        "--quantize-config",
        nargs="*",
        help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]",
    )
    parser.add_argument(
        "--quantize-dtype",
        type=str,
        default="qint8",
        help="Dtype dynamic quantize")
    parser.add_argument(
        "--quantize-asr-model",
        type=bool,
        default=False,
        help="Quantize asr model", )
    parser.add_argument(
        "--quantize-lm-model",
        type=bool,
        default=False,
        help="Quantize lm model", )
    return parser


def main(args):
    """Run the main decoding function."""
    parser = get_parser()
    parser.add_argument(
        "--output", metavar="CKPT_DIR", help="path to save checkpoint.")
    parser.add_argument(
        "--checkpoint_path", type=str, help="path to load checkpoint")
    parser.add_argument("--dict-path", type=str, help="path to load checkpoint")
    args = parser.parse_args(args)

    if args.ngpu == 0 and args.dtype == "float16":
        raise ValueError(
            f"--dtype {args.dtype} does not support the CPU backend.")

    # logging info
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose == 2:
        logging.basicConfig(
            level=logging.DEBUG,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")
    logging.info(args)

    # check CUDA_VISIBLE_DEVICES
    if args.ngpu > 0:
        cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
        if cvd is None:
            logging.warning("CUDA_VISIBLE_DEVICES is not set.")
        elif args.ngpu != len(cvd.split(",")):
            logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
            sys.exit(1)

        # TODO(mn5k): support of multiple GPUs
        if args.ngpu > 1:
            logging.error("The program only supports ngpu=1.")
            sys.exit(1)

    # display PYTHONPATH
    logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))

    # seed setting
    random.seed(args.seed)
    np.random.seed(args.seed)
    logging.info("set random seed = %d" % args.seed)

    # validate rnn options
    if args.rnnlm is not None and args.word_rnnlm is not None:
        logging.error(
            "It seems that both --rnnlm and --word-rnnlm are specified. "
            "Please use either option.")
        sys.exit(1)

    # recog
    if args.num_spkrs == 1:
        if args.num_encs == 1:
            # Experimental API that supports custom LMs
            if args.api == "v2":
361
                from paddlespeech.s2t.decoders.recog import recog_v2
H
Hui Zhang 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374 375
                recog_v2(args)
            else:
                raise ValueError("Only support --api v2")
        else:
            if args.api == "v2":
                raise NotImplementedError(
                    f"--num-encs {args.num_encs} > 1 is not supported in --api v2"
                )
    elif args.num_spkrs == 2:
        raise ValueError("asr_mix not supported.")


if __name__ == "__main__":
    main(sys.argv[1:])