“abc167338a99f1b644e1a5d4bb324a566d2fd87f”上不存在“develop/doc/dev/new_layer_en.html”
提交 c552c087 编写于 作者: H huangyuxin

fix ds2 server config

......@@ -39,6 +39,9 @@ tools/env.sh
tools/openfst-1.8.1/
tools/libsndfile/
tools/python-soundfile/
tools/onnx
tools/onnxruntime
tools/Paddle2ONNX
speechx/fc_patch/
......
......@@ -52,7 +52,7 @@ pull_request_rules:
add: ["T2S"]
- name: "auto add label=Audio"
conditions:
- files~=^paddleaudio/
- files~=^paddlespeech/audio/
actions:
label:
add: ["Audio"]
......@@ -100,7 +100,7 @@ pull_request_rules:
add: ["README"]
- name: "auto add label=Documentation"
conditions:
- files~=^(docs/|CHANGELOG.md|paddleaudio/CHANGELOG.md)
- files~=^(docs/|CHANGELOG.md)
actions:
label:
add: ["Documentation"]
......
# Changelog
Date: 2022-3-15, Author: Xiaojie Chen.
- kaldi and librosa mfcc, fbank, spectrogram.
- unit test and benchmark.
Date: 2022-2-25, Author: Hui Zhang.
- Refactor architecture.
- dtw distance and mcd style dtw.
# PaddleAudio
PaddleAudio is an audio library for PaddlePaddle.
## Install
`pip install .`
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
\ No newline at end of file
# Build docs for PaddleAudio
Execute the following steps in **current directory**.
## 1. Install
`pip install Sphinx sphinx_rtd_theme`
## 2. Generate API docs
Generate API docs from doc string.
`sphinx-apidoc -fMeT -o source ../paddleaudio ../paddleaudio/utils --templatedir source/_templates`
## 3. Build
`sphinx-build source _html`
## 4. Preview
Open `_html/index.html` for page preview.
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
:end
popd
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
This module is used to store environmental variables in PaddleAudio.
PPAUDIO_HOME --> the root directory for storing PaddleAudio related data. Default to ~/.paddleaudio. Users can change the
├ default value through the PPAUDIO_HOME environment variable.
├─ MODEL_HOME --> Store model files.
└─ DATA_HOME --> Store automatically downloaded datasets.
'''
import os
__all__ = [
'USER_HOME',
'PPAUDIO_HOME',
'MODEL_HOME',
'DATA_HOME',
]
def _get_user_home():
return os.path.expanduser('~')
def _get_ppaudio_home():
if 'PPAUDIO_HOME' in os.environ:
home_path = os.environ['PPAUDIO_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
raise RuntimeError(
'The environment variable PPAUDIO_HOME {} is not a directory.'.
format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddleaudio')
def _get_sub_home(directory):
home = os.path.join(_get_ppaudio_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
USER_HOME = _get_user_home()
PPAUDIO_HOME = _get_ppaudio_home()
MODEL_HOME = _get_sub_home('models')
DATA_HOME = _get_sub_home('datasets')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
import setuptools
from setuptools.command.install import install
from setuptools.command.test import test
# set the version here
VERSION = '0.0.0'
# Inspired by the example at https://pytest.org/latest/goodpractises.html
class TestCommand(test):
def finalize_options(self):
test.finalize_options(self)
self.test_args = []
self.test_suite = True
def run(self):
self.run_benchmark()
super(TestCommand, self).run()
def run_tests(self):
# Run nose ensuring that argv simulates running nosetests directly
import nose
nose.run_exit(argv=['nosetests', '-w', 'tests'])
def run_benchmark(self):
for benchmark_item in glob.glob('tests/benchmark/*py'):
os.system(f'pytest {benchmark_item}')
class InstallCommand(install):
def run(self):
install.run(self)
def write_version_py(filename='paddleaudio/__init__.py'):
with open(filename, "a") as f:
f.write(f"__version__ = '{VERSION}'")
def remove_version_py(filename='paddleaudio/__init__.py'):
with open(filename, "r") as f:
lines = f.readlines()
with open(filename, "w") as f:
for line in lines:
if "__version__" not in line:
f.write(line)
remove_version_py()
write_version_py()
setuptools.setup(
name="paddleaudio",
version=VERSION,
author="",
author_email="",
description="PaddleAudio, in development",
long_description="",
long_description_content_type="text/markdown",
url="",
packages=setuptools.find_packages(include=['paddleaudio*']),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
install_requires=[
'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2',
'soundfile >= 0.9.0', 'colorlog', 'pathos == 0.2.8'
],
extras_require={
'test': [
'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1',
'torchaudio==0.10.2', 'pytest-benchmark'
],
},
cmdclass={
'install': InstallCommand,
'test': TestCommand,
}, )
remove_version_py()
......@@ -89,7 +89,7 @@ Then to start the system server, and it provides HTTP backend services.
Then start the server with Fastapi.
```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
export PYTHONPATH=$PYTHONPATH:./src
python src/audio_search.py
```
......
......@@ -91,7 +91,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
启动用 Fastapi 构建的服务
```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
export PYTHONPATH=$PYTHONPATH:./src
python src/audio_search.py
```
......
......@@ -7,11 +7,11 @@ host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online']
# task choices = ['asr_online-inference', 'asr_online-onnx']
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
engine_list: ['asr_online-inference']
#################################################################################
......@@ -19,8 +19,8 @@ engine_list: ['asr_online']
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
################### speech task: asr; engine_type: online-inference #######################
asr_online-inference:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
......@@ -47,3 +47,38 @@ asr_online:
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
num_decoding_left_chunks:
force_yes: True
device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu'
graph_optimization_level: 0
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf:
frame_duration_ms: 80
shift_ms: 40
sample_rate: 16000
sample_width: 2
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 20 # ms
shift_ms: 10 # ms
......@@ -4,6 +4,6 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3
# nohup python3 punc_server.py --config_file conf/punc_application.yaml > punc.log 2>&1 &
paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_wenetspeech_application.yaml > streaming_asr.log 2>&1 &
paddlespeech_server start --config_file conf/ws_conformer_wenetspeech_application.yaml &> streaming_asr.log &
......@@ -4,7 +4,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
# read the wav and pass it to only streaming asr service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
# read the wav and call streaming and punc service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
......
# Customize Dataset for Audio Classification
Following this tutorial you can customize your dataset for audio classification task by using `paddlespeech` and `paddleaudio`.
Following this tutorial you can customize your dataset for audio classification task by using `paddlespeech`.
A base class of classification dataset is `paddleaudio.dataset.AudioClassificationDataset`. To customize your dataset you should write a dataset class derived from `AudioClassificationDataset`.
A base class of classification dataset is `paddlespeech.audio.dataset.AudioClassificationDataset`. To customize your dataset you should write a dataset class derived from `AudioClassificationDataset`.
Assuming you have some wave files that stored in your own directory. You should prepare a meta file with the information of filepaths and labels. For example the absolute path of it is `/PATH/TO/META_FILE.txt`:
```
......@@ -14,7 +14,7 @@ Assuming you have some wave files that stored in your own directory. You should
Here is an example to build your custom dataset in `custom_dataset.py`:
```python
from paddleaudio.datasets.dataset import AudioClassificationDataset
from paddlespeech.audio.datasets.dataset import AudioClassificationDataset
class CustomDataset(AudioClassificationDataset):
meta_file = '/PATH/TO/META_FILE.txt'
......@@ -48,7 +48,7 @@ class CustomDataset(AudioClassificationDataset):
Then you can build dataset and data loader from `CustomDataset`:
```python
import paddle
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.audio.features import LogMelSpectrogram
from custom_dataset import CustomDataset
......
......@@ -8,13 +8,13 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----:
[Ds2 Online Wenetspeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz) | Wenetspeech Dataset | Char-based | 1.2 GB | 2 Conv + 5 LSTM layers | 0.152 (test\_net, w/o LM) <br> 0.2417 (test\_meeting, w/o LM) <br> 0.053 (aishell, w/ LM) |-| 10000 h |-
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0)
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.064 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0)
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0)
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |-
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1)
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0464 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1)
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1)
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz)| Librispeech Dataset | Char-based | 518 MB | 2 Conv + 3 bidirectional LSTM layers| - |0.0725| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0)
[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0337 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1)
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0)
[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0338 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1)
[Transformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0381 | 960 h | [Transformer Librispeech ASR1](../../examples/librispeech/asr1)
[Transformer Librispeech ASR2 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: JoinCTC w/ LM |-| 0.0240 | 960 h | [Transformer Librispeech ASR2](../../examples/librispeech/asr2)
......
......@@ -13,6 +13,7 @@
| Model | Number of Params | Release | Config | Test set | Valid Loss | CER |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 122.3M | r1.0.1 | conf/deepspeech2.yaml + U2 Data pipline and spec aug + fbank161 | test | 5.780756044387817 | 0.055400 |
| DeepSpeech2 | 58.4M | v2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.738585948944092 | 0.064000 |
| DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
......
data:
dataset: 'paddleaudio.datasets:ESC50'
dataset: 'paddlespeech.audio.datasets:ESC50'
num_classes: 50
train:
mode: 'train'
......
......@@ -2,7 +2,7 @@
###########################################
# Data #
###########################################
dataset: 'paddleaudio.datasets:HeySnips'
dataset: 'paddlespeech.audio.datasets:HeySnips'
data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
############################################
......
......@@ -3,6 +3,7 @@
## Deepspeech2 Non-Streaming
| Model | Params | release | Config | Test set | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 113.96M | r1.0.1 | conf/deepspeech2.yaml + U2 Data pipline and spec aug + fbank161 | test-clean | 10.76069622039795 | 0.046700 |
| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | test-clean | 14.49190807 | 0.067283 |
| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | test-clean | 15.184467315673828 | 0.072154 |
| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | test-clean | - | 0.073973 |
......
......@@ -14,9 +14,9 @@
import argparse
import paddle
from paddleaudio.datasets.voxceleb import VoxCeleb
from yacs.config import CfgNode
from paddlespeech.audio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.training.seeding import seed_everything
......
......@@ -21,9 +21,9 @@ import os
from typing import List
import tqdm
from paddleaudio import load as load_audio
from yacs.config import CfgNode
from paddlespeech.audio import load as load_audio
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.vector_utils import get_chunks
......
......@@ -22,9 +22,9 @@ import os
import random
import tqdm
from paddleaudio import load as load_audio
from yacs.config import CfgNode
from paddlespeech.audio import load as load_audio
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.vector_utils import get_chunks
......
......@@ -16,8 +16,8 @@ import os
from typing import List
from typing import Tuple
from ..utils import DATA_HOME
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset
__all__ = ['ESC50']
......
......@@ -17,8 +17,8 @@ import random
from typing import List
from typing import Tuple
from ..utils import DATA_HOME
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset
__all__ = ['GTZAN']
......
......@@ -17,8 +17,8 @@ import random
from typing import List
from typing import Tuple
from ..utils import DATA_HOME
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset
__all__ = ['TESS']
......
......@@ -16,8 +16,8 @@ import os
from typing import List
from typing import Tuple
from ..utils import DATA_HOME
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset
__all__ = ['UrbanSound8K']
......
......@@ -11,13 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...cli.utils import DATA_HOME
from ...cli.utils import MODEL_HOME
from .download import decompress
from .download import download_and_decompress
from .download import load_state_dict_from_url
from .env import DATA_HOME
from .env import MODEL_HOME
from .env import PPAUDIO_HOME
from .env import USER_HOME
from .error import ParameterError
from .log import Logger
from .log import logger
......
......@@ -187,7 +187,7 @@ class ASRExecutor(BaseExecutor):
elif "conformer" in model_type or "transformer" in model_type:
self.config.decode.decoding_method = decode_method
if num_decoding_left_chunks:
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0"
self.config.num_decoding_left_chunks = num_decoding_left_chunks
else:
......
......@@ -21,12 +21,12 @@ from typing import Union
import numpy as np
import paddle
import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from paddlespeech.audio import load
from paddlespeech.audio.features import LogMelSpectrogram
__all__ = ['CLSExecutor']
......
......@@ -24,11 +24,11 @@ from typing import Any
from typing import Dict
import paddle
import paddleaudio
import requests
import yaml
from paddle.framework import load
import paddlespeech.audio
from . import download
from .entry import commands
try:
......@@ -190,6 +190,7 @@ def _get_sub_home(directory):
PPSPEECH_HOME = _get_paddlespcceh_home()
MODEL_HOME = _get_sub_home('models')
CONF_HOME = _get_sub_home('conf')
DATA_HOME = _get_sub_home('datasets')
def _md5(text: str):
......@@ -281,7 +282,7 @@ def _note_one_stat(cls_name, params={}):
if 'audio_file' in params:
try:
_, sr = paddleaudio.load(params['audio_file'])
_, sr = paddlespeech.audio.load(params['audio_file'])
except Exception:
sr = -1
......
......@@ -22,13 +22,13 @@ from typing import Union
import paddle
import soundfile
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
......
......@@ -16,11 +16,12 @@ import os
import numpy as np
from paddle import inference
from paddleaudio.backends import load as load_audio
from paddleaudio.datasets import ESC50
from paddleaudio.features import melspectrogram
from scipy.special import softmax
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.datasets import ESC50
from paddlespeech.audio.features import melspectrogram
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.")
......
......@@ -15,8 +15,8 @@ import argparse
import os
import paddle
from paddleaudio.datasets import ESC50
from paddlespeech.audio.datasets import ESC50
from paddlespeech.cls.models import cnn14
from paddlespeech.cls.models import SoundClassifier
......
......@@ -17,10 +17,10 @@ import os
import paddle
import paddle.nn.functional as F
import yaml
from paddleaudio.backends import load as load_audio
from paddleaudio.features import LogMelSpectrogram
from paddleaudio.utils import logger
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.features import LogMelSpectrogram
from paddlespeech.audio.utils import logger
from paddlespeech.cls.models import SoundClassifier
from paddlespeech.utils.dynamic_import import dynamic_import
......
......@@ -16,10 +16,10 @@ import os
import paddle
import yaml
from paddleaudio.features import LogMelSpectrogram
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
from paddlespeech.audio.features import LogMelSpectrogram
from paddlespeech.audio.utils import logger
from paddlespeech.audio.utils import Timer
from paddlespeech.cls.models import SoundClassifier
from paddlespeech.utils.dynamic_import import dynamic_import
......
......@@ -15,8 +15,9 @@ import os
import paddle.nn as nn
import paddle.nn.functional as F
from paddleaudio.utils.download import load_state_dict_from_url
from paddleaudio.utils.env import MODEL_HOME
from paddlespeech.audio.utils import MODEL_HOME
from paddlespeech.audio.utils.download import load_state_dict_from_url
__all__ = ['CNN14', 'CNN10', 'CNN6', 'cnn14', 'cnn10', 'cnn6']
......
......@@ -14,10 +14,10 @@
import os
import paddle
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
from yacs.config import CfgNode
from paddlespeech.audio.utils import logger
from paddlespeech.audio.utils import Timer
from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel
......
......@@ -15,6 +15,7 @@
__all__ = [
'asr_dynamic_pretrained_models',
'asr_static_pretrained_models',
'asr_onnx_pretrained_models',
'cls_dynamic_pretrained_models',
'cls_static_pretrained_models',
'st_dynamic_pretrained_models',
......@@ -166,23 +167,17 @@ asr_dynamic_pretrained_models = {
},
},
"deepspeech2online_aishell-zh-16k": {
'1.0': {
'1.0.2': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz',
'md5':
'df5ddeac8b679a470176649ac4b78726',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz',
'md5': '4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path': 'model.yaml',
'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1',
'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model': 'onnx/model.onnx',
'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2offline_librispeech-en-16k": {
......@@ -224,6 +219,55 @@ asr_static_pretrained_models = {
'29e02312deb2e59b3c8686c7966d4fe3'
}
},
"deepspeech2online_aishell-zh-16k": {
'1.0.1': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz',
'md5':
'df5ddeac8b679a470176649ac4b78726',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
'1.0.2': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz',
'md5': '4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path': 'model.yaml',
'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1',
'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model': 'onnx/model.onnx',
'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}
asr_onnx_pretrained_models = {
"deepspeech2online_aishell-zh-16k": {
'1.0.2': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz',
'md5': '4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path': 'model.yaml',
'ckpt_path':'exp/deepspeech2_online/checkpoints/avg_1',
'model':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model': 'onnx/model.onnx',
'lm_url':'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}
# ---------------------------------
......
......@@ -164,9 +164,10 @@ class CommonTaskResource:
try:
import_models = '{}_{}_pretrained_models'.format(self.task,
self.model_format)
print(f"from .pretrained_models import {import_models}")
exec('from .pretrained_models import {}'.format(import_models))
models = OrderedDict(locals()[import_models])
except ImportError:
except Exception as e:
models = OrderedDict({}) # no models.
finally:
return models
......
......@@ -14,10 +14,11 @@
"""Contains the audio featurizer class."""
import numpy as np
import paddle
import paddleaudio.compliance.kaldi as kaldi
from python_speech_features import delta
from python_speech_features import mfcc
import paddlespeech.audio.compliance.kaldi as kaldi
class AudioFeaturizer():
"""Audio featurizer, for extracting features from audio contents of
......
......@@ -15,9 +15,10 @@
import librosa
import numpy as np
import paddle
import paddleaudio.compliance.kaldi as kaldi
from python_speech_features import logfbank
import paddlespeech.audio.compliance.kaldi as kaldi
def stft(x,
n_fft,
......
......@@ -7,11 +7,11 @@ host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# task choices = ['asr_online-inference', 'asr_online-onnx']
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
engine_list: ['asr_online-inference']
#################################################################################
......@@ -19,8 +19,8 @@ engine_list: ['asr_online']
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
################### speech task: asr; engine_type: online-inference #######################
asr_online-inference:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
......@@ -30,7 +30,7 @@ asr_online:
decode_method:
num_decoding_left_chunks:
force_yes: True
device: # cpu or gpu:id
device: 'cpu' # cpu or gpu:id
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
......@@ -38,6 +38,41 @@ asr_online:
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
frame_duration_ms: 80
shift_ms: 40
sample_rate: 16000
sample_width: 2
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 20 # ms
shift_ms: 10 # ms
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
num_decoding_left_chunks:
force_yes: True
device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu'
graph_optimization_level: 0
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf:
frame_duration_ms: 85
shift_ms: 40
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from typing import ByteString
from typing import Optional
import numpy as np
import paddle
from numpy import float32
from yacs.config import CfgNode
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils import onnx_infer
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
# ASR server connection process class
class PaddleASRConnectionHanddler:
def __init__(self, asr_engine):
"""Init a Paddle ASR Connection Handler instance
Args:
asr_engine (ASREngine): the global asr engine
"""
super().__init__()
logger.info(
"create an paddle asr connection handler to process the websocket connection"
)
self.config = asr_engine.config # server config
self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self.model_type = self.asr_engine.executor.model_type
self.sample_rate = self.asr_engine.executor.sample_rate
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, (
self.sample_rate, self.preprocess_conf.process[0]['fs'])
self.frame_shift_in_ms = int(
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
self.continuous_decoding = self.config.get("continuous_decoding", False)
self.init_decoder()
self.reset()
def init_decoder(self):
if "deepspeech2" in self.model_type:
assert self.continuous_decoding is False, "ds2 model not support endpoint"
self.am_predictor = self.asr_engine.executor.am_predictor
self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab
enc_n_units=self.model_config.rnn_layer_size * 2,
blank_id=self.model_config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.model_config.get('ctc_grad_norm_type',
None))
cfg = self.model_config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
else:
raise ValueError(f"Not supported: {self.model_type}")
def model_reset(self):
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
def output_reset(self):
## outputs
# partial/ending decoding results
self.result_transcripts = ['']
def reset_continuous_decoding(self):
"""
when in continous decoding, reset for next utterance.
"""
self.global_frame_offset = self.num_frames
self.model_reset()
def reset(self):
if "deepspeech2" in self.model_type:
# for deepspeech2
# init state
self.chunk_state_h_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1)
else:
raise NotImplementedError(f"{self.model_type} not support.")
self.device = None
## common
# global sample and frame step
self.num_samples = 0
self.global_frame_offset = 0
# frame step of cur utterance
self.num_frames = 0
## endpoint
self.endpoint_state = False # True for detect endpoint
## conformer
self.model_reset()
## outputs
self.output_reset()
def extract_feat(self, samples: ByteString):
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
self.num_samples += samples.shape[0]
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
)
if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0
# fbank
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
# cur frame step
num_frames = x_chunk.shape[1]
# global frame step
self.num_frames += num_frames
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
)
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
def decode(self, is_finished=False):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Returns:
None:
"""
if "deepspeech2" in self.model_type:
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
context = 7 # context=7, in audio frame unit
subsampling = 4 # subsampling=4, in audio frame unit
cached_feature_num = context - subsampling
# decoding window for model, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context
# decoding stride for model, in audio frame unit
stride = subsampling * decoding_chunk_size
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
)
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data"
)
return None, None
# if is_finished=True, we need at least context frames
if num_frames < context:
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
# if get the finished chunk, we need process the last context
left_frames = context
else:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# extract the audio
x_chunk = self.cached_feat[:, cur:end, :].numpy()
x_chunk_lens = np.array([x_chunk.shape[1]])
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
self.result_transcripts = [trans_best]
# update feat cache
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
# return trans_best[0]
else:
raise Exception(f"{self.model_type} not support paddleinference.")
@paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens):
"""forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger.info("start to decoce one chunk for deepspeech2")
# state_c, state_h, audio_lens, audio
# 'chunk_state_c_box', 'chunk_state_h_box', 'audio_chunk_lens', 'audio_chunk'
input_names = [n.name for n in self.am_predictor.get_inputs()]
logger.info(f"ort inputs: {input_names}")
# 'softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'
# audio, audio_lens, state_h, state_c
output_names = [n.name for n in self.am_predictor.get_outputs()]
logger.info(f"ort outpus: {output_names}")
assert (len(input_names) == len(output_names))
assert isinstance(input_names[0], str)
input_datas = [
self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens,
x_chunk
]
feeds = dict(zip(input_names, input_datas))
outputs = self.am_predictor.run([*output_names], {**feeds})
output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0]
def get_result(self):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if len(self.result_transcripts) > 0:
return self.result_transcripts[0]
else:
return ''
def get_word_time_stamp(self):
return []
@paddle.no_grad()
def rescoring(self):
...
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
self.task_resource = CommonTaskResource(
task='asr', model_format='onnx', inference_mode='online')
def update_config(self) -> None:
if "deepspeech2" in self.model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
else:
raise NotImplementedError(
f"{self.model_type} not support paddleinference.")
def init_model(self) -> None:
if "deepspeech2" in self.model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor = onnx_infer.get_sess(
model_path=self.am_model, sess_conf=self.am_predictor_conf)
else:
raise NotImplementedError(
f"{self.model_type} not support paddleinference.")
def _init_from_path(self,
model_type: str=None,
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
am_predictor_conf: dict=None):
"""
Init model and other resources from a specific path.
"""
if not model_type or not lang or not sample_rate:
logger.error(
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
)
return False
assert am_params is None, "am_params not used in onnx engine"
self.model_type = model_type
self.sample_rate = sample_rate
self.decode_method = decode_method
self.num_decoding_left_chunks = num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None:
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
self.res_path, self.task_resource.res_dict['cfg_path'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[
'onnx_model']) if am_model is None else os.path.abspath(am_model)
# self.am_params = os.path.join(
# self.res_path, self.task_resource.res_dict[
# 'params']) if am_params is None else os.path.abspath(am_params)
logger.info("Load the pretrained model:")
logger.info(f" tag = {tag}")
logger.info(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}")
# logger.info(f" am_params path: {self.am_params}")
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
logger.info(f"spm model path: {self.config.spm_model_prefix}")
self.vocab = self.config.vocab_filepath
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.update_config()
# AM predictor
self.init_model()
logger.info(f"create the {model_type} model success")
return True
class ASREngine(BaseEngine):
"""ASR model resource
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(ASREngine, self).__init__()
def init_model(self) -> bool:
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf):
return False
return True
def init(self, config: dict) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.config = config
self.executor = ASRServerExecutor()
try:
self.device = self.config.get("device", paddle.get_device())
paddle.set_device(self.device)
except BaseException as e:
logger.error(
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
)
logger.error(
"If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1)
logger.info(f"paddlespeech_server set the device: {self.device}")
if not self.init_model():
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
)
return False
logger.info("Initialize ASR server engine successfully.")
return True
def new_handler(self):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return PaddleASRConnectionHanddler(self)
def preprocess(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")
def run(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")
def postprocess(self):
raise NotImplementedError("Online not using this.")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
......@@ -121,13 +121,13 @@ class PaddleASRConnectionHanddler:
raise ValueError(f"Not supported: {self.model_type}")
def model_reset(self):
if "deepspeech2" in self.model_type:
return
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
if "deepspeech2" in self.model_type:
return
## conformer
# cache for conformer online
self.subsampling_cache = None
......@@ -161,7 +161,9 @@ class PaddleASRConnectionHanddler:
self.model_reset()
self.searcher.reset()
self.endpointer.reset()
self.output_reset()
# reset hys will trancate history transcripts.
# self.output_reset()
def reset(self):
if "deepspeech2" in self.model_type:
......@@ -695,6 +697,66 @@ class ASRServerExecutor(ASRExecutor):
self.task_resource = CommonTaskResource(
task='asr', model_format='dynamic', inference_mode='online')
def update_config(self) -> None:
if "deepspeech2" in self.model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in self.model_type or "transformer" in self.model_type:
with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine")
# update the decoding method
if self.decode_method:
self.config.decode.decoding_method = self.decode_method
# update num_decoding_left_chunks
if self.num_decoding_left_chunks:
assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0"
self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [
"ctc_prefix_beam_search", "attention_rescoring"
]:
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else:
raise Exception(f"not support: {self.model_type}")
def init_model(self) -> None:
if "deepspeech2" in self.model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in self.model_type or "transformer" in self.model_type:
# load model
# model_type: {model_name}_{dataset}
model_name = self.model_type[:self.model_type.rindex('_')]
logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name)
model = model_class.from_config(self.config)
self.model = model
self.model.set_state_dict(paddle.load(self.am_model))
self.model.eval()
else:
raise Exception(f"not support: {self.model_type}")
def _init_from_path(self,
model_type: str=None,
am_model: Optional[os.PathLike]=None,
......@@ -716,6 +778,10 @@ class ASRServerExecutor(ASRExecutor):
self.model_type = model_type
self.sample_rate = sample_rate
self.decode_method = decode_method
self.num_decoding_left_chunks = num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
......@@ -761,62 +827,10 @@ class ASRServerExecutor(ASRExecutor):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
if "deepspeech2" in model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
self.update_config()
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in model_type or "transformer" in model_type:
with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine")
# update the decoding method
if decode_method:
self.config.decode.decoding_method = decode_method
# update num_decoding_left_chunks
if num_decoding_left_chunks:
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [
"ctc_prefix_beam_search", "attention_rescoring"
]:
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
# load model
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name)
model = model_class.from_config(self.config)
self.model = model
self.model.set_state_dict(paddle.load(self.am_model))
self.model.eval()
else:
raise Exception(f"not support: {model_type}")
self.init_model()
logger.info(f"create the {model_type} model success")
return True
......@@ -831,7 +845,20 @@ class ASREngine(BaseEngine):
def __init__(self):
super(ASREngine, self).__init__()
logger.info("create the online asr engine resource instance")
def init_model(self) -> bool:
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf):
return False
return True
def init(self, config: dict) -> bool:
"""init engine resource
......@@ -858,16 +885,7 @@ class ASREngine(BaseEngine):
logger.info(f"paddlespeech_server set the device: {self.device}")
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf):
if not self.init_model():
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
)
......
......@@ -13,12 +13,16 @@
# limitations under the License.
from typing import Text
from ..utils.log import logger
__all__ = ['EngineFactory']
class EngineFactory(object):
@staticmethod
def get_engine(engine_name: Text, engine_type: Text):
logger.info(f"{engine_name} : {engine_type} engine.")
if engine_name == 'asr' and engine_type == 'inference':
from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine
return ASREngine()
......@@ -26,7 +30,13 @@ class EngineFactory(object):
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online':
from paddlespeech.server.engine.asr.online.asr_engine import ASREngine
from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online-inference':
from paddlespeech.server.engine.asr.online.paddleinference.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online-onnx':
from paddlespeech.server.engine.asr.online.onnx.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'tts' and engine_type == 'inference':
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine
......
......@@ -16,9 +16,9 @@ from collections import OrderedDict
import numpy as np
import paddle
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.cli.log import logger
from paddlespeech.cli.vector.infer import VectorExecutor
from paddlespeech.server.engine.base_engine import BaseEngine
......
......@@ -24,11 +24,11 @@ from typing import Any
from typing import Dict
import paddle
import paddleaudio
import requests
import yaml
from paddle.framework import load
import paddlespeech.audio
from .entry import client_commands
from .entry import server_commands
from paddlespeech.cli import download
......@@ -289,7 +289,7 @@ def _note_one_stat(cls_name, params={}):
if 'audio_file' in params:
try:
_, sr = paddleaudio.load(params['audio_file'])
_, sr = paddlespeech.audio.load(params['audio_file'])
except Exception:
sr = -1
......
......@@ -16,21 +16,34 @@ from typing import Optional
import onnxruntime as ort
from .log import logger
def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
logger.info(f"ort sessconf: {sess_conf}")
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
if sess_conf.get('graph_optimization_level', 99) == 0:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if "gpu" in sess_conf["device"]:
# "gpu:0"
providers = ['CPUExecutionProvider']
if "gpu" in sess_conf.get("device", ""):
providers = ['CUDAExecutionProvider']
# fastspeech2/mb_melgan can't use trt now!
if sess_conf["use_trt"]:
if sess_conf.get("use_trt", 0):
providers = ['TensorrtExecutionProvider']
logger.info(f"ort providers: {providers}")
if 'cpu_threads' in sess_conf:
sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0)
else:
providers = ['CUDAExecutionProvider']
elif sess_conf["device"] == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = sess_conf["cpu_threads"]
sess_options.intra_op_num_threads = sess_conf.get(
"intra_op_num_threads", 0)
sess_options.inter_op_num_threads = sess_conf.get("inter_op_num_threads", 0)
sess = ort.InferenceSession(
model_path, providers=providers, sess_options=sess_options)
return sess
......@@ -92,6 +92,7 @@ async def websocket_endpoint(websocket: WebSocket):
else:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
elif "bytes" in message:
# bytes for the pcm data
message = message["bytes"]
......
......@@ -16,10 +16,10 @@ import os
import time
import paddle
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
......
......@@ -18,10 +18,10 @@ import numpy as np
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddleaudio.metric import compute_eer
from tqdm import tqdm
from yacs.config import CfgNode
from paddlespeech.audio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.io.dataset import CSVDataset
......
......@@ -20,9 +20,9 @@ import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment
......
......@@ -15,9 +15,9 @@ from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
from paddleaudio import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.audio import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
......
......@@ -16,9 +16,10 @@ from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
from paddleaudio import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddleaudio.compliance.librosa import mfcc
from paddlespeech.audio import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.audio.compliance.librosa import mfcc
@dataclass
......
......@@ -24,6 +24,7 @@ from setuptools import find_packages
from setuptools import setup
from setuptools.command.develop import develop
from setuptools.command.install import install
from setuptools.command.test import test
HERE = Path(os.path.abspath(os.path.dirname(__file__)))
......@@ -31,42 +32,13 @@ VERSION = '0.0.0'
COMMITID = 'none'
base = [
"editdistance",
"g2p_en",
"g2pM",
"h5py",
"inflect",
"jieba",
"jsonlines",
"kaldiio",
"librosa==0.8.1",
"loguru",
"matplotlib",
"nara_wpe",
"onnxruntime",
"pandas",
"paddleaudio",
"paddlenlp",
"paddlespeech_feat",
"praatio==5.0.0",
"pypinyin",
"pypinyin-dict",
"python-dateutil",
"pyworld",
"resampy==0.2.2",
"sacrebleu",
"scipy",
"sentencepiece~=0.1.96",
"soundfile~=0.10",
"textgrid",
"timer",
"tqdm",
"typeguard",
"visualdl",
"webrtcvad",
"yacs~=0.1.8",
"prettytable",
"zhon",
"editdistance", "g2p_en", "g2pM", "h5py", "inflect", "jieba", "jsonlines",
"kaldiio", "librosa==0.8.1", "loguru", "matplotlib", "nara_wpe",
"onnxruntime", "pandas", "paddlenlp", "paddlespeech_feat", "praatio==5.0.0",
"pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2",
"sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10",
"textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad",
"yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8'
]
server = [
......@@ -177,7 +149,19 @@ class InstallCommand(install):
install.run(self)
# cmd: python setup.py upload
class TestCommand(test):
def finalize_options(self):
test.finalize_options(self)
self.test_args = []
self.test_suite = True
def run_tests(self):
# Run nose ensuring that argv simulates running nosetests directly
import nose
nose.run_exit(argv=['nosetests', '-w', 'tests'])
# cmd: python setup.py upload
class UploadCommand(Command):
description = "Build and publish the package."
user_options = []
......@@ -279,11 +263,13 @@ setup_info = dict(
"sphinx", "sphinx-rtd-theme", "numpydoc", "myst_parser",
"recommonmark>=0.5.0", "sphinx-markdown-tables", "sphinx-autobuild"
],
'test': ['nose', 'torchaudio==0.10.2'],
},
cmdclass={
'develop': DevelopCommand,
'install': InstallCommand,
'upload': UploadCommand,
'test': TestCommand,
},
# Package info
......
# DeepSpeech2 ONNX model
1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
2. check paddleinference and onnxruntime output equal.
3. optimize onnx model
4. check paddleinference and optimized onnxruntime output equal.
Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct.
The example test with these packages installed:
```
paddle2onnx 0.9.8rc0 # develop af4354b4e9a61a93be6490640059a02a4499bc7a
paddleaudio 0.2.1
paddlefsl 1.1.0
paddlenlp 2.2.6
paddlepaddle-gpu 2.2.2
paddlespeech 0.0.0 # develop
paddlespeech-ctcdecoders 0.2.0
paddlespeech-feat 0.1.0
onnx 1.11.0
onnx-simplifier 0.0.0 # https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape
onnxoptimizer 0.2.7
onnxruntime 1.11.0
```
## Using
```
bash run.sh
```
For more details please see `run.sh`.
## Outputs
The optimized onnx model is `exp/model.opt.onnx`.
To show the graph, please using `local/netron.sh`.
#!/usr/bin/env python3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import pickle
import numpy as np
import onnxruntime
import paddle
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--input_file',
type=str,
default="static_ds2online_inputs.pickle",
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model",
)
parser.add_argument(
'--model_type',
type=str,
default="aishell",
help="aishell(1024) or wenetspeech(2048)", )
parser.add_argument(
'--model_dir', type=str, default=".", help="paddle model dir.")
parser.add_argument(
'--model_prefix',
type=str,
default="avg_1.jit",
help="paddle model prefix.")
parser.add_argument(
'--onnx_model',
type=str,
default='./model.old.onnx',
help="onnx model.")
return parser.parse_args()
if __name__ == '__main__':
FLAGS = parse_args()
# input and output
with open(FLAGS.input_file, 'rb') as f:
iodict = pickle.load(f)
print(iodict.keys())
audio_chunk = iodict['audio_chunk']
audio_chunk_lens = iodict['audio_chunk_lens']
chunk_state_h_box = iodict['chunk_state_h_box']
chunk_state_c_box = iodict['chunk_state_c_bos']
print("raw state shape: ", chunk_state_c_box.shape)
if FLAGS.model_type == 'wenetspeech':
chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1)
chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1)
print("state shape: ", chunk_state_c_box.shape)
# paddle
model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix))
res_chunk, res_lens, chunk_state_h, chunk_state_c = model(
paddle.to_tensor(audio_chunk),
paddle.to_tensor(audio_chunk_lens),
paddle.to_tensor(chunk_state_h_box),
paddle.to_tensor(chunk_state_c_box), )
# onnxruntime
options = onnxruntime.SessionOptions()
options.enable_profiling = True
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
"audio_chunk": audio_chunk,
"audio_chunk_lens": audio_chunk_lens,
"chunk_state_h_box": chunk_state_h_box,
"chunk_state_c_box": chunk_state_c_box
})
print(sess.end_profiling())
# assert paddle equal ort
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
if FLAGS.model_type == 'aishell':
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))
#!/bin/bash
# show model
if [ $# != 1 ];then
echo "usage: $0 model_path"
exit 1
fi
file=$1
pip install netron
netron -p 8082 --host $(hostname -i) $file
\ No newline at end of file
#!/bin/bash
# clone onnx repos
git clone https://github.com/onnx/onnx.git
git clone https://github.com/microsoft/onnxruntime.git
git clone https://github.com/PaddlePaddle/Paddle2ONNX.git
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册