提交 2368d3b6 编写于 作者: P pfZhu

init ernie-sat

上级 e25b1cfa
文件已添加
## 使用说明
### 1.安装飞桨
我们的代码基于 Paddle(version>=2.0)
### 2.预训练模型
预训练模型ERNIE-SAT的模型如下所示(链接暂无):
- [ERNIE-SAT_ZH](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-zh.tar.gz)
- [ERNIE-SAT_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en.tar.gz)
- [ERNIE-SAT_ZH_and_EN](http://bj.bcebos.com/wenxin-models/model-ernie-sat-base-en_zh.tar.gz)
### 3.下载
1. 我们使用parallel wavegan作为声码器(vocoder):
- [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)
创建download文件夹,下载上述预训练的声码器(vocoder)模型并将其解压
```bash
mkdir download
cd download
unzip pwg_aishell3_ckpt_0.5.zip
```
2. 我们使用[FastSpeech2](https://arxiv.org/abs/2006.04558) 作为音素(phoneme)的持续时间预测器:
- [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip) 中文场景下使用
- [fastspeech2_nosil_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip) 英文场景下使用
下载上述预训练的fastspeech2模型并将其解压
```bash
cd download
unzip fastspeech2_conformer_baker_ckpt_0.5.zip
unzip fastspeech2_nosil_ljspeech_ckpt_0.5.zip
```
### 4.推理
我们目前只开源了语音编辑、个性化语音合成、跨语言语音合成的推理代码,后续会逐步开源。
注:当前采用的声码器版本与模型训练时版本(https://github.com/kan-bayashi/ParallelWaveGAN)在英文上存在差异,您可使用模型训练时版本作为您的声码器,模型将在后续更新中升级。
我们提供特定音频文件, 以及其对应的文本、音素相关文件:
- prompt_wav: 提供的音频文件
- prompt/dev: 基于上述特定音频对应的文本、音素相关文件
```text
prompt_wav
├── p299_096.wav # 样例语音文件1
├── SSB03540428.wav # 样例语音文件2
└── ...
```
```text
prompt/dev
├── text # 样例语音对应文本
├── wav.scp # 样例语音路径
├── mfa_text # 样例语音对应音素
├── mfa_start # 样例语音中各个音素的开始时间
└── mfa_end # 样例语音中各个音素的结束时间
```
1. `--am` 声学模型格式符合 {model_name}_{dataset}
2. `--am_config`, `--am_checkpoint`, `--am_stat``--phones_dict` 是声学模型的参数,对应于 fastspeech2 预训练模型中的 4 个文件。
3. `--voc` 声码器(vocoder)格式是否符合 {model_name}_{dataset}
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` 是声码器的参数,对应于 parallel wavegan 预训练模型中的 3 个文件。
5. `--lang` 对应模型的语言可以是 `zh``en`
6. `--ngpu` 要使用的GPU数,如果 ngpu==0,则使用 cpu。
7. ` --model_name` 模型名称
8. ` --uid` 特定提示(prompt)语音的id
9. ` --new_str` 输入的文本(本次开源暂时先设置特定的文本)
10. ` --prefix` 特定音频对应的文本、音素相关文件的地址
11. ` --source_language` , 源语言
12. ` --target_language` , 目标语言
13. ` --output_name` , 合成语音名称
14. ` --task_name` , 任务名称, 包括:语音编辑任务、个性化语音合成任务、跨语言语音合成任务
运行以下脚本即可进行实验
```shell
sh run_sedit_en.sh # 语音编辑任务(英文)
sh run_gen_en.sh # 个性化语音合成任务(英文)
sh run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的克隆)
```
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])
# PaddleSpeech Command Line
([简体中文](./README_cn.md)|English)
The simplest approach to use PaddleSpeech models.
## Help
```bash
paddlespeech help
```
## Audio Classification
```bash
paddlespeech cls --input input.wav
```
## Speaker Verification
```bash
paddlespeech vector --task spk --input input_16k.wav
```
## Automatic Speech Recognition
```
paddlespeech asr --lang zh --input input_16k.wav
```
## Speech Translation (English to Chinese)
(not support for Windows now)
```bash
paddlespeech st --input input_16k.wav
```
## Text-to-Speech
```bash
paddlespeech tts --input "你好,欢迎使用百度飞桨深度学习框架!" --output output.wav
```
## Text Post-precessing
- Punctuation Restoration
```bash
paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭
```
# PaddleSpeech 命令行工具
(简体中文|[English](./README.md))
`paddlespeech.cli` 模块是 PaddleSpeech 的命令行工具,它提供了最简便的方式调用 PaddleSpeech 提供的不同语音应用场景的预训练模型,用一行命令就可以进行模型预测:
## 命令行使用帮助
```bash
paddlespeech help
```
## 声音分类
```bash
paddlespeech cls --input input.wav
```
## 声纹识别
```bash
paddlespeech vector --task spk --input input_16k.wav
```
## 语音识别
```
paddlespeech asr --lang zh --input input_16k.wav
```
## 语音翻译(英-中)
(暂不支持Windows系统)
```bash
paddlespeech st --input input_16k.wav
```
## 语音合成
```bash
paddlespeech tts --input "你好,欢迎使用百度飞桨深度学习框架!" --output output.wav
```
## 文本后处理
- 标点恢复
```bash
paddlespeech text --task punc --input 今天的天气真不错啊你下午有空吗我想约你一起去吃饭
```
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import _locale
from .asr import ASRExecutor
from .base_commands import BaseCommand
from .base_commands import HelpCommand
from .cls import CLSExecutor
from .st import STExecutor
from .stats import StatsExecutor
from .text import TextExecutor
from .tts import TTSExecutor
from .vector import VectorExecutor
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import ASRExecutor
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from .entry import commands
from .utils import cli_register
from .utils import get_command
__all__ = [
'BaseCommand',
'HelpCommand',
]
@cli_register(name='paddlespeech')
class BaseCommand:
def execute(self, argv: List[str]) -> bool:
help = get_command('paddlespeech.help')
return help().execute(argv)
@cli_register(name='paddlespeech.help', description='Show help for commands.')
class HelpCommand:
def execute(self, argv: List[str]) -> bool:
msg = 'Usage:\n'
msg += ' paddlespeech <command> <options>\n\n'
msg += 'Commands:\n'
for command, detail in commands['paddlespeech'].items():
if command.startswith('_'):
continue
if '_description' not in detail:
continue
msg += ' {:<15} {}\n'.format(command,
detail['_description'])
print(msg)
return True
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import CLSExecutor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import paddle
import yaml
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor']
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"panns_cnn6-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
'md5': '4cf09194a95df024fd12f84712cf0f9c',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn6.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn10-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
'md5': 'cb8427b22176cc2116367d14847f5413',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn10.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn14-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
'md5': 'e3b9b5614a1595001161d0ab95edee97',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn14.pdparams',
'label_file': 'audioset_labels.txt',
},
}
model_alias = {
"panns_cnn6": "paddlespeech.cls.models.panns:CNN6",
"panns_cnn10": "paddlespeech.cls.models.panns:CNN10",
"panns_cnn14": "paddlespeech.cls.models.panns:CNN14",
}
@cli_register(
name='paddlespeech.cls', description='Audio classification infer command.')
class CLSExecutor(BaseExecutor):
def __init__(self):
super(CLSExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech.cls', add_help=True)
self.parser.add_argument(
'--input', type=str, default=None, help='Audio file to classify.')
self.parser.add_argument(
'--model',
type=str,
default='panns_cnn14',
choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()],
help='Choose model type of cls task.')
self.parser.add_argument(
'--config',
type=str,
default=None,
help='Config of cls task. Use deault config when it is None.')
self.parser.add_argument(
'--ckpt_path',
type=str,
default=None,
help='Checkpoint file of model.')
self.parser.add_argument(
'--label_file',
type=str,
default=None,
help='Label file of cls task.')
self.parser.add_argument(
'--topk',
type=int,
default=1,
help='Return topk scores of classification result.')
self.parser.add_argument(
'--device',
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'-d',
'--job_dump_result',
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(self,
model_type: str='panns_cnn14',
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
label_file: Optional[os.PathLike]=None):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
if label_file is None or ckpt_path is None:
tag = model_type + '-' + '32k' # panns_cnn14-32k
self.res_path = self._get_pretrained_path(tag)
self.cfg_path = os.path.join(self.res_path,
pretrained_models[tag]['cfg_path'])
self.label_file = os.path.join(self.res_path,
pretrained_models[tag]['label_file'])
self.ckpt_path = os.path.join(self.res_path,
pretrained_models[tag]['ckpt_path'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.label_file = os.path.abspath(label_file)
self.ckpt_path = os.path.abspath(ckpt_path)
# config
with open(self.cfg_path, 'r') as f:
self._conf = yaml.safe_load(f)
# labels
self._label_list = []
with open(self.label_file, 'r') as f:
for line in f:
self._label_list.append(line.strip())
# model
model_class = dynamic_import(model_type, model_alias)
model_dict = paddle.load(self.ckpt_path)
self.model = model_class(extract_embedding=False)
self.model.set_state_dict(model_dict)
self.model.eval()
def preprocess(self, audio_file: Union[str, os.PathLike]):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
"""
feat_conf = self._conf['feature']
logger.info(feat_conf)
waveform, _ = load(
file=audio_file,
sr=feat_conf['sample_rate'],
mono=True,
dtype='float32')
if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocessing audio_file:" + audio_file)
# Feature extraction
feature_extractor = LogMelSpectrogram(
sr=feat_conf['sample_rate'],
n_fft=feat_conf['n_fft'],
hop_length=feat_conf['hop_length'],
window=feat_conf['window'],
win_length=feat_conf['window_length'],
f_min=feat_conf['f_min'],
f_max=feat_conf['f_max'],
n_mels=feat_conf['n_mels'], )
feats = feature_extractor(
paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0)))
self._inputs['feats'] = paddle.transpose(feats, [0, 2, 1]).unsqueeze(
1) # [B, N, T] -> [B, 1, T, N]
@paddle.no_grad()
def infer(self):
"""
Model inference and result stored in self.output.
"""
self._outputs['logits'] = self.model(self._inputs['feats'])
def _generate_topk_label(self, result: np.ndarray, topk: int) -> str:
assert topk <= len(
self._label_list), 'Value of topk is larger than number of labels.'
topk_idx = (-result).argsort()[:topk]
ret = ''
for idx in topk_idx:
label, score = self._label_list[idx], result[idx]
ret += f'{label} {score} '
return ret
def postprocess(self, topk: int) -> Union[str, os.PathLike]:
"""
Output postprocess and return human-readable results such as texts and audio files.
"""
return self._generate_topk_label(
result=self._outputs['logits'].squeeze(0).numpy(), topk=topk)
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
parser_args = self.parser.parse_args(argv)
model_type = parser_args.model
label_file = parser_args.label_file
cfg_path = parser_args.config
ckpt_path = parser_args.ckpt_path
topk = parser_args.topk
device = parser_args.device
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, model_type, cfg_path, ckpt_path, label_file,
topk, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
parser_args.job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
model: str='panns_cnn14',
config: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
label_file: Optional[os.PathLike]=None,
topk: int=1,
device: str=paddle.get_device()):
"""
Python API to call an executor.
"""
audio_file = os.path.abspath(os.path.expanduser(audio_file))
paddle.set_device(device)
self._init_from_path(model, config, ckpt_path, label_file)
self.preprocess(audio_file)
self.infer()
res = self.postprocess(topk) # Retrieve result of cls.
return res
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import os
import os.path as osp
import shutil
import subprocess
import tarfile
import time
import zipfile
import requests
from tqdm import tqdm
from .log import logger
__all__ = ['get_path_from_url']
DOWNLOAD_RETRY_LIMIT = 3
def _is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def get_path_from_url(url,
root_dir,
md5sum=None,
check_exist=True,
decompress=True,
method='get'):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
from paddle.fluid.dygraph.parallel import ParallelEnv
assert _is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum, method=method)
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().current_endpoint in unique_endpoints:
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
def _get_download(url, fullname):
# using requests.get method
fname = osp.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _wget_download(url, fullname):
# using wget to download url
tmp_fullname = fullname + "_tmp"
# –user-agent
command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT,
url)
subprc = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_ = subprc.communicate()
if subprc.returncode != 0:
raise RuntimeError(
'{} failed. Please make sure `wget` is installed or {} exists'.
format(command, url))
shutil.move(tmp_fullname, fullname)
return fullname
_download_methods = {
'get': _get_download,
'wget': _wget_download,
}
def _download(url, path, md5sum=None, method='get'):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
"""
assert method in _download_methods, 'make sure `{}` implemented'.format(
method)
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _download_methods[method](url, fullname):
time.sleep(1)
continue
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections import defaultdict
__all__ = ['commands']
def _CommandDict():
return defaultdict(_CommandDict)
def _execute():
com = commands
idx = 0
for _argv in (['paddlespeech'] + sys.argv[1:]):
if _argv not in com:
break
idx += 1
com = com[_argv]
# The method 'execute' of a command instance returns 'True' for a success
# while 'False' for a failure. Here converts this result into a exit status
# in bash: 0 for a success and 1 for a failure.
status = 0 if com['_entry']().execute(sys.argv[idx:]) else 1
return status
commands = _CommandDict()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
from abc import ABC
from abc import abstractmethod
from collections import OrderedDict
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import paddle
from .log import logger
class BaseExecutor(ABC):
"""
An abstract executor of paddlespeech tasks.
"""
def __init__(self):
self._inputs = OrderedDict()
self._outputs = OrderedDict()
@abstractmethod
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
Args:
tag (str): A tag of pretrained model.
Returns:
os.PathLike: The path on which resources of pretrained model locate.
"""
pass
@abstractmethod
def _init_from_path(self, *args, **kwargs):
"""
Init model and other resources from arguments. This method should be called by `__call__()`.
"""
pass
@abstractmethod
def preprocess(self, input: Any, *args, **kwargs):
"""
Input preprocess and return paddle.Tensor stored in self._inputs.
Input content can be a text(tts), a file(asr, cls), a stream(not supported yet) or anything needed.
Args:
input (Any): Input text/file/stream or other content.
"""
pass
@paddle.no_grad()
@abstractmethod
def infer(self, *args, **kwargs):
"""
Model inference and put results into self._outputs.
This method get input tensors from self._inputs, and write output tensors into self._outputs.
"""
pass
@abstractmethod
def postprocess(self, *args, **kwargs) -> Union[str, os.PathLike]:
"""
Output postprocess and return results.
This method get model output from self._outputs and convert it into human-readable results.
Returns:
Union[str, os.PathLike]: Human-readable results such as texts and audio files.
"""
pass
@abstractmethod
def execute(self, argv: List[str]) -> bool:
"""
Command line entry. This method can only be accessed by a command line such as `paddlespeech asr`.
Args:
argv (List[str]): Arguments from command line.
Returns:
int: Result of the command execution. `True` for a success and `False` for a failure.
"""
pass
@abstractmethod
def __call__(self, *arg, **kwargs):
"""
Python API to call an executor.
"""
pass
def get_task_source(self, input_: Union[str, os.PathLike, None]
) -> Dict[str, Union[str, os.PathLike]]:
"""
Get task input source from command line input.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
Returns:
Dict[str, Union[str, os.PathLike]]: A dict with ids and inputs.
"""
if self._is_job_input(input_):
ret = self._get_job_contents(input_)
else:
ret = OrderedDict()
if input_ is None: # Take input from stdin
for i, line in enumerate(sys.stdin):
line = line.strip()
if len(line.split(' ')) == 1:
ret[str(i + 1)] = line
elif len(line.split(' ')) == 2:
id_, info = line.split(' ')
ret[id_] = info
else: # No valid input info from one line.
continue
else:
ret[1] = input_
return ret
def process_task_results(self,
input_: Union[str, os.PathLike, None],
results: Dict[str, os.PathLike],
job_dump_result: bool=False):
"""
Handling task results and redirect stdout if needed.
Args:
input_ (Union[str, os.PathLike, None]): Input from command line.
results (Dict[str, os.PathLike]): Task outputs.
job_dump_result (bool, optional): if True, dumps job results into file. Defaults to False.
"""
if not self._is_job_input(input_) and len(
results) == 1: # Only one input sample
raw_text = list(results.values())[0]
else:
raw_text = self._format_task_results(results)
print(raw_text, end='') # Stdout
if self._is_job_input(
input_) and job_dump_result: # Dump to *.job.done
try:
job_output_file = os.path.abspath(input_) + '.done'
sys.stdout = open(job_output_file, 'w')
print(raw_text, end='')
logger.info(f'Results had been saved to: {job_output_file}')
finally:
sys.stdout.close()
def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool:
"""
Check if current input file is a job input or not.
Args:
input_ (Union[str, os.PathLike]): Input file of current task.
Returns:
bool: return `True` for job input, `False` otherwise.
"""
return input_ and os.path.isfile(input_) and (input_.endswith('.job') or
input_.endswith('.txt'))
def _get_job_contents(
self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]:
"""
Read a job input file and return its contents in a dictionary.
Args:
job_input (os.PathLike): The job input file.
Returns:
Dict[str, str]: Contents of job input.
"""
job_contents = OrderedDict()
with open(job_input) as f:
for line in f:
line = line.strip()
if not line:
continue
k, v = line.split(' ')
job_contents[k] = v
return job_contents
def _format_task_results(
self, results: Dict[str, Union[str, os.PathLike]]) -> str:
"""
Convert task results to raw text.
Args:
results (Dict[str, str]): A dictionary of task results.
Returns:
str: A string object contains task results.
"""
ret = ''
for k, v in results.items():
ret += f'{k} {v}\n'
return ret
def disable_task_loggers(self):
"""
Disable all loggers in current task.
"""
loggers = [
logging.getLogger(name) for name in logging.root.manager.loggerDict
]
for l in loggers:
l.disabled = True
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
__all__ = [
'logger',
]
class Logger(object):
def __init__(self, name: str=None):
name = 'PaddleSpeech' if not name else name
self.logger = logging.getLogger(name)
log_config = {
'DEBUG': 10,
'INFO': 20,
'TRAIN': 21,
'EVAL': 22,
'WARNING': 30,
'ERROR': 40,
'CRITICAL': 50,
'EXCEPTION': 100,
}
for key, level in log_config.items():
logging.addLevelName(level, key)
if key == 'EXCEPTION':
self.__dict__[key.lower()] = self.logger.exception
else:
self.__dict__[key.lower()] = functools.partial(self.__call__,
level)
self.format = logging.Formatter(
fmt='[%(asctime)-15s] [%(levelname)8s] - %(message)s')
self.handler = logging.StreamHandler()
self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler)
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False
def __call__(self, log_level: str, msg: str):
self.logger.log(log_level, msg)
logger = Logger()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import STExecutor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import subprocess
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
import kaldiio
import numpy as np
import paddle
import soundfile
from kaldiio import WriteHelper
from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ["STExecutor"]
pretrained_models = {
"fat_st_ted-en-zh": {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
"md5":
"d62063f35a16d91210a71081bd2dd557",
"cfg_path":
"model.yaml",
"ckpt_path":
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
}
}
model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"}
kaldi_bins = {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
"md5":
"c0682303b3f3393dbf6ed4c4e35a53eb",
}
@cli_register(
name="paddlespeech.st", description="Speech translation infer command.")
class STExecutor(BaseExecutor):
def __init__(self):
super(STExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog="paddlespeech.st", add_help=True)
self.parser.add_argument(
"--input", type=str, default=None, help="Audio file to translate.")
self.parser.add_argument(
"--model",
type=str,
default="fat_st_ted",
choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()],
help="Choose model type of st task.")
self.parser.add_argument(
"--src_lang",
type=str,
default="en",
help="Choose model source language.")
self.parser.add_argument(
"--tgt_lang",
type=str,
default="zh",
help="Choose model target language.")
self.parser.add_argument(
"--sample_rate",
type=int,
default=16000,
choices=[16000],
help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument(
"--config",
type=str,
default=None,
help="Config of st task. Use deault config when it is None.")
self.parser.add_argument(
"--ckpt_path",
type=str,
default=None,
help="Checkpoint file of model.")
self.parser.add_argument(
"--device",
type=str,
default=paddle.get_device(),
help="Choose device to execute model inference.")
self.parser.add_argument(
'-d',
'--job_dump_result',
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
"Use pretrained model stored in: {}".format(decompressed_path))
return decompressed_path
def _set_kaldi_bins(self) -> os.PathLike:
"""
Download and returns kaldi_bins resources path of current task.
"""
decompressed_path = download_and_decompress(kaldi_bins, MODEL_HOME)
decompressed_path = os.path.abspath(decompressed_path)
logger.info("Kaldi_bins stored in: {}".format(decompressed_path))
if "LD_LIBRARY_PATH" in os.environ:
os.environ["LD_LIBRARY_PATH"] += f":{decompressed_path}"
else:
os.environ["LD_LIBRARY_PATH"] = f"{decompressed_path}"
os.environ["PATH"] += f":{decompressed_path}"
return decompressed_path
def _init_from_path(self,
model_type: str="fat_st_ted",
src_lang: str="en",
tgt_lang: str="zh",
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
if cfg_path is None or ckpt_path is None:
tag = model_type + "-" + src_lang + "-" + tgt_lang
res_path = self._get_pretrained_path(tag)
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]["cfg_path"])
self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]["ckpt_path"])
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
self.config.decode.decoding_method = "fullsentence"
with UpdateConfig(self.config):
self.config.cmvn_path = os.path.join(res_path,
self.config.cmvn_path)
self.config.spm_model_prefix = os.path.join(
res_path, self.config.spm_model_prefix)
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
model_conf = self.config
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, model_alias)
self.model = model_class.from_config(model_conf)
self.model.eval()
# load model
params_path = self.ckpt_path
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict)
# set kaldi bins
self._set_kaldi_bins()
def _check(self, audio_file: str, sample_rate: int):
_, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
if audio_sample_rate != sample_rate:
raise Exception("invalid sample rate")
sys.exit(-1)
def preprocess(self, wav_file: Union[str, os.PathLike], model_type: str):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a file(wav).
"""
audio_file = os.path.abspath(wav_file)
logger.info("Preprocess audio_file:" + audio_file)
if "fat_st" in model_type:
cmvn = self.config.cmvn_path
utt_name = "_tmp"
# Get the object for feature extraction
fbank_extract_command = [
"compute-fbank-feats", "--num-mel-bins=80", "--verbose=2",
"--sample-frequency=16000", "scp:-", "ark:-"
]
fbank_extract_process = subprocess.Popen(
fbank_extract_command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
fbank_extract_process.stdin.write(
f"{utt_name} {wav_file}".encode("utf8"))
fbank_extract_process.stdin.close()
fbank_feat = dict(
kaldiio.load_ark(fbank_extract_process.stdout))[utt_name]
extract_command = ["compute-kaldi-pitch-feats", "scp:-", "ark:-"]
pitch_extract_process = subprocess.Popen(
extract_command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
pitch_extract_process.stdin.write(
f"{utt_name} {wav_file}".encode("utf8"))
process_command = ["process-kaldi-pitch-feats", "ark:", "ark:-"]
pitch_process = subprocess.Popen(
process_command,
stdin=pitch_extract_process.stdout,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
pitch_extract_process.stdin.close()
pitch_feat = dict(kaldiio.load_ark(pitch_process.stdout))[utt_name]
concated_feat = np.concatenate((fbank_feat, pitch_feat), axis=1)
raw_feat = f"{utt_name}.raw"
with WriteHelper(
f"ark,scp:{raw_feat}.ark,{raw_feat}.scp") as writer:
writer(utt_name, concated_feat)
cmvn_command = [
"apply-cmvn", "--norm-vars=true", cmvn, f"scp:{raw_feat}.scp",
"ark:-"
]
cmvn_process = subprocess.Popen(
cmvn_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
process_command = [
"copy-feats", "--compress=true", "ark:-", "ark:-"
]
process = subprocess.Popen(
process_command,
stdin=cmvn_process.stdout,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
norm_feat = dict(kaldiio.load_ark(process.stdout))[utt_name]
self._inputs["audio"] = paddle.to_tensor(norm_feat).unsqueeze(0)
self._inputs["audio_len"] = paddle.to_tensor(
self._inputs["audio"].shape[1], dtype="int64")
else:
raise ValueError("Wrong model type.")
@paddle.no_grad()
def infer(self, model_type: str):
"""
Model inference and result stored in self.output.
"""
cfg = self.config.decode
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
if model_type == "fat_st_ted":
hyps = self.model.decode(
audio,
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = hyps
else:
raise ValueError("Wrong model type.")
def postprocess(self, model_type: str) -> Union[str, os.PathLike]:
"""
Output postprocess and return human-readable results such as texts and audio files.
"""
if model_type == "fat_st_ted":
return self._outputs["result"]
else:
raise ValueError("Wrong model type.")
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
parser_args = self.parser.parse_args(argv)
model = parser_args.model
src_lang = parser_args.src_lang
tgt_lang = parser_args.tgt_lang
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
device = parser_args.device
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, model, src_lang, tgt_lang, sample_rate,
config, ckpt_path, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
parser_args.job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
model: str='fat_st_ted',
src_lang: str='en',
tgt_lang: str='zh',
sample_rate: int=16000,
config: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
device: str=paddle.get_device()):
"""
Python API to call an executor.
"""
audio_file = os.path.abspath(audio_file)
self._check(audio_file, sample_rate)
paddle.set_device(device)
self._init_from_path(model, src_lang, tgt_lang, config, ckpt_path)
self.preprocess(audio_file, model)
self.infer(model)
res = self.postprocess(model)
return res
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import StatsExecutor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import List
from prettytable import PrettyTable
from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper
__all__ = ['StatsExecutor']
model_name_format = {
'asr': 'Model-Language-Sample Rate',
'cls': 'Model-Sample Rate',
'st': 'Model-Source language-Target language',
'text': 'Model-Task-Language',
'tts': 'Model-Language'
}
@cli_register(
name='paddlespeech.stats',
description='Get speech tasks support models list.')
class StatsExecutor():
def __init__(self):
super(StatsExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech.stats', add_help=True)
self.parser.add_argument(
'--task',
type=str,
default='asr',
choices=['asr', 'cls', 'st', 'text', 'tts'],
help='Choose speech task.',
required=True)
self.task_choices = ['asr', 'cls', 'st', 'text', 'tts']
def show_support_models(self, pretrained_models: dict):
fields = model_name_format[self.task].split("-")
table = PrettyTable(fields)
for key in pretrained_models:
table.add_row(key.split("-"))
print(table)
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
parser_args = self.parser.parse_args(argv)
self.task = parser_args.task
if self.task not in self.task_choices:
logger.error(
"Please input correct speech task, choices = ['asr', 'cls', 'st', 'text', 'tts']"
)
return False
elif self.task == 'asr':
try:
from ..asr.infer import pretrained_models
logger.info(
"Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error("Failed to get the list of ASR pretrained models.")
return False
elif self.task == 'cls':
try:
from ..cls.infer import pretrained_models
logger.info(
"Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error("Failed to get the list of CLS pretrained models.")
return False
elif self.task == 'st':
try:
from ..st.infer import pretrained_models
logger.info(
"Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error("Failed to get the list of ST pretrained models.")
return False
elif self.task == 'text':
try:
from ..text.infer import pretrained_models
logger.info(
"Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error(
"Failed to get the list of TEXT pretrained models.")
return False
elif self.task == 'tts':
try:
from ..tts.infer import pretrained_models
logger.info(
"Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error("Failed to get the list of TTS pretrained models.")
return False
@stats_wrapper
def __call__(
self,
task: str=None, ):
"""
Python API to call an executor.
"""
self.task = task
if self.task not in self.task_choices:
print(
"Please input correct speech task, choices = ['asr', 'cls', 'st', 'text', 'tts']"
)
elif self.task == 'asr':
try:
from ..asr.infer import pretrained_models
print(
"Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of ASR pretrained models.")
elif self.task == 'cls':
try:
from ..cls.infer import pretrained_models
print(
"Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of CLS pretrained models.")
elif self.task == 'st':
try:
from ..st.infer import pretrained_models
print(
"Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of ST pretrained models.")
elif self.task == 'text':
try:
from ..text.infer import pretrained_models
print(
"Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of TEXT pretrained models.")
elif self.task == 'tts':
try:
from ..tts.infer import pretrained_models
print(
"Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of TTS pretrained models.")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import TextExecutor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union
import paddle
from ...s2t.utils.dynamic_import import dynamic_import
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
__all__ = ['TextExecutor']
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"ernie_linear_p7_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz',
'md5':
'12283e2ddde1797c5d1e57036b512746',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
"ernie_linear_p3_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz',
'md5':
'448eb2fdf85b6a997e7e652e80c51dd2',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
}
model_alias = {
"ernie_linear_p7": "paddlespeech.text.models:ErnieLinear",
"ernie_linear_p3": "paddlespeech.text.models:ErnieLinear",
}
tokenizer_alias = {
"ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer",
"ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer",
}
@cli_register(name='paddlespeech.text', description='Text infer command.')
class TextExecutor(BaseExecutor):
def __init__(self):
super(TextExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech.text', add_help=True)
self.parser.add_argument(
'--input', type=str, default=None, help='Input text.')
self.parser.add_argument(
'--task',
type=str,
default='punc',
choices=['punc'],
help='Choose text task.')
self.parser.add_argument(
'--model',
type=str,
default='ernie_linear_p7_wudao',
choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()],
help='Choose model type of text task.')
self.parser.add_argument(
'--lang',
type=str,
default='zh',
choices=['zh', 'en'],
help='Choose model language.')
self.parser.add_argument(
'--config',
type=str,
default=None,
help='Config of cls task. Use deault config when it is None.')
self.parser.add_argument(
'--ckpt_path',
type=str,
default=None,
help='Checkpoint file of model.')
self.parser.add_argument(
'--punc_vocab',
type=str,
default=None,
help='Vocabulary file of punctuation restoration task.')
self.parser.add_argument(
'--device',
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'-d',
'--job_dump_result',
action='store_true',
help='Save job result into file.')
self.parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(self,
task: str='punc',
model_type: str='ernie_linear_p7_wudao',
lang: str='zh',
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
vocab_file: Optional[os.PathLike]=None):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
self.task = task
if cfg_path is None or ckpt_path is None or vocab_file is None:
tag = '-'.join([model_type, task, lang])
self.res_path = self._get_pretrained_path(tag)
self.cfg_path = os.path.join(self.res_path,
pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(self.res_path,
pretrained_models[tag]['ckpt_path'])
self.vocab_file = os.path.join(self.res_path,
pretrained_models[tag]['vocab_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
self.vocab_file = os.path.abspath(vocab_file)
model_name = model_type[:model_type.rindex('_')]
if self.task == 'punc':
# punc list
self._punc_list = []
with open(self.vocab_file, 'r') as f:
for line in f:
self._punc_list.append(line.strip())
# model
model_class = dynamic_import(model_name, model_alias)
tokenizer_class = dynamic_import(model_name, tokenizer_alias)
self.model = model_class(
cfg_path=self.cfg_path, ckpt_path=self.ckpt_path)
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
else:
raise NotImplementedError
self.model.eval()
def _clean_text(self, text):
text = text.lower()
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
text = re.sub(f'[{"".join([p for p in self._punc_list][1:])}]', '',
text)
return text
def preprocess(self, text: Union[str, os.PathLike]):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
"""
if self.task == 'punc':
clean_text = self._clean_text(text)
assert len(clean_text) > 0, f'Invalid input string: {text}'
tokenized_input = self.tokenizer(
list(clean_text), return_length=True, is_split_into_words=True)
self._inputs['input_ids'] = tokenized_input['input_ids']
self._inputs['seg_ids'] = tokenized_input['token_type_ids']
self._inputs['seq_len'] = tokenized_input['seq_len']
else:
raise NotImplementedError
@paddle.no_grad()
def infer(self):
"""
Model inference and result stored in self.output.
"""
if self.task == 'punc':
input_ids = paddle.to_tensor(self._inputs['input_ids']).unsqueeze(0)
seg_ids = paddle.to_tensor(self._inputs['seg_ids']).unsqueeze(0)
logits, _ = self.model(input_ids, seg_ids)
preds = paddle.argmax(logits, axis=-1).squeeze(0)
self._outputs['preds'] = preds
else:
raise NotImplementedError
def postprocess(self) -> Union[str, os.PathLike]:
"""
Output postprocess and return human-readable results such as texts and audio files.
"""
if self.task == 'punc':
input_ids = self._inputs['input_ids']
seq_len = self._inputs['seq_len']
preds = self._outputs['preds']
tokens = self.tokenizer.convert_ids_to_tokens(
input_ids[1:seq_len - 1])
labels = preds[1:seq_len - 1].tolist()
assert len(tokens) == len(labels)
text = ''
for t, l in zip(tokens, labels):
text += t
if l != 0: # Non punc.
text += self._punc_list[l]
return text
else:
raise NotImplementedError
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
parser_args = self.parser.parse_args(argv)
task = parser_args.task
model_type = parser_args.model
lang = parser_args.lang
cfg_path = parser_args.config
ckpt_path = parser_args.ckpt_path
punc_vocab = parser_args.punc_vocab
device = parser_args.device
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False
for id_, input_ in task_source.items():
try:
res = self(input_, task, model_type, lang, cfg_path, ckpt_path,
punc_vocab, device)
task_results[id_] = res
except Exception as e:
has_exceptions = True
task_results[id_] = f'{e.__class__.__name__}: {e}'
self.process_task_results(parser_args.input, task_results,
parser_args.job_dump_result)
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(
self,
text: str,
task: str='punc',
model: str='ernie_linear_p7_wudao',
lang: str='zh',
config: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None,
punc_vocab: Optional[os.PathLike]=None,
device: str=paddle.get_device(), ):
"""
Python API to call an executor.
"""
paddle.set_device(device)
self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab)
self.preprocess(text)
self.infer()
res = self.postprocess() # Retrieve result of text task.
return res
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import TTSExecutor
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import inspect
import json
import os
import tarfile
import threading
import time
import uuid
import zipfile
from typing import Any
from typing import Dict
import paddle
import requests
import yaml
from paddle.framework import load
import paddleaudio
from . import download
from .entry import commands
try:
from .. import __version__
except ImportError:
__version__ = "0.0.0" # for develop branch
requests.adapters.DEFAULT_RETRIES = 3
__all__ = [
'cli_register',
'get_command',
'download_and_decompress',
'load_state_dict_from_url',
'stats_wrapper',
]
def cli_register(name: str, description: str='') -> Any:
def _warpper(command):
items = name.split('.')
com = commands
for item in items:
com = com[item]
com['_entry'] = command
if description:
com['_description'] = description
return command
return _warpper
def get_command(name: str) -> Any:
items = name.split('.')
com = commands
for item in items:
com = com[item]
return com['_entry']
def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
file_dir = os.path.dirname(filepath)
is_zip_file = False
if tarfile.is_tarfile(filepath):
files = tarfile.open(filepath, "r:*")
file_list = files.getnames()
elif zipfile.is_zipfile(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
is_zip_file = True
else:
return file_dir
if download._is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
elif download._is_a_single_dir(file_list):
if is_zip_file:
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
else:
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.close()
return uncompressed_path
def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
"""
Download archieves and decompress to specific path.
"""
if not os.path.isdir(path):
os.makedirs(path)
assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys()))
filepath = os.path.join(path, os.path.basename(archive['url']))
if os.path.isfile(filepath) and download._md5check(filepath,
archive['md5']):
uncompress_path = _get_uncompress_path(filepath)
if not os.path.isdir(uncompress_path):
download._decompress(filepath)
else:
StatsWorker(
task='download',
version=__version__,
extra_info={
'download_url': archive['url'],
'paddle_version': paddle.__version__
}).start()
uncompress_path = download.get_path_from_url(archive['url'], path,
archive['md5'])
return uncompress_path
def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike:
"""
Download and load a state dict from url
"""
if not os.path.isdir(path):
os.makedirs(path)
download.get_path_from_url(url, path, md5)
return load(os.path.join(path, os.path.basename(url)))
def _get_user_home():
return os.path.expanduser('~')
def _get_paddlespcceh_home():
if 'PPSPEECH_HOME' in os.environ:
home_path = os.environ['PPSPEECH_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
raise RuntimeError(
'The environment variable PPSPEECH_HOME {} is not a directory.'.
format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddlespeech')
def _get_sub_home(directory):
home = os.path.join(_get_paddlespcceh_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
PPSPEECH_HOME = _get_paddlespcceh_home()
MODEL_HOME = _get_sub_home('models')
CONF_HOME = _get_sub_home('conf')
def _md5(text: str):
'''Calculate the md5 value of the input text.'''
md5code = hashlib.md5(text.encode())
return md5code.hexdigest()
class ConfigCache:
def __init__(self):
self._data = {}
self._initialize()
self.file = os.path.join(CONF_HOME, 'cache.yaml')
if not os.path.exists(self.file):
self.flush()
return
with open(self.file, 'r') as file:
try:
cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg)
except Exception as e:
self.flush()
@property
def cache_info(self):
return self._data['cache_info']
def _initialize(self):
# Set default configuration values.
cache_info = _md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time()))
self._data['cache_info'] = cache_info
def flush(self):
'''Flush the current configuration into the configuration file.'''
with open(self.file, 'w') as file:
cfg = json.loads(json.dumps(self._data))
yaml.dump(cfg, file)
stats_api = "http://paddlepaddle.org.cn/paddlehub/stat"
cache_info = ConfigCache().cache_info
class StatsWorker(threading.Thread):
def __init__(self,
task="asr",
model=None,
version=__version__,
extra_info={}):
threading.Thread.__init__(self)
self._task = task
self._model = model
self._version = version
self._extra_info = extra_info
def run(self):
params = {
'task': self._task,
'version': self._version,
'from': 'ppspeech'
}
if self._model:
params['model'] = self._model
self._extra_info.update({
'cache_info': cache_info,
})
params.update({"extra": json.dumps(self._extra_info)})
try:
requests.get(stats_api, params)
except Exception:
pass
return
def _note_one_stat(cls_name, params={}):
task = cls_name.replace('Executor', '').lower() # XXExecutor
extra_info = {
'paddle_version': paddle.__version__,
}
if 'model' in params:
model = params['model']
else:
model = None
if 'audio_file' in params:
try:
_, sr = paddleaudio.load(params['audio_file'])
except Exception:
sr = -1
if task == 'asr':
extra_info.update({
'lang': params['lang'],
'inp_sr': sr,
'model_sr': params['sample_rate'],
})
elif task == 'st':
extra_info.update({
'lang':
params['src_lang'] + '-' + params['tgt_lang'],
'inp_sr':
sr,
'model_sr':
params['sample_rate'],
})
elif task == 'tts':
model = params['am']
extra_info.update({
'lang': params['lang'],
'vocoder': params['voc'],
})
elif task == 'cls':
extra_info.update({
'inp_sr': sr,
})
elif task == 'text':
extra_info.update({
'sub_task': params['task'],
'lang': params['lang'],
})
else:
return
StatsWorker(
task=task,
model=model,
version=__version__,
extra_info=extra_info, ).start()
def _parse_args(func, *args, **kwargs):
# FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations)
argspec = inspect.getfullargspec(func)
keys = argspec[0]
if keys[0] == 'self': # Remove self pointer.
keys = keys[1:]
default_values = argspec[3]
values = [None] * (len(keys) - len(default_values))
values.extend(list(default_values))
params = dict(zip(keys, values))
for idx, v in enumerate(args):
params[keys[idx]] = v
for k, v in kwargs.items():
params[k] = v
return params
def stats_wrapper(executor_func):
def _warpper(self, *args, **kwargs):
try:
_note_one_stat(
type(self).__name__, _parse_args(executor_func, *args,
**kwargs))
except Exception:
pass
return executor_func(self, *args, **kwargs)
return _warpper
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import VectorExecutor
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册