未验证 提交 c67bf7b4 编写于 作者: Z Zth9730 提交者: GitHub

[ASR] support wav2vec2-zh cli, test=asr (#2697)

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr
上级 a01c163d
......@@ -25,6 +25,7 @@ import librosa
import numpy as np
import paddle
import soundfile
from paddlenlp.transformers import AutoTokenizer
from yacs.config import CfgNode
from ..executor import BaseExecutor
......@@ -50,7 +51,7 @@ class SSLExecutor(BaseExecutor):
self.parser.add_argument(
'--model',
type=str,
default='wav2vec2ASR_librispeech',
default=None,
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
......@@ -123,7 +124,7 @@ class SSLExecutor(BaseExecutor):
help='Increase logger verbosity of current task.')
def _init_from_path(self,
model_type: str='wav2vec2ASR_librispeech',
model_type: str=None,
task: str='asr',
lang: str='en',
sample_rate: int=16000,
......@@ -134,6 +135,18 @@ class SSLExecutor(BaseExecutor):
Init model and other resources from a specific path.
"""
logger.debug("start to init the model")
if model_type is None:
if lang == 'en':
model_type = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_type = 'wav2vec2ASR_aishell1'
else:
logger.error(
"invalid lang, please input --lang en or --lang zh")
logger.debug(
"Model type had not been specified, default {} was used.".
format(model_type))
# default max_len: unit:second
self.max_len = 50
if hasattr(self, 'model'):
......@@ -167,9 +180,13 @@ class SSLExecutor(BaseExecutor):
self.config.merge_from_file(self.cfg_path)
if task == 'asr':
with UpdateConfig(self.config):
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath)
if lang == 'en':
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath)
elif lang == 'zh':
self.text_feature = AutoTokenizer.from_pretrained(
self.config.tokenizer)
self.config.decode.decoding_method = decode_method
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
......@@ -253,7 +270,8 @@ class SSLExecutor(BaseExecutor):
audio,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size)
beam_size=cfg.beam_size,
tokenizer=getattr(self.config, 'tokenizer', None))
self._outputs["result"] = result_transcripts[0][0]
except Exception as e:
logger.exception(e)
......@@ -413,7 +431,7 @@ class SSLExecutor(BaseExecutor):
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
model: str='wav2vec2ASR_librispeech',
model: str=None,
task: str='asr',
lang: str='en',
sample_rate: int=16000,
......
......@@ -70,6 +70,38 @@ ssl_dynamic_pretrained_models = {
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
"wav2vec2-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2-large-wenetspeech-self_ckpt_1.3.0.model.tar.gz',
'md5':
'00ea4975c05d1bb58181205674052fe1',
'cfg_path':
'model.yaml',
'ckpt_path':
'chinese-wav2vec2-large',
'model':
'chinese-wav2vec2-large.pdparams',
'params':
'chinese-wav2vec2-large.pdparams',
},
},
"wav2vec2ASR_aishell1-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.3.0.model.tar.gz',
'md5':
'ac8fa0a6345e6a7535f6fabb5e59e218',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
}
# ---------------------------------
......
......@@ -1173,10 +1173,6 @@ class Wav2Vec2ConfigPure():
self.proj_codevector_dim = config.proj_codevector_dim
self.diversity_loss_weight = config.diversity_loss_weight
# ctc loss
self.ctc_loss_reduction = config.ctc_loss_reduction
self.ctc_zero_infinity = config.ctc_zero_infinity
# adapter
self.add_adapter = config.add_adapter
self.adapter_kernel_size = config.adapter_kernel_size
......
......@@ -76,28 +76,66 @@ class Wav2vec2ASR(nn.Layer):
feats: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
beam_size: int):
beam_size: int,
tokenizer: str=None):
batch_size = feats.shape[0]
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1'
raise ValueError(
f"decoding mode {decoding_method} must be running with batch_size == 1"
)
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
if decoding_method == 'ctc_greedy_search':
hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
if tokenizer is None:
hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
else:
hyps = self.ctc_greedy_search(feats)
res = []
res_tokenids = []
for sequence in hyps:
# Decode token terms to words
predicted_tokens = text_feature.convert_ids_to_tokens(
sequence)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
if tokenizer is None:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
else:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = []
res_tokenids = []
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
else:
raise ValueError(
f"wav2vec2 not support decoding method: {decoding_method}")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册