未验证 提交 05288cd3 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #928 from PaddlePaddle/join_ctc

[decoder]Join ctc: decoder, ctc
......@@ -233,7 +233,8 @@ def is_broadcastable(shp1, shp2):
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) is True
assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape,
mask.shape)
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
......@@ -315,7 +316,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
assert len(args) == 1
if isinstance(args[0], str): # dtype
return x.astype(args[0])
elif isinstance(args[0], paddle.Tensor): #Tensor
elif isinstance(args[0], paddle.Tensor): # Tensor
return x.astype(args[0].dtype)
else: # Device
return x
......@@ -364,6 +365,7 @@ from typing import Tuple
from typing import Iterator
from collections import OrderedDict, abc as container_abcs
class LayerDict(paddle.nn.Layer):
r"""Holds submodules in a dictionary.
......@@ -408,7 +410,7 @@ class LayerDict(paddle.nn.Layer):
return x
"""
def __init__(self, modules: Optional[Mapping[str, Layer]] = None) -> None:
def __init__(self, modules: Optional[Mapping[str, Layer]]=None) -> None:
super(LayerDict, self).__init__()
if modules is not None:
self.update(modules)
......@@ -475,10 +477,11 @@ class LayerDict(paddle.nn.Layer):
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("LayerDict.update should be called with an "
"iterable of key/value pairs, but got " +
type(modules).__name__)
"iterable of key/value pairs, but got " + type(
modules).__name__)
if isinstance(modules, (OrderedDict, LayerDict, container_abcs.Mapping)):
if isinstance(modules,
(OrderedDict, LayerDict, container_abcs.Mapping)):
for key, module in modules.items():
self[key] = module
else:
......@@ -490,14 +493,15 @@ class LayerDict(paddle.nn.Layer):
type(m).__name__)
if not len(m) == 2:
raise ValueError("LayerDict update sequence element "
"#" + str(j) + " has length " + str(len(m)) +
"; 2 is required")
"#" + str(j) + " has length " + str(
len(m)) + "; 2 is required")
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
# that's too cumbersome to type correctly with overloads, so we add an ignore here
self[m[0]] = m[1] # type: ignore[assignment]
# remove forward alltogether to fallback on Module's _forward_unimplemented
if not hasattr(paddle.nn, 'LayerDict'):
logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!")
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .batch_beam_search import BatchBeamSearch
from .beam_search import beam_search
from .beam_search import BeamSearch
from .beam_search import Hypothesis
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class BatchBeamSearch():
pass
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Beam search module."""
from itertools import chain
import logger
from typing import Any
from typing import Dict
from typing import List
......@@ -11,18 +22,18 @@ from typing import Union
import paddle
from .utils import end_detect
from .scorers.scorer_interface import PartialScorerInterface
from .scorers.scorer_interface import ScorerInterface
from ..scorers.scorer_interface import PartialScorerInterface
from ..scorers.scorer_interface import ScorerInterface
from ..utils import end_detect
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
yseq: paddle.Tensor # (T,)
yseq: paddle.Tensor # (T,)
score: Union[float, paddle.Tensor] = 0
scores: Dict[str, Union[float, paddle.Tensor]] = dict()
states: Dict[str, Any] = dict()
......@@ -32,25 +43,24 @@ class Hypothesis(NamedTuple):
return self._replace(
yseq=self.yseq.tolist(),
score=float(self.score),
scores={k: float(v) for k, v in self.scores.items()},
)._asdict()
scores={k: float(v)
for k, v in self.scores.items()}, )._asdict()
class BeamSearch(paddle.nn.Layer):
"""Beam search implementation."""
def __init__(
self,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
beam_size: int,
vocab_size: int,
sos: int,
eos: int,
token_list: List[str] = None,
pre_beam_ratio: float = 1.5,
pre_beam_score_key: str = None,
):
self,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
beam_size: int,
vocab_size: int,
sos: int,
eos: int,
token_list: List[str]=None,
pre_beam_ratio: float=1.5,
pre_beam_score_key: str=None, ):
"""Initialize beam search.
Args:
......@@ -72,12 +82,12 @@ class BeamSearch(paddle.nn.Layer):
super().__init__()
# set scorers
self.weights = weights
self.scorers = dict() # all = full + partial
self.full_scorers = dict() # full tokens
self.part_scorers = dict() # partial tokens
self.scorers = dict() # all = full + partial
self.full_scorers = dict() # full tokens
self.part_scorers = dict() # partial tokens
# this module dict is required for recursive cast
# `self.to(device, dtype)` in `recog.py`
self.nn_dict = paddle.nn.LayerDict() # nn.Layer
self.nn_dict = paddle.nn.LayerDict() # nn.Layer
for k, v in scorers.items():
w = weights.get(k, 0)
if w == 0 or v is None:
......@@ -101,20 +111,16 @@ class BeamSearch(paddle.nn.Layer):
self.pre_beam_size = int(pre_beam_ratio * beam_size)
self.beam_size = beam_size
self.n_vocab = vocab_size
if (
pre_beam_score_key is not None
and pre_beam_score_key != "full"
and pre_beam_score_key not in self.full_scorers
):
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
if (pre_beam_score_key is not None and pre_beam_score_key != "full" and
pre_beam_score_key not in self.full_scorers):
raise KeyError(
f"{pre_beam_score_key} is not found in {self.full_scorers}")
# selected `key` scorer to do pre beam search
self.pre_beam_score_key = pre_beam_score_key
# do_pre_beam when need, valid and has part_scorers
self.do_pre_beam = (
self.pre_beam_score_key is not None
and self.pre_beam_size < self.n_vocab
and len(self.part_scorers) > 0
)
self.do_pre_beam = (self.pre_beam_score_key is not None and
self.pre_beam_size < self.n_vocab and
len(self.part_scorers) > 0)
def init_hyp(self, x: paddle.Tensor) -> List[Hypothesis]:
"""Get an initial hypothesis data.
......@@ -136,12 +142,12 @@ class BeamSearch(paddle.nn.Layer):
yseq=paddle.to_tensor([self.sos], place=x.place),
score=0.0,
scores=init_scores,
states=init_states,
)
states=init_states, )
]
@staticmethod
def append_token(xs: paddle.Tensor, x: int) -> paddle.Tensor:
def append_token(xs: paddle.Tensor,
x: Union[int, paddle.Tensor]) -> paddle.Tensor:
"""Append new token to prefix tokens.
Args:
......@@ -152,12 +158,11 @@ class BeamSearch(paddle.nn.Layer):
paddle.Tensor: (T+1,), New tensor contains: xs + [x] with xs.dtype and xs.device
"""
x = paddle.to_tensor([x], dtype=xs.dtype, place=xs.place)
return paddle.cat((xs, x))
x = paddle.to_tensor([x], dtype=xs.dtype) if isinstance(x, int) else x
return paddle.concat((xs, x))
def score_full(
self, hyp: Hypothesis, x: paddle.Tensor
) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]:
def score_full(self, hyp: Hypothesis, x: paddle.Tensor
) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
......@@ -179,9 +184,11 @@ class BeamSearch(paddle.nn.Layer):
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
return scores, states
def score_partial(
self, hyp: Hypothesis, ids: paddle.Tensor, x: paddle.Tensor
) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]:
def score_partial(self,
hyp: Hypothesis,
ids: paddle.Tensor,
x: paddle.Tensor
) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.part_scorers`.
Args:
......@@ -202,12 +209,12 @@ class BeamSearch(paddle.nn.Layer):
states = dict()
for k, d in self.part_scorers.items():
# scores[k] shape (len(ids),)
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k],
x)
return scores, states
def beam(
self, weighted_scores: paddle.Tensor, ids: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor]:
def beam(self, weighted_scores: paddle.Tensor,
ids: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute topk full token ids and partial token ids.
Args:
......@@ -224,7 +231,8 @@ class BeamSearch(paddle.nn.Layer):
"""
# no pre beam performed, `ids` equal to `weighted_scores`
if weighted_scores.size(0) == ids.size(0):
top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab
top_ids = weighted_scores.topk(
self.beam_size)[1] # index in n_vocab
return top_ids, top_ids
# mask pruned in pre-beam not to select in topk
......@@ -232,18 +240,18 @@ class BeamSearch(paddle.nn.Layer):
weighted_scores[:] = -float("inf")
weighted_scores[ids] = tmp
# top_ids no equal to local_ids, since ids shape not same
top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab
local_ids = weighted_scores[ids].topk(self.beam_size)[1] # index in len(ids)
top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab
local_ids = weighted_scores[ids].topk(
self.beam_size)[1] # index in len(ids)
return top_ids, local_ids
@staticmethod
def merge_scores(
prev_scores: Dict[str, float],
next_full_scores: Dict[str, paddle.Tensor],
full_idx: int,
next_part_scores: Dict[str, paddle.Tensor],
part_idx: int,
) -> Dict[str, paddle.Tensor]:
prev_scores: Dict[str, float],
next_full_scores: Dict[str, paddle.Tensor],
full_idx: int,
next_part_scores: Dict[str, paddle.Tensor],
part_idx: int, ) -> Dict[str, paddle.Tensor]:
"""Merge scores for new hypothesis.
Args:
......@@ -289,9 +297,8 @@ class BeamSearch(paddle.nn.Layer):
new_states[k] = d.select_state(part_states[k], part_idx)
return new_states
def search(
self, running_hyps: List[Hypothesis], x: paddle.Tensor
) -> List[Hypothesis]:
def search(self, running_hyps: List[Hypothesis],
x: paddle.Tensor) -> List[Hypothesis]:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
......@@ -306,17 +313,15 @@ class BeamSearch(paddle.nn.Layer):
part_ids = paddle.arange(self.n_vocab) # no pre-beam
for hyp in running_hyps:
# scoring
weighted_scores = paddle.zeros(self.n_vocab, dtype=x.dtype)
weighted_scores = paddle.zeros([self.n_vocab], dtype=x.dtype)
scores, states = self.score_full(hyp, x)
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
# partial scoring
if self.do_pre_beam:
pre_beam_scores = (
weighted_scores
if self.pre_beam_score_key == "full"
else scores[self.pre_beam_score_key]
)
pre_beam_scores = (weighted_scores
if self.pre_beam_score_key == "full" else
scores[self.pre_beam_score_key])
part_ids = paddle.topk(pre_beam_scores, self.pre_beam_size)[1]
part_scores, part_states = self.score_partial(hyp, part_ids, x)
for k in self.part_scorers:
......@@ -332,22 +337,21 @@ class BeamSearch(paddle.nn.Layer):
Hypothesis(
score=weighted_scores[j],
yseq=self.append_token(hyp.yseq, j),
scores=self.merge_scores(
hyp.scores, scores, j, part_scores, part_j
),
scores=self.merge_scores(hyp.scores, scores, j,
part_scores, part_j),
states=self.merge_states(states, part_states, part_j),
)
)
))
# sort and prune 2 x beam -> beam
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
: min(len(best_hyps), self.beam_size)
]
best_hyps = sorted(
best_hyps, key=lambda x: x.score,
reverse=True)[:min(len(best_hyps), self.beam_size)]
return best_hyps
def forward(
self, x: paddle.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
) -> List[Hypothesis]:
def forward(self,
x: paddle.Tensor,
maxlenratio: float=0.0,
minlenratio: float=0.0) -> List[Hypothesis]:
"""Perform beam search.
Args:
......@@ -382,9 +386,11 @@ class BeamSearch(paddle.nn.Layer):
logger.debug("position " + str(i))
best = self.search(running_hyps, x)
# post process of one iteration
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
running_hyps = self.post_process(i, maxlen, maxlenratio, best,
ended_hyps)
# end detection
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
if maxlenratio == 0.0 and end_detect(
[h.asdict() for h in ended_hyps], i):
logger.info(f"end detected at {i}")
break
if len(running_hyps) == 0:
......@@ -396,41 +402,39 @@ class BeamSearch(paddle.nn.Layer):
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logger.warning(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return (
[]
if minlenratio < 0.1
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
)
logger.warning("there is no N-best results, perform recognition "
"again with smaller minlenratio.")
return ([] if minlenratio < 0.1 else
self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)))
# report the best result
best = nbest_hyps[0]
for k, v in best.scores.items():
logger.info(
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
f"{float(v):6.2f} * {self.weights[k]:3} = {float(v) * self.weights[k]:6.2f} for {k}"
)
logger.info(f"total log probability: {best.score:.2f}")
logger.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logger.info(f"total log probability: {float(best.score):.2f}")
logger.info(
f"normalized log probability: {float(best.score) / len(best.yseq):.2f}"
)
logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logger.info(
"best hypo: "
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
# logger.info(
# "best hypo: "
# + "".join([self.token_list[x] for x in best.yseq[1:-1]])
# + "\n"
# )
logger.info("best hypo: " + "".join(
[self.token_list[x] for x in best.yseq[1:]]) + "\n")
return nbest_hyps
def post_process(
self,
i: int,
maxlen: int,
maxlenratio: float,
running_hyps: List[Hypothesis],
ended_hyps: List[Hypothesis],
) -> List[Hypothesis]:
self,
i: int,
maxlen: int,
maxlenratio: float,
running_hyps: List[Hypothesis],
ended_hyps: List[Hypothesis], ) -> List[Hypothesis]:
"""Perform post-processing of beam search iterations.
Args:
......@@ -446,10 +450,8 @@ class BeamSearch(paddle.nn.Layer):
"""
logger.debug(f"the number of running hypotheses: {len(running_hyps)}")
if self.token_list is not None:
logger.debug(
"best hypo: "
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
)
logger.debug("best hypo: " + "".join(
[self.token_list[x] for x in running_hyps[0].yseq[1:]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logger.info("adding <eos> in the last position in the loop")
......@@ -464,7 +466,8 @@ class BeamSearch(paddle.nn.Layer):
for hyp in running_hyps:
if hyp.yseq[-1] == self.eos:
# e.g., Word LM needs to add final <eos> score
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
for k, d in chain(self.full_scorers.items(),
self.part_scorers.items()):
s = d.final_score(hyp.states[k])
hyp.scores[k] += s
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
......@@ -475,19 +478,18 @@ class BeamSearch(paddle.nn.Layer):
def beam_search(
x: paddle.Tensor,
sos: int,
eos: int,
beam_size: int,
vocab_size: int,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
token_list: List[str] = None,
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
pre_beam_ratio: float = 1.5,
pre_beam_score_key: str = "full",
) -> list:
x: paddle.Tensor,
sos: int,
eos: int,
beam_size: int,
vocab_size: int,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
token_list: List[str]=None,
maxlenratio: float=0.0,
minlenratio: float=0.0,
pre_beam_ratio: float=1.5,
pre_beam_score_key: str="full", ) -> list:
"""Perform beam search with scorers.
Args:
......@@ -523,6 +525,6 @@ def beam_search(
pre_beam_score_key=pre_beam_score_key,
sos=sos,
eos=eos,
token_list=token_list,
).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
token_list=token_list, ).forward(
x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
return [h.asdict() for h in ret]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import json
from pathlib import Path
import jsonlines
import paddle
import yaml
from yacs.config import CfgNode
from .beam_search import BatchBeamSearch
from .beam_search import BeamSearch
from .scorers.length_bonus import LengthBonus
from .scorers.scorer_interface import BatchScorerInterface
from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
# from espnet.nets.lm_interface import dynamic_import_lm
logger = Log(__name__).getlog()
# NOTE: you need this func to generate our sphinx doc
def load_trained_model(args):
args.nprocs = args.ngpu
confs = CfgNode()
confs.set_new_allowed(True)
confs.merge_from_file(args.model_conf)
class_obj = dynamic_import_tester(args.model_name)
exp = class_obj(confs, args)
with exp.eval():
exp.setup()
exp.restore()
char_list = exp.args.char_list
model = exp.model
return model, char_list, exp, confs
def recog_v2(args):
"""Decode with custom models that implements ScorerInterface.
Args:
args (namespace): The program arguments.
See py:func:`bin.asr_recog.get_parser` for details
"""
logger.warning("experimental API for custom LMs is selected by --api v2")
if args.batchsize > 1:
raise NotImplementedError("multi-utt batch decoding is not implemented")
if args.streaming_mode is not None:
raise NotImplementedError("streaming mode is not implemented")
if args.word_rnnlm:
raise NotImplementedError("word LM is not implemented")
# set_deterministic(args)
model, char_list, exp, confs = load_trained_model(args)
assert isinstance(model, ASRInterface)
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=confs.collator.augmentation_config
if args.preprocess_conf is None else args.preprocess_conf,
preprocess_args={"train": False},
)
if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
lm = lm_class(len(char_list), lm_args)
torch_load(args.rnnlm, lm)
lm.eval()
else:
lm = None
if args.ngram_model:
from .scorers.ngram import NgramFullScorer
from .scorers.ngram import NgramPartScorer
if args.ngram_scorer == "full":
ngram = NgramFullScorer(args.ngram_model, char_list)
else:
ngram = NgramPartScorer(args.ngram_model, char_list)
else:
ngram = None
scorers = model.scorers() # decoder
scorers["lm"] = lm
scorers["ngram"] = ngram
scorers["length_bonus"] = LengthBonus(len(char_list))
weights = dict(
decoder=1.0 - args.ctc_weight,
ctc=args.ctc_weight,
lm=args.lm_weight,
ngram=args.ngram_weight,
length_bonus=args.penalty,
)
beam_search = BeamSearch(
beam_size=args.beam_size,
vocab_size=len(char_list),
weights=weights,
scorers=scorers,
sos=model.sos,
eos=model.eos,
token_list=char_list,
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
)
# TODO(karita): make all scorers batchfied
if args.batchsize == 1:
non_batch = [
k for k, v in beam_search.full_scorers.items()
if not isinstance(v, BatchScorerInterface)
]
if len(non_batch) == 0:
beam_search.__class__ = BatchBeamSearch
logger.info("BatchBeamSearch implementation is selected.")
else:
logger.warning(f"As non-batch scorers {non_batch} are found, "
f"fall back to non-batch implementation.")
if args.ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
if args.ngpu == 1:
device = "gpu:0"
else:
device = "cpu"
paddle.set_device(device)
dtype = getattr(paddle, args.dtype)
logger.info(f"Decoding device={device}, dtype={dtype}")
model.to(device=device, dtype=dtype)
model.eval()
beam_search.to(device=device, dtype=dtype)
beam_search.eval()
# read json data
js = []
with jsonlines.open(args.recog_json, "r") as reader:
for item in reader:
js.append(item)
# jsonlines to dict, key by 'utt', value by jsonline
js = {item['utt']: item for item in js}
new_js = {}
with paddle.no_grad():
with jsonlines.open(args.result_label, "w") as f:
for idx, name in enumerate(js.keys(), 1):
logger.info(f"({idx}/{len(js.keys())}) decoding " + name)
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
logger.info(f'feat: {feat.shape}')
enc = model.encode(paddle.to_tensor(feat).to(dtype))
logger.info(f'eout: {enc.shape}')
nbest_hyps = beam_search(x=enc,
maxlenratio=args.maxlenratio,
minlenratio=args.minlenratio)
nbest_hyps = [
h.asdict()
for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)]
]
new_js[name] = add_results_to_json(js[name], nbest_hyps,
char_list)
item = new_js[name]['output'][0] # 1-best
ref = item['text']
rec_text = item['rec_text'].replace('▁',
' ').replace('<eos>',
'').strip()
rec_tokenid = list(map(int, item['rec_tokenid'].split()))
f.write({
"utt": name,
"refs": [ref],
"hyps": [rec_text],
"hyps_tokenid": [rec_tokenid],
})
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end speech recognition model decoding script."""
import logging
import os
import random
import sys
from distutils.util import strtobool
import configargparse
import numpy as np
from .recog import recog_v2
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description="Transcribe text from speech using "
"a speech recognition model on one CPU or GPU",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter, )
parser.add(
'--model-name',
type=str,
default='u2_kaldi',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
# general configuration
parser.add("--config", is_config_file=True, help="Config file path")
parser.add(
"--config2",
is_config_file=True,
help="Second config file path that overwrites the settings in `--config`",
)
parser.add(
"--config3",
is_config_file=True,
help="Third config file path that overwrites the settings "
"in `--config` and `--config2`", )
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
parser.add_argument(
"--dtype",
choices=("float16", "float32", "float64"),
default="float32",
help="Float precision (only available in --api v2)", )
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument(
"--verbose", "-V", type=int, default=2, help="Verbose option")
parser.add_argument(
"--batchsize",
type=int,
default=1,
help="Batch size for beam search (0: means no batch processing)", )
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing", )
parser.add_argument(
"--api",
default="v2",
choices=["v2"],
help="Beam search APIs "
"v2: Experimental API. It supports any models that implements ScorerInterface.",
)
# task related
parser.add_argument(
"--recog-json", type=str, help="Filename of recognition data (json)")
parser.add_argument(
"--result-label",
type=str,
required=True,
help="Filename of result label data (json)", )
# model (parameter) related
parser.add_argument(
"--model",
type=str,
required=True,
help="Model file parameters to read")
parser.add_argument(
"--model-conf", type=str, default=None, help="Model config file")
parser.add_argument(
"--num-spkrs",
type=int,
default=1,
choices=[1, 2],
help="Number of speakers in the speech", )
parser.add_argument(
"--num-encs",
default=1,
type=int,
help="Number of encoders in the model.")
# search related
parser.add_argument(
"--nbest", type=int, default=1, help="Output N-best hypotheses")
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
parser.add_argument(
"--penalty", type=float, default=0.0, help="Incertion penalty")
parser.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths.
If maxlenratio<0.0, its absolute value is interpreted
as a constant max output length""", )
parser.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length", )
parser.add_argument(
"--ctc-weight",
type=float,
default=0.0,
help="CTC weight in joint decoding")
parser.add_argument(
"--weights-ctc-dec",
type=float,
action="append",
help="ctc weight assigned to each encoder during decoding."
"[in multi-encoder mode only]", )
parser.add_argument(
"--ctc-window-margin",
type=int,
default=0,
help="""Use CTC window with margin parameter to accelerate
CTC/attention decoding especially on GPU. Smaller magin
makes decoding faster, but may increase search errors.
If margin=0 (default), this function is disabled""", )
# transducer related
parser.add_argument(
"--search-type",
type=str,
default="default",
choices=["default", "nsc", "tsd", "alsd", "maes"],
help="""Type of beam search implementation to use during inference.
Can be either: default beam search ("default"),
N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"),
Alignment-Length Synchronous Decoding ("alsd") or
modified Adaptive Expansion Search ("maes").""", )
parser.add_argument(
"--nstep",
type=int,
default=1,
help="""Number of expansion steps allowed in NSC beam search or mAES
(nstep > 0 for NSC and nstep > 1 for mAES).""", )
parser.add_argument(
"--prefix-alpha",
type=int,
default=2,
help="Length prefix difference allowed in NSC beam search or mAES.", )
parser.add_argument(
"--max-sym-exp",
type=int,
default=2,
help="Number of symbol expansions allowed in TSD.", )
parser.add_argument(
"--u-max",
type=int,
default=400,
help="Length prefix difference allowed in ALSD.", )
parser.add_argument(
"--expansion-gamma",
type=float,
default=2.3,
help="Allowed logp difference for prune-by-value method in mAES.", )
parser.add_argument(
"--expansion-beta",
type=int,
default=2,
help="""Number of additional candidates for expanded hypotheses
selection in mAES.""", )
parser.add_argument(
"--score-norm",
type=strtobool,
nargs="?",
default=True,
help="Normalize final hypotheses' score by length", )
parser.add_argument(
"--softmax-temperature",
type=float,
default=1.0,
help="Penalization term for softmax function.", )
# rnnlm related
parser.add_argument(
"--rnnlm", type=str, default=None, help="RNNLM model file to read")
parser.add_argument(
"--rnnlm-conf",
type=str,
default=None,
help="RNNLM model config file to read")
parser.add_argument(
"--word-rnnlm",
type=str,
default=None,
help="Word RNNLM model file to read")
parser.add_argument(
"--word-rnnlm-conf",
type=str,
default=None,
help="Word RNNLM model config file to read", )
parser.add_argument(
"--word-dict", type=str, default=None, help="Word list to read")
parser.add_argument(
"--lm-weight", type=float, default=0.1, help="RNNLM weight")
# ngram related
parser.add_argument(
"--ngram-model",
type=str,
default=None,
help="ngram model file to read")
parser.add_argument(
"--ngram-weight", type=float, default=0.1, help="ngram weight")
parser.add_argument(
"--ngram-scorer",
type=str,
default="part",
choices=("full", "part"),
help="""if the ngram is set as a part scorer, similar with CTC scorer,
ngram scorer only scores topK hypethesis.
if the ngram is set as full scorer, ngram scorer scores all hypthesis
the decoding speed of part scorer is musch faster than full one""",
)
# streaming related
parser.add_argument(
"--streaming-mode",
type=str,
default=None,
choices=["window", "segment"],
help="""Use streaming recognizer for inference.
`--batchsize` must be set to 0 to enable this mode""", )
parser.add_argument(
"--streaming-window", type=int, default=10, help="Window size")
parser.add_argument(
"--streaming-min-blank-dur",
type=int,
default=10,
help="Minimum blank duration threshold", )
parser.add_argument(
"--streaming-onset-margin", type=int, default=1, help="Onset margin")
parser.add_argument(
"--streaming-offset-margin", type=int, default=1, help="Offset margin")
# non-autoregressive related
# Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
parser.add_argument(
"--maskctc-n-iterations",
type=int,
default=10,
help="Number of decoding iterations."
"For Mask CTC, set 0 to predict 1 mask/iter.", )
parser.add_argument(
"--maskctc-probability-threshold",
type=float,
default=0.999,
help="Threshold probability for CTC output", )
# quantize model related
parser.add_argument(
"--quantize-config",
nargs="*",
help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]",
)
parser.add_argument(
"--quantize-dtype",
type=str,
default="qint8",
help="Dtype dynamic quantize")
parser.add_argument(
"--quantize-asr-model",
type=bool,
default=False,
help="Quantize asr model", )
parser.add_argument(
"--quantize-lm-model",
type=bool,
default=False,
help="Quantize lm model", )
return parser
def main(args):
"""Run the main decoding function."""
parser = get_parser()
parser.add_argument(
"--output", metavar="CKPT_DIR", help="path to save checkpoint.")
parser.add_argument(
"--checkpoint_path", type=str, help="path to load checkpoint")
parser.add_argument("--dict-path", type=str, help="path to load checkpoint")
args = parser.parse_args(args)
if args.ngpu == 0 and args.dtype == "float16":
raise ValueError(
f"--dtype {args.dtype} does not support the CPU backend.")
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
elif args.verbose == 2:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
logging.info(args)
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(mn5k): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info("set random seed = %d" % args.seed)
# validate rnn options
if args.rnnlm is not None and args.word_rnnlm is not None:
logging.error(
"It seems that both --rnnlm and --word-rnnlm are specified. "
"Please use either option.")
sys.exit(1)
# recog
if args.num_spkrs == 1:
if args.num_encs == 1:
# Experimental API that supports custom LMs
if args.api == "v2":
from deepspeech.decoders.recog import recog_v2
recog_v2(args)
else:
raise ValueError("Only support --api v2")
else:
if args.api == "v2":
raise NotImplementedError(
f"--num-encs {args.num_encs} > 1 is not supported in --api v2"
)
elif args.num_spkrs == 2:
raise ValueError("asr_mix not supported.")
if __name__ == "__main__":
main(sys.argv[1:])
......@@ -15,8 +15,8 @@
import numpy as np
import paddle
from .ctc_prefix_score import CTCPrefixScorer
from .ctc_prefix_score import CTCPrefixScorerPD
from .ctc_prefix_score import CTCPrefixScore
from .ctc_prefix_score import CTCPrefixScorePD
from .scorer_interface import BatchPartialScorerInterface
......
......@@ -6,7 +6,7 @@ import paddle
import six
class CTCPrefixScorerPD():
class CTCPrefixScorePD():
"""Batch processing of CTCPrefixScore
which is based on Algorithm 2 in WATANABE et al.
......@@ -267,7 +267,7 @@ class CTCPrefixScorerPD():
return (r_prev_new, s_prev, f_min_prev, f_max_prev)
class CTCPrefixScorer():
class CTCPrefixScore():
"""Compute CTC label sequence scores
which is based on Algorithm 2 in WATANABE et al.
......
......@@ -85,8 +85,9 @@ class NgramFullScorer(Ngrambase, BatchScorerInterface):
and next state list for ys.
"""
return self.score_partial_(
y, paddle.to_tensor(range(self.charlen)), state, x)
return self.score_partial_(y,
paddle.to_tensor(range(self.charlen)), state,
x)
class NgramPartScorer(Ngrambase, PartialScorerInterface):
......
......@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
__all__ = ["end_detect"]
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["end_detect", "parse_hypothesis", "add_results_to_json"]
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
......@@ -47,3 +51,79 @@ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
return True
else:
return False
# * ------------------ recognition related ------------------ *
def parse_hypothesis(hyp, char_list):
"""Parse hypothesis.
Args:
hyp (list[dict[str, Any]]): Recognition hypothesis.
char_list (list[str]): List of characters.
Returns:
tuple(str, str, str, float)
"""
# remove sos and get results
tokenid_as_list = list(map(int, hyp["yseq"][1:]))
token_as_list = [char_list[idx] for idx in tokenid_as_list]
score = float(hyp["score"])
# convert to string
tokenid = " ".join([str(idx) for idx in tokenid_as_list])
token = " ".join(token_as_list)
text = "".join(token_as_list).replace("<space>", " ")
return text, token, tokenid, score
def add_results_to_json(js, nbest_hyps, char_list):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers: nutts x nspkrs.
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js = dict()
new_js["utt2spk"] = js["utt2spk"]
new_js["output"] = []
for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp,
char_list)
# copy ground-truth
if len(js["output"]) > 0:
out_dic = dict(js["output"][0].items())
else:
# for no reference case (e.g., speech translation)
out_dic = {"name": ""}
# update name
out_dic["name"] += "[%d]" % n
# add recognition results
out_dic["rec_text"] = rec_text
out_dic["rec_token"] = rec_token
out_dic["rec_tokenid"] = rec_tokenid
out_dic["score"] = score
# add to list of N-best result dicts
new_js["output"].append(out_dic)
# show 1-best result
if n == 1:
if "text" in out_dic.keys():
logger.info("groundtruth: %s" % out_dic["text"])
logger.info("prediction : %s" % out_dic["rec_text"])
return new_js
......@@ -11,3 +11,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepspeech.training.trainer import Trainer
from deepspeech.utils.dynamic_import import dynamic_import
model_trainer_alias = {
"ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Trainer",
"u2": "deepspeech.exps.u2.model:U2Trainer",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer",
"u2_st": "deepspeech.exps.u2_st.model:U2STTrainer",
}
def dynamic_import_trainer(module):
"""Import Trainer dynamically.
Args:
module (str): trainer name. e.g., ds2, u2, u2_kaldi
Returns:
type: Trainer class
"""
model_class = dynamic_import(module, model_trainer_alias)
assert issubclass(model_class,
Trainer), f"{module} does not implement Trainer"
return model_class
model_tester_alias = {
"ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Tester",
"u2": "deepspeech.exps.u2.model:U2Tester",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester",
"u2_st": "deepspeech.exps.u2_st.model:U2STTester",
}
def dynamic_import_tester(module):
"""Import Tester dynamically.
Args:
module (str): tester name. e.g., ds2, u2, u2_kaldi
Returns:
type: Tester class
"""
model_class = dynamic_import(module, model_tester_alias)
assert issubclass(model_class,
Trainer), f"{module} does not implement Tester"
return model_class
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from deepspeech.decoders.recog_bin import main
if __name__ == "__main__":
main(sys.argv[1:])
......@@ -317,11 +317,9 @@ class U2Trainer(Trainer):
with UpdateConfig(model_conf):
model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
# lr
......@@ -436,8 +434,9 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time
for i, (utt, target, result, rec_tids) in enumerate(zip(
utts, target_transcripts, result_transcripts, result_tokenids)):
for i, (utt, target, result, rec_tids) in enumerate(
zip(utts, target_transcripts, result_transcripts,
result_tokenids)):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
......
......@@ -140,7 +140,7 @@ class TextFeaturizer():
Returns:
str: text string.
"""
tokens = [t.replace(SPACE, " ") for t in tokens ]
tokens = [t.replace(SPACE, " ") for t in tokens]
return "".join(tokens)
def word_tokenize(self, text):
......@@ -207,7 +207,7 @@ class TextFeaturizer():
"""Load vocabulary from file."""
vocab_list = load_dict(vocab_filepath, maskctc)
assert vocab_list is not None
logger.info(f"Vocab: {pformat(vocab_list)}")
logger.debug(f"Vocab: {pformat(vocab_list)}")
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ASR Interface module."""
import argparse
from deepspeech.utils.dynamic_import import dynamic_import
class ASRInterface:
"""ASR Interface for ESPnet model implementation."""
@staticmethod
def add_arguments(parser):
"""Add arguments to parser."""
return parser
@classmethod
def build(cls, idim: int, odim: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
args = argparse.Namespace(**kwargs)
return cls(idim, odim, args)
def forward(self, xs, ilens, ys, olens):
"""Compute loss for training.
:param xs: batch of padded source sequences paddle.Tensor (B, Tmax, idim)
:param ilens: batch of lengths of source sequences (B), paddle.Tensor
:param ys: batch of padded target sequences paddle.Tensor (B, Lmax)
:param olens: batch of lengths of target sequences (B), paddle.Tensor
:return: loss value
:rtype: paddle.Tensor
"""
raise NotImplementedError("forward method is not implemented")
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace recog_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("recognize method is not implemented")
def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace recog_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
def calculate_all_attentions(self, xs, ilens, ys):
"""Calculate attention.
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise NotImplementedError(
"calculate_all_attentions method is not implemented")
def calculate_all_ctc_probs(self, xs, ilens, ys):
"""Calculate CTC probability.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: CTC probabilities (B, Tmax, vocab)
:rtype: float ndarray
"""
raise NotImplementedError(
"calculate_all_ctc_probs method is not implemented")
@property
def attention_plot_class(self):
"""Get attention plot class."""
from espnet.asr.asr_utils import PlotAttentionReport
return PlotAttentionReport
@property
def ctc_plot_class(self):
"""Get CTC plot class."""
from espnet.asr.asr_utils import PlotCTCReport
return PlotCTCReport
def get_total_subsampling_factor(self):
"""Get total subsampling factor."""
raise NotImplementedError(
"get_total_subsampling_factor method is not implemented")
def encode(self, feat):
"""Encode feature in `beam_search` (optional).
Args:
x (numpy.ndarray): input feature (T, D)
Returns:
paddle.Tensor: encoded feature (T, D)
"""
raise NotImplementedError("encode method is not implemented")
def scorers(self):
"""Get scorers for `beam_search` (optional).
Returns:
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
"""
raise NotImplementedError("decoders method is not implemented")
predefined_asr = {
"transformer": "deepspeech.models.u2:U2Model",
"conformer": "deepspeech.models.u2:U2Model",
}
def dynamic_import_asr(module):
"""Import ASR models dynamically.
Args:
module (str): asr name. e.g., transformer, conformer
Returns:
type: ASR class
"""
model_class = dynamic_import(module, predefined_asr)
assert issubclass(model_class,
ASRInterface), f"{module} does not implement ASRInterface"
return model_class
......@@ -28,8 +28,10 @@ from paddle import jit
from paddle import nn
from yacs.config import CfgNode
from deepspeech.decoders.scorers.ctc import CTCPrefixScorer
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.frontend.utility import load_cmvn
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.modules.cmvn import GlobalCMVN
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.decoder import TransformerDecoder
......@@ -55,7 +57,7 @@ __all__ = ["U2Model", "U2InferModel"]
logger = Log(__name__).getlog()
class U2BaseModel(nn.Layer):
class U2BaseModel(ASRInterface, nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model"""
@classmethod
......@@ -120,7 +122,7 @@ class U2BaseModel(nn.Layer):
**kwargs):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__()
nn.Layer.__init__(self)
# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
......@@ -813,7 +815,27 @@ class U2BaseModel(nn.Layer):
return res, res_tokenids
class U2Model(U2BaseModel):
class U2DecodeModel(U2BaseModel):
def scorers(self):
"""Scorers."""
return dict(
decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
def encode(self, x):
"""Encode acoustic features.
:param ndarray x: source acoustic feature (T, D)
:return: encoder outputs
:rtype: paddle.Tensor
"""
self.eval()
x = paddle.to_tensor(x).unsqueeze(0)
ilen = x.size(1)
enc_output, _ = self._forward_encoder(x, ilen)
return enc_output.squeeze(0)
class U2Model(U2DecodeModel):
def __init__(self, configs: dict):
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder definition."""
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
......@@ -20,10 +21,12 @@ import paddle
from paddle import nn
from typeguard import check_argument_types
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.decoder_layer import DecoderLayer
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.modules.mask import make_xs_mask
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward
from deepspeech.utils.log import Log
......@@ -33,7 +36,7 @@ logger = Log(__name__).getlog()
__all__ = ["TransformerDecoder"]
class TransformerDecoder(nn.Layer):
class TransformerDecoder(BatchScorerInterface, nn.Layer):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
......@@ -71,7 +74,8 @@ class TransformerDecoder(nn.Layer):
concat_after: bool=False, ):
assert check_argument_types()
super().__init__()
nn.Layer.__init__(self)
self.selfattention_layer_type = 'selfattn'
attention_dim = encoder_output_size
if input_layer == "embed":
......@@ -180,3 +184,63 @@ class TransformerDecoder(nn.Layer):
if self.use_output_layer:
y = paddle.log_softmax(self.output_layer(y), axis=-1)
return y, new_cache
# beam search API (see ScorerInterface)
def score(self, ys, state, x):
"""Score.
ys: (ylen,)
x: (xlen, n_feat)
"""
ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L)
x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T)
if self.selfattention_layer_type != "selfattn":
# TODO(karita): implement cache
logging.warning(
f"{self.selfattention_layer_type} does not support cached decoding."
)
state = None
logp, state = self.forward_one_step(
x.unsqueeze(0), x_mask, ys.unsqueeze(0), ys_mask, cache=state)
return logp.squeeze(0), state
# batch beam search API (see BatchScorerInterface)
def batch_score(self,
ys: paddle.Tensor,
states: List[Any],
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (paddle.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[paddle.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
paddle.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L)
xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T)
logp, states = self.forward_one_step(
xs, xs_mask, ys, ys_mask, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)]
for b in range(n_batch)]
return logp, state_list
......@@ -18,12 +18,25 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"make_pad_mask", "make_non_pad_mask", "subsequent_mask",
"make_xs_mask", "make_pad_mask", "make_non_pad_mask", "subsequent_mask",
"subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores",
"mask_finished_preds"
]
def make_xs_mask(xs: paddle.Tensor, pad_value=0.0) -> paddle.Tensor:
"""Maks mask tensor containing indices of non-padded part.
Args:
xs (paddle.Tensor): (B, T, D), zeros for pad.
Returns:
paddle.Tensor: Mask Tensor indices of non-padded part. (B, T)
"""
pad_frame = paddle.full([1, 1, xs.shape[-1]], pad_value, dtype=xs.dtype)
mask = xs != pad_frame
mask = mask.all(axis=-1)
return mask
def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
......@@ -31,6 +44,7 @@ def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: Mask tensor containing indices of padded part.
(B, T)
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
......@@ -62,6 +76,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: mask tensor containing indices of padded part.
(B, T)
Examples:
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
......
......@@ -35,7 +35,7 @@ class LoadFromFile(argparse.Action):
parser.parse_args(f.read().split(), namespace)
def default_argument_parser():
def default_argument_parser(parser=None):
r"""A simple yet genral argument parser for experiments with parakeet.
This is used in examples with parakeet. And it is intended to be used by
......@@ -62,7 +62,9 @@ def default_argument_parser():
argparse.ArgumentParser
the parser
"""
parser = argparse.ArgumentParser()
if parser is None:
parser = argparse.ArgumentParser()
parser.register('action', 'extend', ExtendAction)
parser.add_argument(
'--conf', type=open, action=LoadFromFile, help="config file.")
......
......@@ -126,7 +126,8 @@ class Trainer():
logger.info(f"Set seed {args.seed}")
# profiler and benchmark options
if self.args.benchmark_batch_size:
if hasattr(self.args,
"benchmark_batch_size") and self.args.benchmark_batch_size:
with UpdateConfig(self.config):
self.config.collator.batch_size = self.args.benchmark_batch_size
self.config.training.log_interval = 1
......@@ -326,12 +327,24 @@ class Trainer():
finally:
self.destory()
def restore(self):
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
assert self.args.checkpoint_path
infos = self.checkpoint.load_latest_parameters(
self.model, checkpoint_path=self.args.checkpoint_path)
return infos
def run_test(self):
"""Do Test/Decode"""
try:
with Timer("Test/Decode Done: {}"):
with self.eval():
self.resume_or_scratch()
self.restore()
self.test()
except KeyboardInterrupt:
exit(-1)
......@@ -341,6 +354,7 @@ class Trainer():
try:
with Timer("Export Done: {}"):
with self.eval():
self.restore()
self.export()
except KeyboardInterrupt:
exit(-1)
......@@ -350,7 +364,7 @@ class Trainer():
try:
with Timer("Align Done: {}"):
with self.eval():
self.resume_or_scratch()
self.restore()
self.align()
except KeyboardInterrupt:
sys.exit(-1)
......
# ASR
* s0 is for deepspeech2 offline
* s1 is for transformer/conformer/U2
* s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi
* s0 is for deepspeech2 offline
* s1 is for transformer/conformer/U2
* s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi
## Data
| Data Subset | Duration in Seconds |
......
# LibriSpeech
## Transformer
| Model | Params | Config | Augmentation| Test Set | Decode Method | Loss | WER % |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.395054340362549 | 4.2 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.395054340362549 | 5.0 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.395054340362549 | |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescore | 6.395054340362549 | |
| Model | Params | Config | Augmentation| Loss |
| --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 |
| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 |
| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 |
| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 |
| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 |
| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 |
batchsize: 0
beam-size: 60
ctc-weight: 0.0
lm-weight: 0.0
maxlenratio: 0.0
minlenratio: 0.0
penalty: 0.0
batchsize: 0
beam-size: 60
ctc-weight: 0.4
lm-weight: 0.6
maxlenratio: 0.0
minlenratio: 0.0
penalty: 0.0
\ No newline at end of file
batchsize: 0
beam-size: 60
ctc-weight: 0.4
lm-weight: 0.0
maxlenratio: 0.0
minlenratio: 0.0
penalty: 0.0
\ No newline at end of file
#!/bin/bash
set -e
expdir=exp
datadir=data
nj=32
tag=
# decode config
decode_config=conf/decode/decode.yaml
# lm params
lang_model=rnnlm.model.best
lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/
lmtag='nolm'
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
# bpemode (unigram or bpe)
nbpe=5000
bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
bpemodel=${bpeprefix}.model
# bin params
config_path=conf/transformer.yaml
dict=data/bpe_unigram_5000_units.txt
ckpt_prefix=
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
if [ -z ${ckpt_prefix} ]; then
echo "usage: $0 --ckpt_prefix ckpt_prefix"
exit 1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
ckpt_dir=$(dirname `dirname ${ckpt_prefix}`)
echo "ckpt dir: ${ckpt_dir}"
ckpt_tag=$(basename ${ckpt_prefix})
echo "ckpt tag: ${ckpt_tag}"
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
echo "chunk mode: ${chunk_mode}"
echo "decode conf: ${decode_config}"
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
pids=() # initialize pids
for dmethd in join_ctc; do
(
echo "${dmethd} decoding"
for rtask in ${recog_set}; do
(
echo "${rtask} dataset"
decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag}_${tag}
feat_recog_dir=${datadir}
mkdir -p ${decode_dir}
mkdir -p ${feat_recog_dir}
# split data
split_json.sh manifest.${rtask} ${nj}
#### use CPU for decoding
ngpu=0
# set batchsize 0 to disable batch decoding
${decode_cmd} JOB=1:${nj} ${decode_dir}/log/decode.JOB.log \
python3 -u ${BIN_DIR}/recog.py \
--api v2 \
--config ${decode_config} \
--ngpu ${ngpu} \
--batchsize 0 \
--checkpoint_path ${ckpt_prefix} \
--dict-path ${dict} \
--recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \
--result-label ${decode_dir}/data.JOB.json \
--model-conf ${config_path} \
--model ${ckpt_prefix}.pdparams
#--rnnlm ${lmexpdir}/${lang_model} \
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict}
) &
pids+=($!) # store background pids
i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." || true
done
)
done
echo "Finished"
exit 0
......@@ -6,7 +6,7 @@ expdir=exp
datadir=data
nj=32
lmtag=
lmtag='nolm'
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
......@@ -17,23 +17,31 @@ bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}"
bpemodel=${bpeprefix}.model
if [ $# != 3 ];then
echo "usage: ${0} config_path dict_path ckpt_path_prefix"
exit -1
config_path=conf/transformer.yaml
dict=data/bpe_unigram_5000_units.txt
ckpt_prefix=
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
if [ -z ${ckpt_prefix} ]; then
echo "usage: $0 --ckpt_prefix ckpt_prefix"
exit 1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
dict=$2
ckpt_prefix=$3
ckpt_dir=$(dirname `dirname ${ckpt_prefix}`)
echo "ckpt dir: ${ckpt_dir}"
ckpt_tag=$(basename ${ckpt_prefix})
echo "ckpt tag: ${ckpt_tag}"
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
echo "chunk mode ${chunk_mode}"
echo "chunk mode: ${chunk_mode}"
# download language model
......@@ -46,13 +54,13 @@ pids=() # initialize pids
for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
(
echo "${dmethd} decoding"
echo "decode method: ${dmethd}"
for rtask in ${recog_set}; do
(
echo "${rtask} dataset"
decode_dir=decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}
echo "dataset: ${rtask}"
decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag}
feat_recog_dir=${datadir}
mkdir -p ${expdir}/${decode_dir}
mkdir -p ${decode_dir}
mkdir -p ${feat_recog_dir}
# split data
......@@ -63,7 +71,7 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco
# set batchsize 0 to disable batch decoding
batch_size=1
${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \
${decode_cmd} JOB=1:${nj} ${decode_dir}/log/decode.JOB.log \
python3 -u ${BIN_DIR}/test.py \
--model-name u2_kaldi \
--run-mode test \
......@@ -71,12 +79,12 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco
--dict-path ${dict} \
--config ${config_path} \
--checkpoint_path ${ckpt_prefix} \
--result-file ${expdir}/${decode_dir}/data.JOB.json \
--result-file ${decode_dir}/data.JOB.json \
--opts decoding.decoding_method ${dmethd} \
--opts decoding.batch_size ${batch_size} \
--opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${expdir}/${decode_dir} ${dict}
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict}
) &
pids+=($!) # store background pids
......
#!/bin/bash
set -e
. ./path.sh || exit 1;
......@@ -7,8 +8,9 @@ set -e
stage=0
stop_stage=100
conf_path=conf/transformer.yaml
dict_path=data/train_960_unigram5000_units.txt
dict_path=data/bpe_unigram_5000_units.txt
avg_num=10
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
......
......@@ -11,20 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import io
import os
import re
import subprocess as sp
import sys
from pathlib import Path
import contextlib
import inspect
from setuptools import Command
from setuptools import find_packages
from setuptools import setup
from setuptools import Command
from setuptools.command.develop import develop
from setuptools.command.install import install
import subprocess as sp
HERE = Path(os.path.abspath(os.path.dirname(__file__)))
......@@ -40,16 +40,18 @@ def pushd(new_dir):
def read(*names, **kwargs):
with io.open(os.path.join(os.path.dirname(__file__), *names),
encoding=kwargs.get("encoding", "utf8")) as fp:
with io.open(
os.path.join(os.path.dirname(__file__), *names),
encoding=kwargs.get("encoding", "utf8")) as fp:
return fp.read()
def check_call(cmd: str, shell=False, executable=None):
try:
sp.check_call(cmd.split(),
shell=shell,
executable="/bin/bash" if shell else executable)
sp.check_call(
cmd.split(),
shell=shell,
executable="/bin/bash" if shell else executable)
except sp.CalledProcessError as e:
print(
f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:",
......@@ -189,7 +191,6 @@ setup_info = dict(
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
)
], )
setup(**setup_info)
......@@ -25,8 +25,8 @@ def main(args):
paddle.set_device('cpu')
val_scores = []
beat_val_scores = []
selected_epochs = []
beat_val_scores = None
selected_epochs = None
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
......@@ -80,9 +80,10 @@ def main(args):
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"ckpt": path_list,
"epoch": selected_epochs.tolist(),
"val_loss": beat_val_scores.tolist(),
"val_loss_mean": np.mean(beat_val_scores),
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
"val_losses": beat_val_scores.tolist(),
})
f.write(data + "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册