recog.py 6.8 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hui Zhang 已提交
14
# Reference espnet Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
H
Hui Zhang 已提交
15 16
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import jsonlines
H
Hui Zhang 已提交
17
import paddle
H
Hui Zhang 已提交
18
from yacs.config import CfgNode
H
Hui Zhang 已提交
19

H
Hui Zhang 已提交
20
from .beam_search import BatchBeamSearch
H
Hui Zhang 已提交
21 22
from .beam_search import BeamSearch
from .scorers.length_bonus import LengthBonus
H
Hui Zhang 已提交
23 24
from .scorers.scorer_interface import BatchScorerInterface
from .utils import add_results_to_json
25 26 27 28 29
from paddlespeech.s2t.exps import dynamic_import_tester
from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from paddlespeech.s2t.models.asr_interface import ASRInterface
from paddlespeech.s2t.models.lm_interface import dynamic_import_lm
from paddlespeech.s2t.utils.log import Log
H
Hui Zhang 已提交
30

H
Hui Zhang 已提交
31 32
logger = Log(__name__).getlog()

H
Hui Zhang 已提交
33
# NOTE: you need this func to generate our sphinx doc
H
Hui Zhang 已提交
34

H
Hui Zhang 已提交
35

H
Hui Zhang 已提交
36 37 38 39 40 41
def get_config(config_path):
    confs = CfgNode(new_allowed=True)
    confs.merge_from_file(config_path)
    return confs


H
Hui Zhang 已提交
42 43
def load_trained_model(args):
    args.nprocs = args.ngpu
H
Hui Zhang 已提交
44
    confs = get_config(args.model_conf)
H
Hui Zhang 已提交
45 46 47 48 49 50 51 52 53
    class_obj = dynamic_import_tester(args.model_name)
    exp = class_obj(confs, args)
    with exp.eval():
        exp.setup()
        exp.restore()
    char_list = exp.args.char_list
    model = exp.model
    return model, char_list, exp, confs

H
Hui Zhang 已提交
54

H
Hui Zhang 已提交
55
def load_trained_lm(args):
H
Hui Zhang 已提交
56
    lm_args = get_config(args.rnnlm_conf)
H
Hui Zhang 已提交
57
    lm_model_module = lm_args.model_module
H
Hui Zhang 已提交
58
    lm_class = dynamic_import_lm(lm_model_module)
H
Hui Zhang 已提交
59
    lm = lm_class(**lm_args.model)
H
Hui Zhang 已提交
60 61 62 63
    model_dict = paddle.load(args.rnnlm)
    lm.set_state_dict(model_dict)
    return lm

H
Hui Zhang 已提交
64

H
Hui Zhang 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
def recog_v2(args):
    """Decode with custom models that implements ScorerInterface.

    Args:
        args (namespace): The program arguments.
        See py:func:`bin.asr_recog.get_parser` for details

    """
    logger.warning("experimental API for custom LMs is selected by --api v2")
    if args.batchsize > 1:
        raise NotImplementedError("multi-utt batch decoding is not implemented")
    if args.streaming_mode is not None:
        raise NotImplementedError("streaming mode is not implemented")
    if args.word_rnnlm:
        raise NotImplementedError("word LM is not implemented")
H
Hui Zhang 已提交
80

H
Hui Zhang 已提交
81 82
    # set_deterministic(args)
    model, char_list, exp, confs = load_trained_model(args)
H
Hui Zhang 已提交
83
    assert isinstance(model, ASRInterface)
H
Hui Zhang 已提交
84

H
Hui Zhang 已提交
85 86 87 88
    load_inputs_and_targets = LoadInputsAndTargets(
        mode="asr",
        load_output=False,
        sort_in_input_length=False,
H
Hui Zhang 已提交
89
        preprocess_conf=confs.collator.augmentation_config
H
Hui Zhang 已提交
90
        if args.preprocess_conf is None else args.preprocess_conf,
91
        preprocess_args={"train": False}, )
H
Hui Zhang 已提交
92 93

    if args.rnnlm:
H
Hui Zhang 已提交
94
        lm = load_trained_lm(args)
H
Hui Zhang 已提交
95 96 97 98 99 100 101 102 103
        lm.eval()
    else:
        lm = None

    if args.ngram_model:
        from .scorers.ngram import NgramFullScorer
        from .scorers.ngram import NgramPartScorer

        if args.ngram_scorer == "full":
H
Hui Zhang 已提交
104
            ngram = NgramFullScorer(args.ngram_model, char_list)
H
Hui Zhang 已提交
105
        else:
H
Hui Zhang 已提交
106
            ngram = NgramPartScorer(args.ngram_model, char_list)
H
Hui Zhang 已提交
107 108 109
    else:
        ngram = None

H
Hui Zhang 已提交
110
    scorers = model.scorers()  # decoder
H
Hui Zhang 已提交
111 112
    scorers["lm"] = lm
    scorers["ngram"] = ngram
H
Hui Zhang 已提交
113
    scorers["length_bonus"] = LengthBonus(len(char_list))
H
Hui Zhang 已提交
114 115 116 117 118
    weights = dict(
        decoder=1.0 - args.ctc_weight,
        ctc=args.ctc_weight,
        lm=args.lm_weight,
        ngram=args.ngram_weight,
119
        length_bonus=args.penalty, )
H
Hui Zhang 已提交
120 121
    beam_search = BeamSearch(
        beam_size=args.beam_size,
H
Hui Zhang 已提交
122
        vocab_size=len(char_list),
H
Hui Zhang 已提交
123 124 125 126
        weights=weights,
        scorers=scorers,
        sos=model.sos,
        eos=model.eos,
H
Hui Zhang 已提交
127
        token_list=char_list,
128
        pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", )
H
Hui Zhang 已提交
129

H
Hui Zhang 已提交
130 131 132
    # TODO(karita): make all scorers batchfied
    if args.batchsize == 1:
        non_batch = [
H
Hui Zhang 已提交
133
            k for k, v in beam_search.full_scorers.items()
H
Hui Zhang 已提交
134 135 136 137 138 139
            if not isinstance(v, BatchScorerInterface)
        ]
        if len(non_batch) == 0:
            beam_search.__class__ = BatchBeamSearch
            logger.info("BatchBeamSearch implementation is selected.")
        else:
H
Hui Zhang 已提交
140 141
            logger.warning(f"As non-batch scorers {non_batch} are found, "
                           f"fall back to non-batch implementation.")
H
Hui Zhang 已提交
142 143 144 145 146 147 148

    if args.ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    if args.ngpu == 1:
        device = "gpu:0"
    else:
        device = "cpu"
H
Hui Zhang 已提交
149
    paddle.set_device(device)
H
Hui Zhang 已提交
150 151 152 153 154 155 156 157
    dtype = getattr(paddle, args.dtype)
    logger.info(f"Decoding device={device}, dtype={dtype}")
    model.to(device=device, dtype=dtype)
    model.eval()
    beam_search.to(device=device, dtype=dtype)
    beam_search.eval()

    # read json data
H
Hui Zhang 已提交
158 159 160 161
    js = []
    with jsonlines.open(args.recog_json, "r") as reader:
        for item in reader:
            js.append(item)
H
Hui Zhang 已提交
162
    # jsonlines to dict, key by 'utt', value by jsonline
H
Hui Zhang 已提交
163 164 165 166
    js = {item['utt']: item for item in js}

    new_js = {}
    with paddle.no_grad():
H
Hui Zhang 已提交
167 168 169 170 171 172 173
        with jsonlines.open(args.result_label, "w") as f:
            for idx, name in enumerate(js.keys(), 1):
                logger.info(f"({idx}/{len(js.keys())}) decoding " + name)
                batch = [(name, js[name])]
                feat = load_inputs_and_targets(batch)[0][0]
                logger.info(f'feat: {feat.shape}')
                enc = model.encode(paddle.to_tensor(feat).to(dtype))
H
Hui Zhang 已提交
174
                logger.info(f'eout: {enc.shape}')
175 176 177 178
                nbest_hyps = beam_search(
                    x=enc,
                    maxlenratio=args.maxlenratio,
                    minlenratio=args.minlenratio)
H
Hui Zhang 已提交
179
                nbest_hyps = [
H
Hui Zhang 已提交
180 181
                    h.asdict()
                    for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)]
H
Hui Zhang 已提交
182
                ]
H
Hui Zhang 已提交
183 184
                new_js[name] = add_results_to_json(js[name], nbest_hyps,
                                                   char_list)
H
Hui Zhang 已提交
185

H
Hui Zhang 已提交
186
                item = new_js[name]['output'][0]  # 1-best
H
Hui Zhang 已提交
187
                ref = item['text']
188 189
                rec_text = item['rec_text'].replace('▁', ' ').replace(
                    '<eos>', '').strip()
H
Hui Zhang 已提交
190
                rec_tokenid = list(map(int, item['rec_tokenid'].split()))
H
Hui Zhang 已提交
191
                f.write({
H
Hui Zhang 已提交
192 193 194 195 196
                    "utt": name,
                    "refs": [ref],
                    "hyps": [rec_text],
                    "hyps_tokenid": [rec_tokenid],
                })