未验证 提交 c67bf7b4 编写于 作者: Z Zth9730 提交者: GitHub

[ASR] support wav2vec2-zh cli, test=asr (#2697)

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr
上级 a01c163d
...@@ -25,6 +25,7 @@ import librosa ...@@ -25,6 +25,7 @@ import librosa
import numpy as np import numpy as np
import paddle import paddle
import soundfile import soundfile
from paddlenlp.transformers import AutoTokenizer
from yacs.config import CfgNode from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
...@@ -50,7 +51,7 @@ class SSLExecutor(BaseExecutor): ...@@ -50,7 +51,7 @@ class SSLExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
default='wav2vec2ASR_librispeech', default=None,
choices=[ choices=[
tag[:tag.index('-')] tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys() for tag in self.task_resource.pretrained_models.keys()
...@@ -123,7 +124,7 @@ class SSLExecutor(BaseExecutor): ...@@ -123,7 +124,7 @@ class SSLExecutor(BaseExecutor):
help='Increase logger verbosity of current task.') help='Increase logger verbosity of current task.')
def _init_from_path(self, def _init_from_path(self,
model_type: str='wav2vec2ASR_librispeech', model_type: str=None,
task: str='asr', task: str='asr',
lang: str='en', lang: str='en',
sample_rate: int=16000, sample_rate: int=16000,
...@@ -134,6 +135,18 @@ class SSLExecutor(BaseExecutor): ...@@ -134,6 +135,18 @@ class SSLExecutor(BaseExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
logger.debug("start to init the model") logger.debug("start to init the model")
if model_type is None:
if lang == 'en':
model_type = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_type = 'wav2vec2ASR_aishell1'
else:
logger.error(
"invalid lang, please input --lang en or --lang zh")
logger.debug(
"Model type had not been specified, default {} was used.".
format(model_type))
# default max_len: unit:second # default max_len: unit:second
self.max_len = 50 self.max_len = 50
if hasattr(self, 'model'): if hasattr(self, 'model'):
...@@ -167,9 +180,13 @@ class SSLExecutor(BaseExecutor): ...@@ -167,9 +180,13 @@ class SSLExecutor(BaseExecutor):
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
if task == 'asr': if task == 'asr':
with UpdateConfig(self.config): with UpdateConfig(self.config):
self.text_feature = TextFeaturizer( if lang == 'en':
unit_type=self.config.unit_type, self.text_feature = TextFeaturizer(
vocab=self.config.vocab_filepath) unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath)
elif lang == 'zh':
self.text_feature = AutoTokenizer.from_pretrained(
self.config.tokenizer)
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
...@@ -253,7 +270,8 @@ class SSLExecutor(BaseExecutor): ...@@ -253,7 +270,8 @@ class SSLExecutor(BaseExecutor):
audio, audio,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size) beam_size=cfg.beam_size,
tokenizer=getattr(self.config, 'tokenizer', None))
self._outputs["result"] = result_transcripts[0][0] self._outputs["result"] = result_transcripts[0][0]
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
...@@ -413,7 +431,7 @@ class SSLExecutor(BaseExecutor): ...@@ -413,7 +431,7 @@ class SSLExecutor(BaseExecutor):
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, audio_file: os.PathLike,
model: str='wav2vec2ASR_librispeech', model: str=None,
task: str='asr', task: str='asr',
lang: str='en', lang: str='en',
sample_rate: int=16000, sample_rate: int=16000,
......
...@@ -70,6 +70,38 @@ ssl_dynamic_pretrained_models = { ...@@ -70,6 +70,38 @@ ssl_dynamic_pretrained_models = {
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
}, },
}, },
"wav2vec2-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2-large-wenetspeech-self_ckpt_1.3.0.model.tar.gz',
'md5':
'00ea4975c05d1bb58181205674052fe1',
'cfg_path':
'model.yaml',
'ckpt_path':
'chinese-wav2vec2-large',
'model':
'chinese-wav2vec2-large.pdparams',
'params':
'chinese-wav2vec2-large.pdparams',
},
},
"wav2vec2ASR_aishell1-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.3.0.model.tar.gz',
'md5':
'ac8fa0a6345e6a7535f6fabb5e59e218',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
} }
# --------------------------------- # ---------------------------------
......
...@@ -1173,10 +1173,6 @@ class Wav2Vec2ConfigPure(): ...@@ -1173,10 +1173,6 @@ class Wav2Vec2ConfigPure():
self.proj_codevector_dim = config.proj_codevector_dim self.proj_codevector_dim = config.proj_codevector_dim
self.diversity_loss_weight = config.diversity_loss_weight self.diversity_loss_weight = config.diversity_loss_weight
# ctc loss
self.ctc_loss_reduction = config.ctc_loss_reduction
self.ctc_zero_infinity = config.ctc_zero_infinity
# adapter # adapter
self.add_adapter = config.add_adapter self.add_adapter = config.add_adapter
self.adapter_kernel_size = config.adapter_kernel_size self.adapter_kernel_size = config.adapter_kernel_size
......
...@@ -76,28 +76,66 @@ class Wav2vec2ASR(nn.Layer): ...@@ -76,28 +76,66 @@ class Wav2vec2ASR(nn.Layer):
feats: paddle.Tensor, feats: paddle.Tensor,
text_feature: Dict[str, int], text_feature: Dict[str, int],
decoding_method: str, decoding_method: str,
beam_size: int): beam_size: int,
tokenizer: str=None):
batch_size = feats.shape[0] batch_size = feats.shape[0]
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1: if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
logger.error( raise ValueError(
f'decoding mode {decoding_method} must be running with batch_size == 1' f"decoding mode {decoding_method} must be running with batch_size == 1"
) )
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
if decoding_method == 'ctc_greedy_search': if decoding_method == 'ctc_greedy_search':
hyps = self.ctc_greedy_search(feats) if tokenizer is None:
res = [text_feature.defeaturize(hyp) for hyp in hyps] hyps = self.ctc_greedy_search(feats)
res_tokenids = [hyp for hyp in hyps] res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
else:
hyps = self.ctc_greedy_search(feats)
res = []
res_tokenids = []
for sequence in hyps:
# Decode token terms to words
predicted_tokens = text_feature.convert_ids_to_tokens(
sequence)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
# ctc_prefix_beam_search and attention_rescoring only return one # ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible # result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode # with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search': elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1 assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search(feats, beam_size) if tokenizer is None:
res = [text_feature.defeaturize(hyp)] hyp = self.ctc_prefix_beam_search(feats, beam_size)
res_tokenids = [hyp] res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
else:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = []
res_tokenids = []
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
else: else:
raise ValueError( raise ValueError(
f"wav2vec2 not support decoding method: {decoding_method}") f"wav2vec2 not support decoding method: {decoding_method}")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册