提交 d46cee6a 编写于 作者: Y Yang Zhou

Merge branch 'develop' of github.com:SmileGoat/PaddleSpeech into audio_dev

...@@ -39,6 +39,9 @@ tools/env.sh ...@@ -39,6 +39,9 @@ tools/env.sh
tools/openfst-1.8.1/ tools/openfst-1.8.1/
tools/libsndfile/ tools/libsndfile/
tools/python-soundfile/ tools/python-soundfile/
tools/onnx
tools/onnxruntime
tools/Paddle2ONNX
speechx/fc_patch/ speechx/fc_patch/
......
...@@ -52,7 +52,7 @@ pull_request_rules: ...@@ -52,7 +52,7 @@ pull_request_rules:
add: ["T2S"] add: ["T2S"]
- name: "auto add label=Audio" - name: "auto add label=Audio"
conditions: conditions:
- files~=^paddleaudio/ - files~=^paddlespeech/audio/
actions: actions:
label: label:
add: ["Audio"] add: ["Audio"]
...@@ -100,7 +100,7 @@ pull_request_rules: ...@@ -100,7 +100,7 @@ pull_request_rules:
add: ["README"] add: ["README"]
- name: "auto add label=Documentation" - name: "auto add label=Documentation"
conditions: conditions:
- files~=^(docs/|CHANGELOG.md|paddleaudio/CHANGELOG.md) - files~=^(docs/|CHANGELOG.md)
actions: actions:
label: label:
add: ["Documentation"] add: ["Documentation"]
......
# Changelog
Date: 2022-3-15, Author: Xiaojie Chen.
- kaldi and librosa mfcc, fbank, spectrogram.
- unit test and benchmark.
Date: 2022-2-25, Author: Hui Zhang.
- Refactor architecture.
- dtw distance and mcd style dtw.
# PaddleAudio
PaddleAudio is an audio library for PaddlePaddle.
## Install
`pip install .`
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
\ No newline at end of file
# Build docs for PaddleAudio
Execute the following steps in **current directory**.
## 1. Install
`pip install Sphinx sphinx_rtd_theme`
## 2. Generate API docs
Generate API docs from doc string.
`sphinx-apidoc -fMeT -o source ../paddleaudio ../paddleaudio/utils --templatedir source/_templates`
## 3. Build
`sphinx-build source _html`
## 4. Preview
Open `_html/index.html` for page preview.
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
:end
popd
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
This module is used to store environmental variables in PaddleAudio.
PPAUDIO_HOME --> the root directory for storing PaddleAudio related data. Default to ~/.paddleaudio. Users can change the
├ default value through the PPAUDIO_HOME environment variable.
├─ MODEL_HOME --> Store model files.
└─ DATA_HOME --> Store automatically downloaded datasets.
'''
import os
__all__ = [
'USER_HOME',
'PPAUDIO_HOME',
'MODEL_HOME',
'DATA_HOME',
]
def _get_user_home():
return os.path.expanduser('~')
def _get_ppaudio_home():
if 'PPAUDIO_HOME' in os.environ:
home_path = os.environ['PPAUDIO_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
raise RuntimeError(
'The environment variable PPAUDIO_HOME {} is not a directory.'.
format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddleaudio')
def _get_sub_home(directory):
home = os.path.join(_get_ppaudio_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
USER_HOME = _get_user_home()
PPAUDIO_HOME = _get_ppaudio_home()
MODEL_HOME = _get_sub_home('models')
DATA_HOME = _get_sub_home('datasets')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
import setuptools
from setuptools.command.install import install
from setuptools.command.test import test
# set the version here
VERSION = '0.0.0'
# Inspired by the example at https://pytest.org/latest/goodpractises.html
class TestCommand(test):
def finalize_options(self):
test.finalize_options(self)
self.test_args = []
self.test_suite = True
def run(self):
self.run_benchmark()
super(TestCommand, self).run()
def run_tests(self):
# Run nose ensuring that argv simulates running nosetests directly
import nose
nose.run_exit(argv=['nosetests', '-w', 'tests'])
def run_benchmark(self):
for benchmark_item in glob.glob('tests/benchmark/*py'):
os.system(f'pytest {benchmark_item}')
class InstallCommand(install):
def run(self):
install.run(self)
def write_version_py(filename='paddleaudio/__init__.py'):
with open(filename, "a") as f:
f.write(f"__version__ = '{VERSION}'")
def remove_version_py(filename='paddleaudio/__init__.py'):
with open(filename, "r") as f:
lines = f.readlines()
with open(filename, "w") as f:
for line in lines:
if "__version__" not in line:
f.write(line)
remove_version_py()
write_version_py()
setuptools.setup(
name="paddleaudio",
version=VERSION,
author="",
author_email="",
description="PaddleAudio, in development",
long_description="",
long_description_content_type="text/markdown",
url="",
packages=setuptools.find_packages(include=['paddleaudio*']),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
install_requires=[
'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2',
'soundfile >= 0.9.0', 'colorlog', 'pathos == 0.2.8'
],
extras_require={
'test': [
'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1',
'torchaudio==0.10.2', 'pytest-benchmark'
],
},
cmdclass={
'install': InstallCommand,
'test': TestCommand,
}, )
remove_version_py()
...@@ -89,7 +89,7 @@ Then to start the system server, and it provides HTTP backend services. ...@@ -89,7 +89,7 @@ Then to start the system server, and it provides HTTP backend services.
Then start the server with Fastapi. Then start the server with Fastapi.
```bash ```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio export PYTHONPATH=$PYTHONPATH:./src
python src/audio_search.py python src/audio_search.py
``` ```
......
...@@ -91,7 +91,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" ...@@ -91,7 +91,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
启动用 Fastapi 构建的服务 启动用 Fastapi 构建的服务
```bash ```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio export PYTHONPATH=$PYTHONPATH:./src
python src/audio_search.py python src/audio_search.py
``` ```
......
...@@ -33,6 +33,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ...@@ -33,6 +33,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
```bash ```bash
# in PaddleSpeech/demos/streaming_asr_server start the service # in PaddleSpeech/demos/streaming_asr_server start the service
paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application.yaml paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application.yaml
# if you want to increase decoding speed, you can use the config file below, it will increase decoding speed and reduce accuracy
paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application_faster.yaml
``` ```
Usage: Usage:
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
下载好 `PaddleSpeech` 之后,进入到 `PaddleSpeech/demos/streaming_asr_server` 目录。 下载好 `PaddleSpeech` 之后,进入到 `PaddleSpeech/demos/streaming_asr_server` 目录。
配置文件可参见该目录下 `conf/ws_application.yaml``conf/ws_conformer_wenetspeech_application.yaml` 配置文件可参见该目录下 `conf/ws_application.yaml``conf/ws_conformer_wenetspeech_application.yaml`
目前服务集成的模型有: DeepSpeech2和 conformer模型,对应的配置文件如下: 目前服务集成的模型有: DeepSpeech2 和 conformer模型,对应的配置文件如下:
* DeepSpeech: `conf/ws_application.yaml` * DeepSpeech: `conf/ws_application.yaml`
* conformer: `conf/ws_conformer_wenetspeech_application.yaml` * conformer: `conf/ws_conformer_wenetspeech_application.yaml`
...@@ -40,6 +40,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ...@@ -40,6 +40,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
```bash ```bash
# 在 PaddleSpeech/demos/streaming_asr_server 目录启动服务 # 在 PaddleSpeech/demos/streaming_asr_server 目录启动服务
paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application.yaml paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application.yaml
# 你如果愿意为了增加解码的速度而牺牲一定的模型精度,你可以使用如下的脚本
paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application_faster.yaml
``` ```
使用方法: 使用方法:
......
...@@ -28,6 +28,7 @@ asr_online: ...@@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
decode_method: decode_method:
num_decoding_left_chunks: -1
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring" decode_method: "attention_rescoring"
......
...@@ -32,7 +32,7 @@ asr_online: ...@@ -32,7 +32,7 @@ asr_online:
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring" decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected continuous_decoding: True # enable continue decoding when endpoint detected
num_decoding_left_chunks: -1
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True
......
...@@ -7,8 +7,8 @@ host: 0.0.0.0 ...@@ -7,8 +7,8 @@ host: 0.0.0.0
port: 8090 port: 8090
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online'] # task choices = ['asr_online']
# protocol = ['websocket', 'http'] (only one can be selected). # protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type. # websocket only support online engine type.
protocol: 'websocket' protocol: 'websocket'
engine_list: ['asr_online'] engine_list: ['asr_online']
...@@ -21,7 +21,7 @@ engine_list: ['asr_online'] ...@@ -21,7 +21,7 @@ engine_list: ['asr_online']
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online ####################### ################### speech task: asr; engine_type: online #######################
asr_online: asr_online:
model_type: 'deepspeech2online_aishell' model_type: 'conformer_online_wenetspeech'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
...@@ -29,8 +29,10 @@ asr_online: ...@@ -29,8 +29,10 @@ asr_online:
cfg_path: cfg_path:
decode_method: decode_method:
force_yes: True force_yes: True
device: # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
num_decoding_left_chunks: 16
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True
...@@ -38,11 +40,9 @@ asr_online: ...@@ -38,11 +40,9 @@ asr_online:
summary: True # False -> do not show predictor config summary: True # False -> do not show predictor config
chunk_buffer_conf: chunk_buffer_conf:
frame_duration_ms: 80
shift_ms: 40
sample_rate: 16000
sample_width: 2
window_n: 7 # frame window_n: 7 # frame
shift_n: 4 # frame shift_n: 4 # frame
window_ms: 20 # ms window_ms: 25 # ms
shift_ms: 10 # ms shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2
...@@ -7,11 +7,11 @@ host: 0.0.0.0 ...@@ -7,11 +7,11 @@ host: 0.0.0.0
port: 8090 port: 8090
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online'] # task choices = ['asr_online-inference', 'asr_online-onnx']
# protocol = ['websocket'] (only one can be selected). # protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type. # websocket only support online engine type.
protocol: 'websocket' protocol: 'websocket'
engine_list: ['asr_online'] engine_list: ['asr_online-onnx']
################################################################################# #################################################################################
...@@ -19,15 +19,16 @@ engine_list: ['asr_online'] ...@@ -19,15 +19,16 @@ engine_list: ['asr_online']
################################################################################# #################################################################################
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online ####################### ################### speech task: asr; engine_type: online-inference #######################
asr_online: asr_online-inference:
model_type: 'deepspeech2online_aishell' model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
decode_method: decode_method:
num_decoding_left_chunks:
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
...@@ -37,6 +38,41 @@ asr_online: ...@@ -37,6 +38,41 @@ asr_online:
glog_info: False # True -> print glog glog_info: False # True -> print glog
summary: True # False -> do not show predictor config summary: True # False -> do not show predictor config
chunk_buffer_conf:
frame_duration_ms: 85
shift_ms: 40
sample_rate: 16000
sample_width: 2
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
num_decoding_left_chunks:
force_yes: True
device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu'
graph_optimization_level: 0
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf: chunk_buffer_conf:
frame_duration_ms: 80 frame_duration_ms: 80
shift_ms: 40 shift_ms: 40
...@@ -44,5 +80,5 @@ asr_online: ...@@ -44,5 +80,5 @@ asr_online:
sample_width: 2 sample_width: 2
window_n: 7 # frame window_n: 7 # frame
shift_n: 4 # frame shift_n: 4 # frame
window_ms: 20 # ms window_ms: 25 # ms
shift_ms: 10 # ms shift_ms: 10 # ms
#!/usr/bin/env python3
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog=__doc__)
parser.add_argument(
'--logfile', type=str, required=True, help='ws client log file')
args = parser.parse_args()
rtfs = []
with open(args.logfile, 'r') as f:
for line in f:
if 'RTF=' in line:
# udio duration: 6.126, elapsed time: 3.471978187561035, RTF=0.5667610492264177
line = line.strip()
beg = line.index("audio")
line = line[beg:]
items = line.split(',')
vals = []
for elem in items:
if "RTF=" in elem:
continue
_, val = elem.split(":")
vals.append(eval(val))
keys = ['T', 'P']
meta = dict(zip(keys, vals))
rtfs.append(meta)
T = 0.0
P = 0.0
n = 0
for m in rtfs:
n += 1
T += m['T']
P += m['P']
print(f"RTF: {P/T}, utts: {n}")
#!/bin/bash
if [ $# != 1 ];then
echo "usage: $0 wav_scp"
exit -1
fi
scp=$1
# calc RTF
# wav_scp can generate from `speechx/examples/ds2_ol/aishell`
exp=exp
mkdir -p $exp
python3 local/websocket_client.py --server_ip 127.0.0.1 --port 8090 --wavscp $scp &> $exp/log.rsl
python3 local/rtf_from_log.py --logfile $exp/log.rsl
\ No newline at end of file
#!/usr/bin/python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,9 +12,9 @@ ...@@ -11,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#!/usr/bin/python # calc avg RTF(NOT Accurate): grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
# -*- coding: UTF-8 -*- # python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}' # python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
import argparse import argparse
import asyncio import asyncio
import codecs import codecs
......
...@@ -4,6 +4,6 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3 ...@@ -4,6 +4,6 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3
# nohup python3 punc_server.py --config_file conf/punc_application.yaml > punc.log 2>&1 & # nohup python3 punc_server.py --config_file conf/punc_application.yaml > punc.log 2>&1 &
paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log & paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 & # nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_wenetspeech_application.yaml > streaming_asr.log 2>&1 &
paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log & paddlespeech_server start --config_file conf/ws_conformer_wenetspeech_application.yaml &> streaming_asr.log &
...@@ -3,11 +3,9 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav ...@@ -3,11 +3,9 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
# read the wav and pass it to only streaming asr service # read the wav and pass it to only streaming asr service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address. # If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wav
# read the wav and call streaming and punc service # read the wav and call streaming and punc service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address. # If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav
# Customize Dataset for Audio Classification # Customize Dataset for Audio Classification
Following this tutorial you can customize your dataset for audio classification task by using `paddlespeech` and `paddleaudio`. Following this tutorial you can customize your dataset for audio classification task by using `paddlespeech`.
A base class of classification dataset is `paddleaudio.dataset.AudioClassificationDataset`. To customize your dataset you should write a dataset class derived from `AudioClassificationDataset`. A base class of classification dataset is `paddlespeech.audio.dataset.AudioClassificationDataset`. To customize your dataset you should write a dataset class derived from `AudioClassificationDataset`.
Assuming you have some wave files that stored in your own directory. You should prepare a meta file with the information of filepaths and labels. For example the absolute path of it is `/PATH/TO/META_FILE.txt`: Assuming you have some wave files that stored in your own directory. You should prepare a meta file with the information of filepaths and labels. For example the absolute path of it is `/PATH/TO/META_FILE.txt`:
``` ```
...@@ -14,7 +14,7 @@ Assuming you have some wave files that stored in your own directory. You should ...@@ -14,7 +14,7 @@ Assuming you have some wave files that stored in your own directory. You should
Here is an example to build your custom dataset in `custom_dataset.py`: Here is an example to build your custom dataset in `custom_dataset.py`:
```python ```python
from paddleaudio.datasets.dataset import AudioClassificationDataset from paddlespeech.audio.datasets.dataset import AudioClassificationDataset
class CustomDataset(AudioClassificationDataset): class CustomDataset(AudioClassificationDataset):
meta_file = '/PATH/TO/META_FILE.txt' meta_file = '/PATH/TO/META_FILE.txt'
...@@ -48,7 +48,7 @@ class CustomDataset(AudioClassificationDataset): ...@@ -48,7 +48,7 @@ class CustomDataset(AudioClassificationDataset):
Then you can build dataset and data loader from `CustomDataset`: Then you can build dataset and data loader from `CustomDataset`:
```python ```python
import paddle import paddle
from paddleaudio.features import LogMelSpectrogram from paddlespeech.audio.features import LogMelSpectrogram
from custom_dataset import CustomDataset from custom_dataset import CustomDataset
......
...@@ -6,15 +6,15 @@ ...@@ -6,15 +6,15 @@
### Speech Recognition Model ### Speech Recognition Model
Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | Example Link Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | Hours of speech | Example Link
:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: :-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----:
[Ds2 Online Wenetspeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz) | Wenetspeech Dataset | Char-based | 1.2 GB | 2 Conv + 5 LSTM layers | 0.152 (test\_net, w/o LM) <br> 0.2417 (test\_meeting, w/o LM) <br> 0.053 (aishell, w/ LM) |-| 10000 h |- [Ds2 Online Wenetspeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz) | Wenetspeech Dataset | Char-based | 1.2 GB | 2 Conv + 5 LSTM layers | 0.152 (test\_net, w/o LM) <br> 0.2417 (test\_meeting, w/o LM) <br> 0.053 (aishell, w/ LM) |-| 10000 h |-
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) [Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0)
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.064 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) [Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0)
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- [Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |-
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) [Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1)
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0464 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) [Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0464 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1)
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) [Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1)
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz)| Librispeech Dataset | Char-based | 518 MB | 2 Conv + 3 bidirectional LSTM layers| - |0.0725| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) [Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0)
[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0337 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1) [Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0338 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1)
[Transformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0381 | 960 h | [Transformer Librispeech ASR1](../../examples/librispeech/asr1) [Transformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0381 | 960 h | [Transformer Librispeech ASR1](../../examples/librispeech/asr1)
[Transformer Librispeech ASR2 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: JoinCTC w/ LM |-| 0.0240 | 960 h | [Transformer Librispeech ASR2](../../examples/librispeech/asr2) [Transformer Librispeech ASR2 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: JoinCTC w/ LM |-| 0.0240 | 960 h | [Transformer Librispeech ASR2](../../examples/librispeech/asr2)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
| Model | Number of Params | Release | Config | Test set | Valid Loss | CER | | Model | Number of Params | Release | Config | Test set | Valid Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 122.3M | r1.0.1 | conf/deepspeech2.yaml + U2 Data pipline and spec aug + fbank161 | test | 5.780756044387817 | 0.055400 |
| DeepSpeech2 | 58.4M | v2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.738585948944092 | 0.064000 | | DeepSpeech2 | 58.4M | v2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.738585948944092 | 0.064000 |
| DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | | DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 |
| DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | | DeepSpeech2 | 58.4M | v2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 |
......
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..." ...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp mkdir -p exp
...@@ -26,7 +33,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -26,7 +33,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
......
...@@ -6,6 +6,7 @@ gpus=0,1,2,3 ...@@ -6,6 +6,7 @@ gpus=0,1,2,3
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/deepspeech2.yaml #conf/deepspeech2.yaml or conf/deepspeech2_online.yaml conf_path=conf/deepspeech2.yaml #conf/deepspeech2.yaml or conf/deepspeech2_online.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=10 avg_num=10
audio_file=data/demo_01_03.wav audio_file=data/demo_01_03.wav
...@@ -24,7 +25,7 @@ fi ...@@ -24,7 +25,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
...@@ -17,13 +17,21 @@ if [ ${seed} != 0 ]; then ...@@ -17,13 +17,21 @@ if [ ${seed} != 0 ]; then
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi fi
if [ $# != 2 ];then if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
echo ${ips_config}
mkdir -p exp mkdir -p exp
...@@ -37,7 +45,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -37,7 +45,7 @@ python3 -u ${BIN_DIR}/train.py \
--benchmark-batch-size ${benchmark_batch_size} \ --benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step} --benchmark-max-step ${benchmark_max_step}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--seed ${seed} \ --seed ${seed} \
--config ${config_path} \ --config ${config_path} \
......
...@@ -6,6 +6,7 @@ gpus=0,1,2,3 ...@@ -6,6 +6,7 @@ gpus=0,1,2,3
stage=0 stage=0
stop_stage=50 stop_stage=50
conf_path=conf/conformer.yaml conf_path=conf/conformer.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=30 avg_num=30
audio_file=data/demo_01_03.wav audio_file=data/demo_01_03.wav
...@@ -23,7 +24,7 @@ fi ...@@ -23,7 +24,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
#! /usr/bin/env bash #! /usr/bin/env bash
if [ $# != 2 ];then if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..." ...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
echo "using ${device}..." echo "using ${device}..."
...@@ -28,7 +35,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -28,7 +35,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
......
...@@ -6,6 +6,7 @@ gpus=0,1,2,3 ...@@ -6,6 +6,7 @@ gpus=0,1,2,3
stage=0 stage=0
stop_stage=50 stop_stage=50
conf_path=conf/conformer.yaml conf_path=conf/conformer.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=20 avg_num=20
...@@ -22,7 +23,7 @@ fi ...@@ -22,7 +23,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
data: data:
dataset: 'paddleaudio.datasets:ESC50' dataset: 'paddlespeech.audio.datasets:ESC50'
num_classes: 50 num_classes: 50
train: train:
mode: 'train' mode: 'train'
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
########################################### ###########################################
# Data # # Data #
########################################### ###########################################
dataset: 'paddleaudio.datasets:HeySnips' dataset: 'paddlespeech.audio.datasets:HeySnips'
data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter' data_dir: '/PATH/TO/DATA/hey_snips_research_6k_en_train_eval_clean_ter'
############################################ ############################################
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
## Deepspeech2 Non-Streaming ## Deepspeech2 Non-Streaming
| Model | Params | release | Config | Test set | Loss | WER | | Model | Params | release | Config | Test set | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 113.96M | r1.0.1 | conf/deepspeech2.yaml + U2 Data pipline and spec aug + fbank161 | test-clean | 10.76069622039795 | 0.046700 |
| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | test-clean | 14.49190807 | 0.067283 | | DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | test-clean | 14.49190807 | 0.067283 |
| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | test-clean | 15.184467315673828 | 0.072154 | | DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | test-clean | 15.184467315673828 | 0.072154 |
| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | test-clean | - | 0.073973 | | DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | test-clean | - | 0.073973 |
......
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..." ...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp mkdir -p exp
...@@ -26,7 +33,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -26,7 +33,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
......
...@@ -6,6 +6,7 @@ gpus=0,1,2,3 ...@@ -6,6 +6,7 @@ gpus=0,1,2,3
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=5 avg_num=5
audio_file=data/demo_002_en.wav audio_file=data/demo_002_en.wav
...@@ -23,7 +24,7 @@ fi ...@@ -23,7 +24,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
...@@ -42,6 +42,11 @@ echo "chunk mode ${chunk_mode}" ...@@ -42,6 +42,11 @@ echo "chunk mode ${chunk_mode}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# format the reference test file
python3 utils/format_rsl.py \
--origin_ref data/manifest.test-clean.raw \
--trans_ref data/manifest.test-clean.text
for type in attention; do for type in attention; do
echo "decoding ${type}" echo "decoding ${type}"
if [ ${chunk_mode} == true ];then if [ ${chunk_mode} == true ];then
...@@ -63,11 +68,16 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then ...@@ -63,11 +68,16 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"
exit 1 exit 1
fi fi
python3 utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.${type}.rsl \
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
python3 utils/compute-wer.py --char=1 --v=1 \
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
echo "decoding ${type} done." echo "decoding ${type} done."
done done
fi
for type in ctc_greedy_search; do for type in ctc_greedy_search; do
echo "decoding ${type}" echo "decoding ${type}"
if [ ${chunk_mode} == true ];then if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1 # stream decoding only support batchsize=1
...@@ -88,12 +98,18 @@ for type in ctc_greedy_search; do ...@@ -88,12 +98,18 @@ for type in ctc_greedy_search; do
echo "Failed in evaluation!" echo "Failed in evaluation!"
exit 1 exit 1
fi fi
python3 utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.${type}.rsl \
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
python3 utils/compute-wer.py --char=1 --v=1 \
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
echo "decoding ${type} done." echo "decoding ${type} done."
done done
for type in ctc_prefix_beam_search attention_rescoring; do for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}" echo "decoding ${type}"
batch_size=1 batch_size=1
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
...@@ -109,8 +125,33 @@ for type in ctc_prefix_beam_search attention_rescoring; do ...@@ -109,8 +125,33 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "Failed in evaluation!" echo "Failed in evaluation!"
exit 1 exit 1
fi fi
python3 utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.${type}.rsl \
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
python3 utils/compute-wer.py --char=1 --v=1 \
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
echo "decoding ${type} done." echo "decoding ${type} done."
done done
fi
if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
python3 utils/format_rsl.py \
--origin_ref data/manifest.test-clean.raw \
--trans_ref_sclite data/manifest.test.text-clean.sclite
output_dir=${ckpt_prefix}
for type in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl \
--trans_hyp_sclite ${output_dir}/${type}.rsl.text.sclite
mkdir -p ${output_dir}/${type}_sclite
sclite -i wsj -r data/manifest.test-clean.text.sclite -h ${output_dir}/${type}.rsl.text.sclite -e utf-8 -o all -O ${output_dir}/${type}_sclite -c NOASCII
done
fi
echo "Finished" echo "Finished"
......
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..." ...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp mkdir -p exp
...@@ -29,7 +36,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -29,7 +36,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
......
...@@ -8,6 +8,7 @@ gpus=0,1,2,3 ...@@ -8,6 +8,7 @@ gpus=0,1,2,3
stage=0 stage=0
stop_stage=50 stop_stage=50
conf_path=conf/transformer.yaml conf_path=conf/transformer.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=30 avg_num=30
audio_file=data/demo_002_en.wav audio_file=data/demo_002_en.wav
...@@ -25,7 +26,7 @@ fi ...@@ -25,7 +26,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..." ...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp mkdir -p exp
...@@ -27,7 +34,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -27,7 +34,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--model-name u2_kaldi \ --model-name u2_kaldi \
--config ${config_path} \ --config ${config_path} \
......
...@@ -9,6 +9,7 @@ gpus=0,1,2,3,4,5,6,7 ...@@ -9,6 +9,7 @@ gpus=0,1,2,3,4,5,6,7
stage=0 stage=0
stop_stage=50 stop_stage=50
conf_path=conf/transformer.yaml conf_path=conf/transformer.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/decode/decode_base.yaml decode_conf_path=conf/decode/decode_base.yaml
dict_path=data/lang_char/train_960_unigram5000_units.txt dict_path=data/lang_char/train_960_unigram5000_units.txt
avg_num=10 avg_num=10
...@@ -26,7 +27,7 @@ fi ...@@ -26,7 +27,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
#!/bin/bash #!/bin/bash
if [ $# != 3 ];then if [ $# -lt 3 ] && [ $# -gt 4 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ckpt_path" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ckpt_path ips(optional)"
exit -1 exit -1
fi fi
...@@ -11,6 +11,13 @@ echo "using $ngpu gpus..." ...@@ -11,6 +11,13 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ckpt_path=$3 ckpt_path=$3
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp mkdir -p exp
...@@ -21,12 +28,21 @@ if [ ${seed} != 0 ]; then ...@@ -21,12 +28,21 @@ if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
fi fi
if [ ${ngpu} == 0 ]; then
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--checkpoint_path "${ckpt_path}" \ --checkpoint_path "${ckpt_path}" \
--seed ${seed} --seed ${seed}
else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--checkpoint_path "${ckpt_path}" \
--seed ${seed}
fi
if [ ${seed} != 0 ]; then if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic unset FLAGS_cudnn_deterministic
......
...@@ -7,6 +7,7 @@ gpus=0,1,2,3 ...@@ -7,6 +7,7 @@ gpus=0,1,2,3
stage=0 stage=0
stop_stage=3 stop_stage=3
conf_path=conf/transformer_es.yaml conf_path=conf/transformer_es.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
must_c_path= must_c_path=
lang=es lang=es
...@@ -25,7 +26,7 @@ fi ...@@ -25,7 +26,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} "${ckpt_path}" CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} "${ckpt_path}" ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
#!/bin/bash #!/bin/bash
if [ $# != 2 ];then if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..." ...@@ -10,6 +10,13 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp mkdir -p exp
...@@ -26,7 +33,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -26,7 +33,7 @@ python3 -u ${BIN_DIR}/train.py \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
--seed ${seed} --seed ${seed}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
......
...@@ -6,6 +6,7 @@ gpus=0,1,2,3 ...@@ -6,6 +6,7 @@ gpus=0,1,2,3
stage=0 stage=0
stop_stage=50 stop_stage=50
conf_path=conf/transformer_mtl_noam.yaml conf_path=conf/transformer_mtl_noam.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=5 avg_num=5
data_path=./TED_EnZh # path to unzipped data data_path=./TED_EnZh # path to unzipped data
...@@ -23,7 +24,7 @@ fi ...@@ -23,7 +24,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
#!/bin/bash #!/bin/bash
if [ $# != 3 ];then if [ $# -lt 3 ] && [ $# -gt 4 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ckpt_path" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
...@@ -11,6 +11,15 @@ echo "using $ngpu gpus..." ...@@ -11,6 +11,15 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ckpt_path=$3 ckpt_path=$3
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp
mkdir -p exp mkdir -p exp
...@@ -28,7 +37,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -28,7 +37,7 @@ python3 -u ${BIN_DIR}/train.py \
--checkpoint_path "${ckpt_path}" \ --checkpoint_path "${ckpt_path}" \
--seed ${seed} --seed ${seed}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
......
...@@ -7,6 +7,7 @@ gpus=0,1,2,3 ...@@ -7,6 +7,7 @@ gpus=0,1,2,3
stage=1 stage=1
stop_stage=4 stop_stage=4
conf_path=conf/transformer_mtl_noam.yaml conf_path=conf/transformer_mtl_noam.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
ckpt_path= # paddle.98 # (finetune from FAT-ST pretrained model) ckpt_path= # paddle.98 # (finetune from FAT-ST pretrained model)
avg_num=5 avg_num=5
...@@ -29,7 +30,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -29,7 +30,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Finetune from Pretrained Model" ${ckpt_path} echo "Finetune from Pretrained Model" ${ckpt_path}
./local/download_pretrain.sh || exit -1 ./local/download_pretrain.sh || exit -1
fi fi
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} "${ckpt_path}" CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} "${ckpt_path}" ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
...@@ -15,13 +15,20 @@ if [ ${seed} != 0 ]; then ...@@ -15,13 +15,20 @@ if [ ${seed} != 0 ]; then
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi fi
if [ $# != 2 ];then if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp mkdir -p exp
...@@ -33,7 +40,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -33,7 +40,7 @@ python3 -u ${BIN_DIR}/train.py \
--profiler-options "${profiler_options}" \ --profiler-options "${profiler_options}" \
--seed ${seed} --seed ${seed}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--output exp/${ckpt_name} \ --output exp/${ckpt_name} \
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
set -e set -e
source path.sh source path.sh
gpus=0 gpus=4
stage=0 stage=0
stop_stage=100 stop_stage=100
conf_path=conf/deepspeech2.yaml conf_path=conf/deepspeech2.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=1 avg_num=1
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
...@@ -21,7 +22,7 @@ fi ...@@ -21,7 +22,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
...@@ -17,13 +17,20 @@ if [ ${seed} != 0 ]; then ...@@ -17,13 +17,20 @@ if [ ${seed} != 0 ]; then
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi fi
if [ $# != 2 ];then if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1 exit -1
fi fi
config_path=$1 config_path=$1
ckpt_name=$2 ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp mkdir -p exp
...@@ -37,7 +44,7 @@ python3 -u ${BIN_DIR}/train.py \ ...@@ -37,7 +44,7 @@ python3 -u ${BIN_DIR}/train.py \
--benchmark-batch-size ${benchmark_batch_size} \ --benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step} --benchmark-max-step ${benchmark_max_step}
else else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${BIN_DIR}/train.py \ python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--seed ${seed} \ --seed ${seed} \
--config ${config_path} \ --config ${config_path} \
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
set -e set -e
source path.sh source path.sh
gpus=0 gpus=4
stage=0 stage=0
stop_stage=50 stop_stage=50
conf_path=conf/transformer.yaml conf_path=conf/transformer.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml decode_conf_path=conf/tuning/decode.yaml
avg_num=1 avg_num=1
...@@ -22,7 +23,7 @@ fi ...@@ -22,7 +23,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir # train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
import argparse import argparse
import paddle import paddle
from paddleaudio.datasets.voxceleb import VoxCeleb
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.datasets.voxceleb import VoxCeleb
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.training.seeding import seed_everything from paddlespeech.vector.training.seeding import seed_everything
......
...@@ -21,9 +21,9 @@ import os ...@@ -21,9 +21,9 @@ import os
from typing import List from typing import List
import tqdm import tqdm
from paddleaudio import load as load_audio
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio import load as load_audio
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.vector_utils import get_chunks from paddlespeech.vector.utils.vector_utils import get_chunks
......
...@@ -22,9 +22,9 @@ import os ...@@ -22,9 +22,9 @@ import os
import random import random
import tqdm import tqdm
from paddleaudio import load as load_audio
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio import load as load_audio
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.utils.vector_utils import get_chunks from paddlespeech.vector.utils.vector_utils import get_chunks
......
...@@ -16,8 +16,8 @@ import os ...@@ -16,8 +16,8 @@ import os
from typing import List from typing import List
from typing import Tuple from typing import Tuple
from ..utils import DATA_HOME
from ..utils.download import download_and_decompress from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset from .dataset import AudioClassificationDataset
__all__ = ['ESC50'] __all__ = ['ESC50']
......
...@@ -17,8 +17,8 @@ import random ...@@ -17,8 +17,8 @@ import random
from typing import List from typing import List
from typing import Tuple from typing import Tuple
from ..utils import DATA_HOME
from ..utils.download import download_and_decompress from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset from .dataset import AudioClassificationDataset
__all__ = ['GTZAN'] __all__ = ['GTZAN']
......
...@@ -17,8 +17,8 @@ import random ...@@ -17,8 +17,8 @@ import random
from typing import List from typing import List
from typing import Tuple from typing import Tuple
from ..utils import DATA_HOME
from ..utils.download import download_and_decompress from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset from .dataset import AudioClassificationDataset
__all__ = ['TESS'] __all__ = ['TESS']
......
...@@ -16,8 +16,8 @@ import os ...@@ -16,8 +16,8 @@ import os
from typing import List from typing import List
from typing import Tuple from typing import Tuple
from ..utils import DATA_HOME
from ..utils.download import download_and_decompress from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset from .dataset import AudioClassificationDataset
__all__ = ['UrbanSound8K'] __all__ = ['UrbanSound8K']
......
...@@ -11,13 +11,11 @@ ...@@ -11,13 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...cli.utils import DATA_HOME
from ...cli.utils import MODEL_HOME
from .download import decompress from .download import decompress
from .download import download_and_decompress from .download import download_and_decompress
from .download import load_state_dict_from_url from .download import load_state_dict_from_url
from .env import DATA_HOME
from .env import MODEL_HOME
from .env import PPAUDIO_HOME
from .env import USER_HOME
from .error import ParameterError from .error import ParameterError
from .log import Logger from .log import Logger
from .log import logger from .log import logger
......
...@@ -83,6 +83,12 @@ class ASRExecutor(BaseExecutor): ...@@ -83,6 +83,12 @@ class ASRExecutor(BaseExecutor):
'attention_rescoring' 'attention_rescoring'
], ],
help='only support transformer and conformer model') help='only support transformer and conformer model')
self.parser.add_argument(
'--num_decoding_left_chunks',
'-num_left',
type=str,
default=-1,
help='only support transformer and conformer online model')
self.parser.add_argument( self.parser.add_argument(
'--ckpt_path', '--ckpt_path',
type=str, type=str,
...@@ -122,6 +128,7 @@ class ASRExecutor(BaseExecutor): ...@@ -122,6 +128,7 @@ class ASRExecutor(BaseExecutor):
sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring', decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
ckpt_path: Optional[os.PathLike]=None): ckpt_path: Optional[os.PathLike]=None):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
...@@ -179,6 +186,9 @@ class ASRExecutor(BaseExecutor): ...@@ -179,6 +186,9 @@ class ASRExecutor(BaseExecutor):
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
if num_decoding_left_chunks:
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0"
self.config.num_decoding_left_chunks = num_decoding_left_chunks
else: else:
raise Exception("wrong type") raise Exception("wrong type")
...@@ -451,6 +461,7 @@ class ASRExecutor(BaseExecutor): ...@@ -451,6 +461,7 @@ class ASRExecutor(BaseExecutor):
config: os.PathLike=None, config: os.PathLike=None,
ckpt_path: os.PathLike=None, ckpt_path: os.PathLike=None,
decode_method: str='attention_rescoring', decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
force_yes: bool=False, force_yes: bool=False,
rtf: bool=False, rtf: bool=False,
device=paddle.get_device()): device=paddle.get_device()):
...@@ -460,7 +471,7 @@ class ASRExecutor(BaseExecutor): ...@@ -460,7 +471,7 @@ class ASRExecutor(BaseExecutor):
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, decode_method, self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path) num_decoding_left_chunks, ckpt_path)
if not self._check(audio_file, sample_rate, force_yes): if not self._check(audio_file, sample_rate, force_yes):
sys.exit(-1) sys.exit(-1)
if rtf: if rtf:
......
...@@ -21,12 +21,12 @@ from typing import Union ...@@ -21,12 +21,12 @@ from typing import Union
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import stats_wrapper from ..utils import stats_wrapper
from paddlespeech.audio import load
from paddlespeech.audio.features import LogMelSpectrogram
__all__ = ['CLSExecutor'] __all__ = ['CLSExecutor']
......
...@@ -24,8 +24,8 @@ from typing import Any ...@@ -24,8 +24,8 @@ from typing import Any
from typing import Dict from typing import Dict
import paddle import paddle
import paddleaudio
import requests import requests
import soundfile as sf
import yaml import yaml
from paddle.framework import load from paddle.framework import load
...@@ -190,6 +190,7 @@ def _get_sub_home(directory): ...@@ -190,6 +190,7 @@ def _get_sub_home(directory):
PPSPEECH_HOME = _get_paddlespcceh_home() PPSPEECH_HOME = _get_paddlespcceh_home()
MODEL_HOME = _get_sub_home('models') MODEL_HOME = _get_sub_home('models')
CONF_HOME = _get_sub_home('conf') CONF_HOME = _get_sub_home('conf')
DATA_HOME = _get_sub_home('datasets')
def _md5(text: str): def _md5(text: str):
...@@ -281,7 +282,8 @@ def _note_one_stat(cls_name, params={}): ...@@ -281,7 +282,8 @@ def _note_one_stat(cls_name, params={}):
if 'audio_file' in params: if 'audio_file' in params:
try: try:
_, sr = paddleaudio.load(params['audio_file']) # recursive import cased by: utils.DATA_HOME
_, sr = sf.read(params['audio_file'])
except Exception: except Exception:
sr = -1 sr = -1
......
...@@ -22,13 +22,13 @@ from typing import Union ...@@ -22,13 +22,13 @@ from typing import Union
import paddle import paddle
import soundfile import soundfile
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import stats_wrapper from ..utils import stats_wrapper
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification from paddlespeech.vector.modules.sid_model import SpeakerIdetification
......
...@@ -16,11 +16,12 @@ import os ...@@ -16,11 +16,12 @@ import os
import numpy as np import numpy as np
from paddle import inference from paddle import inference
from paddleaudio.backends import load as load_audio
from paddleaudio.datasets import ESC50
from paddleaudio.features import melspectrogram
from scipy.special import softmax from scipy.special import softmax
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.datasets import ESC50
from paddlespeech.audio.features import melspectrogram
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.") parser.add_argument("--model_dir", type=str, required=True, default="./export", help="The directory to static model.")
......
...@@ -15,8 +15,8 @@ import argparse ...@@ -15,8 +15,8 @@ import argparse
import os import os
import paddle import paddle
from paddleaudio.datasets import ESC50
from paddlespeech.audio.datasets import ESC50
from paddlespeech.cls.models import cnn14 from paddlespeech.cls.models import cnn14
from paddlespeech.cls.models import SoundClassifier from paddlespeech.cls.models import SoundClassifier
......
...@@ -17,10 +17,10 @@ import os ...@@ -17,10 +17,10 @@ import os
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
import yaml import yaml
from paddleaudio.backends import load as load_audio
from paddleaudio.features import LogMelSpectrogram
from paddleaudio.utils import logger
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.features import LogMelSpectrogram
from paddlespeech.audio.utils import logger
from paddlespeech.cls.models import SoundClassifier from paddlespeech.cls.models import SoundClassifier
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
......
...@@ -16,10 +16,10 @@ import os ...@@ -16,10 +16,10 @@ import os
import paddle import paddle
import yaml import yaml
from paddleaudio.features import LogMelSpectrogram
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
from paddlespeech.audio.features import LogMelSpectrogram
from paddlespeech.audio.utils import logger
from paddlespeech.audio.utils import Timer
from paddlespeech.cls.models import SoundClassifier from paddlespeech.cls.models import SoundClassifier
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
......
...@@ -15,8 +15,9 @@ import os ...@@ -15,8 +15,9 @@ import os
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddleaudio.utils.download import load_state_dict_from_url
from paddleaudio.utils.env import MODEL_HOME from paddlespeech.audio.utils import MODEL_HOME
from paddlespeech.audio.utils.download import load_state_dict_from_url
__all__ = ['CNN14', 'CNN10', 'CNN6', 'cnn14', 'cnn10', 'cnn6'] __all__ = ['CNN14', 'CNN10', 'CNN6', 'cnn14', 'cnn10', 'cnn6']
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import os import os
import paddle import paddle
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.utils import logger
from paddlespeech.audio.utils import Timer
from paddlespeech.kws.exps.mdtc.collate import collate_features from paddlespeech.kws.exps.mdtc.collate import collate_features
from paddlespeech.kws.models.loss import max_pooling_loss from paddlespeech.kws.models.loss import max_pooling_loss
from paddlespeech.kws.models.mdtc import KWSModel from paddlespeech.kws.models.mdtc import KWSModel
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
__all__ = [ __all__ = [
'asr_dynamic_pretrained_models', 'asr_dynamic_pretrained_models',
'asr_static_pretrained_models', 'asr_static_pretrained_models',
'asr_onnx_pretrained_models',
'cls_dynamic_pretrained_models', 'cls_dynamic_pretrained_models',
'cls_static_pretrained_models', 'cls_static_pretrained_models',
'st_dynamic_pretrained_models', 'st_dynamic_pretrained_models',
...@@ -134,15 +135,21 @@ asr_dynamic_pretrained_models = { ...@@ -134,15 +135,21 @@ asr_dynamic_pretrained_models = {
}, },
}, },
"deepspeech2online_wenetspeech-zh-16k": { "deepspeech2online_wenetspeech-zh-16k": {
'1.0': { '1.0.3': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.1.model.tar.gz', 'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.3.model.tar.gz',
'md5': 'md5':
'd1be86a3e786042ab64f05161b5fae62', 'cfe273793e68f790f742b411c98bc75e',
'cfg_path': 'cfg_path':
'model.yaml', 'model.yaml',
'ckpt_path': 'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10', 'exp/deepspeech2_online/checkpoints/avg_10',
'model':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url': 'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5': 'lm_md5':
...@@ -166,11 +173,11 @@ asr_dynamic_pretrained_models = { ...@@ -166,11 +173,11 @@ asr_dynamic_pretrained_models = {
}, },
}, },
"deepspeech2online_aishell-zh-16k": { "deepspeech2online_aishell-zh-16k": {
'1.0': { '1.0.2': {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz', 'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz',
'md5': 'md5':
'df5ddeac8b679a470176649ac4b78726', '4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path': 'cfg_path':
'model.yaml', 'model.yaml',
'ckpt_path': 'ckpt_path':
...@@ -179,6 +186,8 @@ asr_dynamic_pretrained_models = { ...@@ -179,6 +186,8 @@ asr_dynamic_pretrained_models = {
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params': 'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url': 'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5': 'lm_md5':
...@@ -224,6 +233,115 @@ asr_static_pretrained_models = { ...@@ -224,6 +233,115 @@ asr_static_pretrained_models = {
'29e02312deb2e59b3c8686c7966d4fe3' '29e02312deb2e59b3c8686c7966d4fe3'
} }
}, },
"deepspeech2online_aishell-zh-16k": {
'1.0.1': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz',
'md5':
'df5ddeac8b679a470176649ac4b78726',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
'1.0.2': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz',
'md5':
'4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2online_wenetspeech-zh-16k": {
'1.0.3': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.3.model.tar.gz',
'md5':
'cfe273793e68f790f742b411c98bc75e',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'model':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}
asr_onnx_pretrained_models = {
"deepspeech2online_aishell-zh-16k": {
'1.0.2': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.2.model.tar.gz',
'md5':
'4dd42cfce9aaa54db0ec698da6c48ec5',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2online_wenetspeech-zh-16k": {
'1.0.3': {
'url':
'http://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.3.model.tar.gz',
'md5':
'cfe273793e68f790f742b411c98bc75e',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'model':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_10.jit.pdiparams',
'onnx_model':
'onnx/model.onnx',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
} }
# --------------------------------- # ---------------------------------
......
...@@ -164,9 +164,10 @@ class CommonTaskResource: ...@@ -164,9 +164,10 @@ class CommonTaskResource:
try: try:
import_models = '{}_{}_pretrained_models'.format(self.task, import_models = '{}_{}_pretrained_models'.format(self.task,
self.model_format) self.model_format)
print(f"from .pretrained_models import {import_models}")
exec('from .pretrained_models import {}'.format(import_models)) exec('from .pretrained_models import {}'.format(import_models))
models = OrderedDict(locals()[import_models]) models = OrderedDict(locals()[import_models])
except ImportError: except Exception as e:
models = OrderedDict({}) # no models. models = OrderedDict({}) # no models.
finally: finally:
return models return models
......
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
"""Contains the audio featurizer class.""" """Contains the audio featurizer class."""
import numpy as np import numpy as np
import paddle import paddle
import paddleaudio.compliance.kaldi as kaldi
from python_speech_features import delta from python_speech_features import delta
from python_speech_features import mfcc from python_speech_features import mfcc
import paddlespeech.audio.compliance.kaldi as kaldi
class AudioFeaturizer(): class AudioFeaturizer():
"""Audio featurizer, for extracting features from audio contents of """Audio featurizer, for extracting features from audio contents of
......
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
import librosa import librosa
import numpy as np import numpy as np
import paddle import paddle
import paddleaudio.compliance.kaldi as kaldi
from python_speech_features import logfbank from python_speech_features import logfbank
import paddlespeech.audio.compliance.kaldi as kaldi
def stft(x, def stft(x,
n_fft, n_fft,
......
...@@ -28,6 +28,7 @@ asr_online: ...@@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
decode_method: decode_method:
num_decoding_left_chunks: -1
force_yes: True force_yes: True
device: # cpu or gpu:id device: # cpu or gpu:id
continuous_decoding: True # enable continue decoding when endpoint detected continuous_decoding: True # enable continue decoding when endpoint detected
......
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online']
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer_online_wenetspeech'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
num_decoding_left_chunks: 16
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online-inference', 'asr_online-onnx']
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online-onnx']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online-inference #######################
asr_online-inference:
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
num_decoding_left_chunks:
force_yes: True
device: 'cpu' # cpu or gpu:id
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
frame_duration_ms: 80
shift_ms: 40
sample_rate: 16000
sample_width: 2
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
################################### ASR #########################################
################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx:
model_type: 'deepspeech2online_wenetspeech'
am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
num_decoding_left_chunks:
force_yes: True
device: 'cpu' # cpu or gpu:id
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
am_predictor_conf:
device: 'cpu' # set 'gpu:id' or 'cpu'
graph_optimization_level: 0
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
chunk_buffer_conf:
frame_duration_ms: 85
shift_ms: 40
sample_rate: 16000
sample_width: 2
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from typing import ByteString
from typing import Optional
import numpy as np
import paddle
from numpy import float32
from yacs.config import CfgNode
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils import onnx_infer
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
# ASR server connection process class
class PaddleASRConnectionHanddler:
def __init__(self, asr_engine):
"""Init a Paddle ASR Connection Handler instance
Args:
asr_engine (ASREngine): the global asr engine
"""
super().__init__()
logger.info(
"create an paddle asr connection handler to process the websocket connection"
)
self.config = asr_engine.config # server config
self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self.model_type = self.asr_engine.executor.model_type
self.sample_rate = self.asr_engine.executor.sample_rate
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, (
self.sample_rate, self.preprocess_conf.process[0]['fs'])
self.frame_shift_in_ms = int(
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
self.continuous_decoding = self.config.get("continuous_decoding", False)
self.init_decoder()
self.reset()
def init_decoder(self):
if "deepspeech2" in self.model_type:
assert self.continuous_decoding is False, "ds2 model not support endpoint"
self.am_predictor = self.asr_engine.executor.am_predictor
self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab
enc_n_units=self.model_config.rnn_layer_size * 2,
blank_id=self.model_config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.model_config.get('ctc_grad_norm_type',
None))
cfg = self.model_config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
else:
raise ValueError(f"Not supported: {self.model_type}")
def model_reset(self):
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
def output_reset(self):
## outputs
# partial/ending decoding results
self.result_transcripts = ['']
def reset_continuous_decoding(self):
"""
when in continous decoding, reset for next utterance.
"""
self.global_frame_offset = self.num_frames
self.model_reset()
def reset(self):
if "deepspeech2" in self.model_type:
# for deepspeech2
# init state
self.chunk_state_h_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1)
else:
raise NotImplementedError(f"{self.model_type} not support.")
self.device = None
## common
# global sample and frame step
self.num_samples = 0
self.global_frame_offset = 0
# frame step of cur utterance
self.num_frames = 0
## endpoint
self.endpoint_state = False # True for detect endpoint
## conformer
self.model_reset()
## outputs
self.output_reset()
def extract_feat(self, samples: ByteString):
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
self.num_samples += samples.shape[0]
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
)
if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0
# fbank
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
# cur frame step
num_frames = x_chunk.shape[1]
# global frame step
self.num_frames += num_frames
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
)
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
def decode(self, is_finished=False):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Returns:
None:
"""
if "deepspeech2" in self.model_type:
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
context = 7 # context=7, in audio frame unit
subsampling = 4 # subsampling=4, in audio frame unit
cached_feature_num = context - subsampling
# decoding window for model, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context
# decoding stride for model, in audio frame unit
stride = subsampling * decoding_chunk_size
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
)
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data"
)
return None, None
# if is_finished=True, we need at least context frames
if num_frames < context:
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
# if get the finished chunk, we need process the last context
left_frames = context
else:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# extract the audio
x_chunk = self.cached_feat[:, cur:end, :].numpy()
x_chunk_lens = np.array([x_chunk.shape[1]])
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
self.result_transcripts = [trans_best]
# update feat cache
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
# return trans_best[0]
else:
raise Exception(f"{self.model_type} not support paddleinference.")
@paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens):
"""forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger.info("start to decoce one chunk for deepspeech2")
# state_c, state_h, audio_lens, audio
# 'chunk_state_c_box', 'chunk_state_h_box', 'audio_chunk_lens', 'audio_chunk'
input_names = [n.name for n in self.am_predictor.get_inputs()]
logger.info(f"ort inputs: {input_names}")
# 'softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'
# audio, audio_lens, state_h, state_c
output_names = [n.name for n in self.am_predictor.get_outputs()]
logger.info(f"ort outpus: {output_names}")
assert (len(input_names) == len(output_names))
assert isinstance(input_names[0], str)
input_datas = [
self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens,
x_chunk
]
feeds = dict(zip(input_names, input_datas))
outputs = self.am_predictor.run([*output_names], {**feeds})
output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0]
def get_result(self):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if len(self.result_transcripts) > 0:
return self.result_transcripts[0]
else:
return ''
def get_word_time_stamp(self):
return []
@paddle.no_grad()
def rescoring(self):
...
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
self.task_resource = CommonTaskResource(
task='asr', model_format='onnx', inference_mode='online')
def update_config(self) -> None:
if "deepspeech2" in self.model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
else:
raise NotImplementedError(
f"{self.model_type} not support paddleinference.")
def init_model(self) -> None:
if "deepspeech2" in self.model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor = onnx_infer.get_sess(
model_path=self.am_model, sess_conf=self.am_predictor_conf)
else:
raise NotImplementedError(
f"{self.model_type} not support paddleinference.")
def _init_from_path(self,
model_type: str=None,
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
am_predictor_conf: dict=None):
"""
Init model and other resources from a specific path.
"""
if not model_type or not lang or not sample_rate:
logger.error(
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
)
return False
assert am_params is None, "am_params not used in onnx engine"
self.model_type = model_type
self.sample_rate = sample_rate
self.decode_method = decode_method
self.num_decoding_left_chunks = num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None:
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
self.res_path, self.task_resource.res_dict['cfg_path'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[
'onnx_model']) if am_model is None else os.path.abspath(am_model)
# self.am_params = os.path.join(
# self.res_path, self.task_resource.res_dict[
# 'params']) if am_params is None else os.path.abspath(am_params)
logger.info("Load the pretrained model:")
logger.info(f" tag = {tag}")
logger.info(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}")
# logger.info(f" am_params path: {self.am_params}")
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
logger.info(f"spm model path: {self.config.spm_model_prefix}")
self.vocab = self.config.vocab_filepath
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.update_config()
# AM predictor
self.init_model()
logger.info(f"create the {model_type} model success")
return True
class ASREngine(BaseEngine):
"""ASR model resource
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(ASREngine, self).__init__()
def init_model(self) -> bool:
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf):
return False
return True
def init(self, config: dict) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.config = config
self.executor = ASRServerExecutor()
try:
self.device = self.config.get("device", paddle.get_device())
paddle.set_device(self.device)
except BaseException as e:
logger.error(
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
)
logger.error(
"If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1)
logger.info(f"paddlespeech_server set the device: {self.device}")
if not self.init_model():
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
)
return False
logger.info("Initialize ASR server engine successfully.")
return True
def new_handler(self):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return PaddleASRConnectionHanddler(self)
def preprocess(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")
def run(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")
def postprocess(self):
raise NotImplementedError("Online not using this.")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from typing import ByteString
from typing import Optional
import numpy as np
import paddle
from numpy import float32
from yacs.config import CfgNode
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
# ASR server connection process class
class PaddleASRConnectionHanddler:
def __init__(self, asr_engine):
"""Init a Paddle ASR Connection Handler instance
Args:
asr_engine (ASREngine): the global asr engine
"""
super().__init__()
logger.info(
"create an paddle asr connection handler to process the websocket connection"
)
self.config = asr_engine.config # server config
self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self.model_type = self.asr_engine.executor.model_type
self.sample_rate = self.asr_engine.executor.sample_rate
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, (
self.sample_rate, self.preprocess_conf.process[0]['fs'])
self.frame_shift_in_ms = int(
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
self.continuous_decoding = self.config.get("continuous_decoding", False)
self.init_decoder()
self.reset()
def init_decoder(self):
if "deepspeech2" in self.model_type:
assert self.continuous_decoding is False, "ds2 model not support endpoint"
self.am_predictor = self.asr_engine.executor.am_predictor
self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab
enc_n_units=self.model_config.rnn_layer_size * 2,
blank_id=self.model_config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.model_config.get('ctc_grad_norm_type',
None))
cfg = self.model_config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
else:
raise ValueError(f"Not supported: {self.model_type}")
def model_reset(self):
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
def output_reset(self):
## outputs
# partial/ending decoding results
self.result_transcripts = ['']
def reset_continuous_decoding(self):
"""
when in continous decoding, reset for next utterance.
"""
self.global_frame_offset = self.num_frames
self.model_reset()
def reset(self):
if "deepspeech2" in self.model_type:
# for deepspeech2
# init state
self.chunk_state_h_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1)
else:
raise NotImplementedError(f"{self.model_type} not support.")
self.device = None
## common
# global sample and frame step
self.num_samples = 0
self.global_frame_offset = 0
# frame step of cur utterance
self.num_frames = 0
## endpoint
self.endpoint_state = False # True for detect endpoint
## conformer
self.model_reset()
## outputs
self.output_reset()
def extract_feat(self, samples: ByteString):
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
self.num_samples += samples.shape[0]
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
)
if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0
# fbank
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
# cur frame step
num_frames = x_chunk.shape[1]
# global frame step
self.num_frames += num_frames
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
)
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
def decode(self, is_finished=False):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Returns:
None:
"""
if "deepspeech2" in self.model_type:
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
context = 7 # context=7, in audio frame unit
subsampling = 4 # subsampling=4, in audio frame unit
cached_feature_num = context - subsampling
# decoding window for model, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context
# decoding stride for model, in audio frame unit
stride = subsampling * decoding_chunk_size
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
)
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data"
)
return None, None
# if is_finished=True, we need at least context frames
if num_frames < context:
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
# if get the finished chunk, we need process the last context
left_frames = context
else:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# extract the audio
x_chunk = self.cached_feat[:, cur:end, :].numpy()
x_chunk_lens = np.array([x_chunk.shape[1]])
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
self.result_transcripts = [trans_best]
# update feat cache
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
# return trans_best[0]
else:
raise Exception(f"{self.model_type} not support paddleinference.")
@paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens):
"""forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger.info("start to decoce one chunk for deepspeech2")
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(self.chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
c_box_handle.reshape(self.chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
self.am_predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0]
def get_result(self):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if len(self.result_transcripts) > 0:
return self.result_transcripts[0]
else:
return ''
def get_word_time_stamp(self):
return []
@paddle.no_grad()
def rescoring(self):
...
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
self.task_resource = CommonTaskResource(
task='asr', model_format='static', inference_mode='online')
def update_config(self) -> None:
if "deepspeech2" in self.model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
else:
raise NotImplementedError(
f"{self.model_type} not support paddleinference.")
def init_model(self) -> None:
if "deepspeech2" in self.model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
else:
raise NotImplementedError(
f"{self.model_type} not support paddleinference.")
def _init_from_path(self,
model_type: str=None,
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
am_predictor_conf: dict=None):
"""
Init model and other resources from a specific path.
"""
if not model_type or not lang or not sample_rate:
logger.error(
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
)
return False
self.model_type = model_type
self.sample_rate = sample_rate
self.decode_method = decode_method
self.num_decoding_left_chunks = num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None or am_model is None or am_params is None:
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
self.res_path, self.task_resource.res_dict['cfg_path'])
self.am_model = os.path.join(self.res_path,
self.task_resource.res_dict['model'])
self.am_params = os.path.join(self.res_path,
self.task_resource.res_dict['params'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model)
self.am_params = os.path.abspath(am_params)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info("Load the pretrained model:")
logger.info(f" tag = {tag}")
logger.info(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}")
logger.info(f" am_params path: {self.am_params}")
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
logger.info(f"spm model path: {self.config.spm_model_prefix}")
self.vocab = self.config.vocab_filepath
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.update_config()
# AM predictor
self.init_model()
logger.info(f"create the {model_type} model success")
return True
class ASREngine(BaseEngine):
"""ASR model resource
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(ASREngine, self).__init__()
def init_model(self) -> bool:
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf):
return False
return True
def init(self, config: dict) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.config = config
self.executor = ASRServerExecutor()
try:
self.device = self.config.get("device", paddle.get_device())
paddle.set_device(self.device)
except BaseException as e:
logger.error(
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
)
logger.error(
"If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1)
logger.info(f"paddlespeech_server set the device: {self.device}")
if not self.init_model():
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
)
return False
logger.info("Initialize ASR server engine successfully.")
return True
def new_handler(self):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return PaddleASRConnectionHanddler(self)
def preprocess(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")
def run(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")
def postprocess(self):
raise NotImplementedError("Online not using this.")
...@@ -121,13 +121,13 @@ class PaddleASRConnectionHanddler: ...@@ -121,13 +121,13 @@ class PaddleASRConnectionHanddler:
raise ValueError(f"Not supported: {self.model_type}") raise ValueError(f"Not supported: {self.model_type}")
def model_reset(self): def model_reset(self):
if "deepspeech2" in self.model_type:
return
# cache for audio and feat # cache for audio and feat
self.remained_wav = None self.remained_wav = None
self.cached_feat = None self.cached_feat = None
if "deepspeech2" in self.model_type:
return
## conformer ## conformer
# cache for conformer online # cache for conformer online
self.subsampling_cache = None self.subsampling_cache = None
...@@ -161,7 +161,9 @@ class PaddleASRConnectionHanddler: ...@@ -161,7 +161,9 @@ class PaddleASRConnectionHanddler:
self.model_reset() self.model_reset()
self.searcher.reset() self.searcher.reset()
self.endpointer.reset() self.endpointer.reset()
self.output_reset()
# reset hys will trancate history transcripts.
# self.output_reset()
def reset(self): def reset(self):
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
...@@ -695,6 +697,66 @@ class ASRServerExecutor(ASRExecutor): ...@@ -695,6 +697,66 @@ class ASRServerExecutor(ASRExecutor):
self.task_resource = CommonTaskResource( self.task_resource = CommonTaskResource(
task='asr', model_format='dynamic', inference_mode='online') task='asr', model_format='dynamic', inference_mode='online')
def update_config(self) -> None:
if "deepspeech2" in self.model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in self.model_type or "transformer" in self.model_type:
with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine")
# update the decoding method
if self.decode_method:
self.config.decode.decoding_method = self.decode_method
# update num_decoding_left_chunks
if self.num_decoding_left_chunks:
assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0"
self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [
"ctc_prefix_beam_search", "attention_rescoring"
]:
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else:
raise Exception(f"not support: {self.model_type}")
def init_model(self) -> None:
if "deepspeech2" in self.model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in self.model_type or "transformer" in self.model_type:
# load model
# model_type: {model_name}_{dataset}
model_name = self.model_type[:self.model_type.rindex('_')]
logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name)
model = model_class.from_config(self.config)
self.model = model
self.model.set_state_dict(paddle.load(self.am_model))
self.model.eval()
else:
raise Exception(f"not support: {self.model_type}")
def _init_from_path(self, def _init_from_path(self,
model_type: str=None, model_type: str=None,
am_model: Optional[os.PathLike]=None, am_model: Optional[os.PathLike]=None,
...@@ -703,6 +765,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -703,6 +765,7 @@ class ASRServerExecutor(ASRExecutor):
sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring', decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
am_predictor_conf: dict=None): am_predictor_conf: dict=None):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
...@@ -715,6 +778,10 @@ class ASRServerExecutor(ASRExecutor): ...@@ -715,6 +778,10 @@ class ASRServerExecutor(ASRExecutor):
self.model_type = model_type self.model_type = model_type
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.decode_method = decode_method
self.num_decoding_left_chunks = num_decoding_left_chunks
# conf for paddleinference predictor or onnx
self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}") logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
...@@ -760,59 +827,10 @@ class ASRServerExecutor(ASRExecutor): ...@@ -760,59 +827,10 @@ class ASRServerExecutor(ASRExecutor):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
if "deepspeech2" in model_type: self.update_config()
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor") self.init_model()
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in model_type or "transformer" in model_type:
with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine")
# update the decoding method
if decode_method:
self.config.decode.decoding_method = decode_method
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [
"ctc_prefix_beam_search", "attention_rescoring"
]:
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
# load model
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name)
model = model_class.from_config(self.config)
self.model = model
self.model.set_state_dict(paddle.load(self.am_model))
self.model.eval()
else:
raise Exception(f"not support: {model_type}")
logger.info(f"create the {model_type} model success") logger.info(f"create the {model_type} model success")
return True return True
...@@ -827,7 +845,20 @@ class ASREngine(BaseEngine): ...@@ -827,7 +845,20 @@ class ASREngine(BaseEngine):
def __init__(self): def __init__(self):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
logger.info("create the online asr engine resource instance")
def init_model(self) -> bool:
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf):
return False
return True
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
"""init engine resource """init engine resource
...@@ -854,15 +885,7 @@ class ASREngine(BaseEngine): ...@@ -854,15 +885,7 @@ class ASREngine(BaseEngine):
logger.info(f"paddlespeech_server set the device: {self.device}") logger.info(f"paddlespeech_server set the device: {self.device}")
if not self.executor._init_from_path( if not self.init_model():
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf):
logger.error( logger.error(
"Init the ASR server occurs error, please check the server configuration yaml" "Init the ASR server occurs error, please check the server configuration yaml"
) )
......
...@@ -13,12 +13,16 @@ ...@@ -13,12 +13,16 @@
# limitations under the License. # limitations under the License.
from typing import Text from typing import Text
from ..utils.log import logger
__all__ = ['EngineFactory'] __all__ = ['EngineFactory']
class EngineFactory(object): class EngineFactory(object):
@staticmethod @staticmethod
def get_engine(engine_name: Text, engine_type: Text): def get_engine(engine_name: Text, engine_type: Text):
logger.info(f"{engine_name} : {engine_type} engine.")
if engine_name == 'asr' and engine_type == 'inference': if engine_name == 'asr' and engine_type == 'inference':
from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine
return ASREngine() return ASREngine()
...@@ -26,7 +30,13 @@ class EngineFactory(object): ...@@ -26,7 +30,13 @@ class EngineFactory(object):
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
return ASREngine() return ASREngine()
elif engine_name == 'asr' and engine_type == 'online': elif engine_name == 'asr' and engine_type == 'online':
from paddlespeech.server.engine.asr.online.asr_engine import ASREngine from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online-inference':
from paddlespeech.server.engine.asr.online.paddleinference.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online-onnx':
from paddlespeech.server.engine.asr.online.onnx.asr_engine import ASREngine
return ASREngine() return ASREngine()
elif engine_name == 'tts' and engine_type == 'inference': elif engine_name == 'tts' and engine_type == 'inference':
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine
......
...@@ -16,9 +16,9 @@ from collections import OrderedDict ...@@ -16,9 +16,9 @@ from collections import OrderedDict
import numpy as np import numpy as np
import paddle import paddle
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.vector.infer import VectorExecutor from paddlespeech.cli.vector.infer import VectorExecutor
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
......
...@@ -24,11 +24,11 @@ from typing import Any ...@@ -24,11 +24,11 @@ from typing import Any
from typing import Dict from typing import Dict
import paddle import paddle
import paddleaudio
import requests import requests
import yaml import yaml
from paddle.framework import load from paddle.framework import load
import paddlespeech.audio
from .entry import client_commands from .entry import client_commands
from .entry import server_commands from .entry import server_commands
from paddlespeech.cli import download from paddlespeech.cli import download
...@@ -289,7 +289,7 @@ def _note_one_stat(cls_name, params={}): ...@@ -289,7 +289,7 @@ def _note_one_stat(cls_name, params={}):
if 'audio_file' in params: if 'audio_file' in params:
try: try:
_, sr = paddleaudio.load(params['audio_file']) _, sr = paddlespeech.audio.load(params['audio_file'])
except Exception: except Exception:
sr = -1 sr = -1
......
...@@ -16,21 +16,34 @@ from typing import Optional ...@@ -16,21 +16,34 @@ from typing import Optional
import onnxruntime as ort import onnxruntime as ort
from .log import logger
def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None): def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
logger.info(f"ort sessconf: {sess_conf}")
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
if sess_conf.get('graph_optimization_level', 99) == 0:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if "gpu" in sess_conf["device"]: # "gpu:0"
providers = ['CPUExecutionProvider']
if "gpu" in sess_conf.get("device", ""):
providers = ['CUDAExecutionProvider']
# fastspeech2/mb_melgan can't use trt now! # fastspeech2/mb_melgan can't use trt now!
if sess_conf["use_trt"]: if sess_conf.get("use_trt", 0):
providers = ['TensorrtExecutionProvider'] providers = ['TensorrtExecutionProvider']
logger.info(f"ort providers: {providers}")
if 'cpu_threads' in sess_conf:
sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0)
else: else:
providers = ['CUDAExecutionProvider'] sess_options.intra_op_num_threads = sess_conf.get(
elif sess_conf["device"] == "cpu": "intra_op_num_threads", 0)
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = sess_conf["cpu_threads"] sess_options.inter_op_num_threads = sess_conf.get("inter_op_num_threads", 0)
sess = ort.InferenceSession( sess = ort.InferenceSession(
model_path, providers=providers, sess_options=sess_options) model_path, providers=providers, sess_options=sess_options)
return sess return sess
...@@ -92,6 +92,7 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -92,6 +92,7 @@ async def websocket_endpoint(websocket: WebSocket):
else: else:
resp = {"status": "ok", "message": "no valid json data"} resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp) await websocket.send_json(resp)
elif "bytes" in message: elif "bytes" in message:
# bytes for the pcm data # bytes for the pcm data
message = message["bytes"] message = message["bytes"]
......
...@@ -16,10 +16,10 @@ import os ...@@ -16,10 +16,10 @@ import os
import time import time
import paddle import paddle
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.backends import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
......
...@@ -18,10 +18,10 @@ import numpy as np ...@@ -18,10 +18,10 @@ import numpy as np
import paddle import paddle
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from paddleaudio.metric import compute_eer
from tqdm import tqdm from tqdm import tqdm
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.io.dataset import CSVDataset from paddlespeech.vector.io.dataset import CSVDataset
......
...@@ -20,9 +20,9 @@ import paddle ...@@ -20,9 +20,9 @@ import paddle
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment from paddlespeech.vector.io.augment import waveform_augment
......
...@@ -15,9 +15,9 @@ from dataclasses import dataclass ...@@ -15,9 +15,9 @@ from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from paddle.io import Dataset from paddle.io import Dataset
from paddleaudio import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.audio import load as load_audio
from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
......
...@@ -16,9 +16,10 @@ from dataclasses import dataclass ...@@ -16,9 +16,10 @@ from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from paddle.io import Dataset from paddle.io import Dataset
from paddleaudio import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram from paddlespeech.audio import load as load_audio
from paddleaudio.compliance.librosa import mfcc from paddlespeech.audio.compliance.librosa import melspectrogram
from paddlespeech.audio.compliance.librosa import mfcc
@dataclass @dataclass
......
...@@ -24,6 +24,7 @@ from setuptools import find_packages ...@@ -24,6 +24,7 @@ from setuptools import find_packages
from setuptools import setup from setuptools import setup
from setuptools.command.develop import develop from setuptools.command.develop import develop
from setuptools.command.install import install from setuptools.command.install import install
from setuptools.command.test import test
HERE = Path(os.path.abspath(os.path.dirname(__file__))) HERE = Path(os.path.abspath(os.path.dirname(__file__)))
...@@ -31,42 +32,13 @@ VERSION = '0.0.0' ...@@ -31,42 +32,13 @@ VERSION = '0.0.0'
COMMITID = 'none' COMMITID = 'none'
base = [ base = [
"editdistance", "editdistance", "g2p_en", "g2pM", "h5py", "inflect", "jieba", "jsonlines",
"g2p_en", "kaldiio", "librosa==0.8.1", "loguru", "matplotlib", "nara_wpe",
"g2pM", "onnxruntime", "pandas", "paddlenlp", "paddlespeech_feat", "praatio==5.0.0",
"h5py", "pypinyin", "pypinyin-dict", "python-dateutil", "pyworld", "resampy==0.2.2",
"inflect", "sacrebleu", "scipy", "sentencepiece~=0.1.96", "soundfile~=0.10",
"jieba", "textgrid", "timer", "tqdm", "typeguard", "visualdl", "webrtcvad",
"jsonlines", "yacs~=0.1.8", "prettytable", "zhon", 'colorlog', 'pathos == 0.2.8'
"kaldiio",
"librosa==0.8.1",
"loguru",
"matplotlib",
"nara_wpe",
"onnxruntime",
"pandas",
"paddleaudio",
"paddlenlp",
"paddlespeech_feat",
"praatio==5.0.0",
"pypinyin",
"pypinyin-dict",
"python-dateutil",
"pyworld",
"resampy==0.2.2",
"sacrebleu",
"scipy",
"sentencepiece~=0.1.96",
"soundfile~=0.10",
"textgrid",
"timer",
"tqdm",
"typeguard",
"visualdl",
"webrtcvad",
"yacs~=0.1.8",
"prettytable",
"zhon",
] ]
server = [ server = [
...@@ -177,7 +149,19 @@ class InstallCommand(install): ...@@ -177,7 +149,19 @@ class InstallCommand(install):
install.run(self) install.run(self)
# cmd: python setup.py upload class TestCommand(test):
def finalize_options(self):
test.finalize_options(self)
self.test_args = []
self.test_suite = True
def run_tests(self):
# Run nose ensuring that argv simulates running nosetests directly
import nose
nose.run_exit(argv=['nosetests', '-w', 'tests'])
# cmd: python setup.py upload
class UploadCommand(Command): class UploadCommand(Command):
description = "Build and publish the package." description = "Build and publish the package."
user_options = [] user_options = []
...@@ -279,11 +263,13 @@ setup_info = dict( ...@@ -279,11 +263,13 @@ setup_info = dict(
"sphinx", "sphinx-rtd-theme", "numpydoc", "myst_parser", "sphinx", "sphinx-rtd-theme", "numpydoc", "myst_parser",
"recommonmark>=0.5.0", "sphinx-markdown-tables", "sphinx-autobuild" "recommonmark>=0.5.0", "sphinx-markdown-tables", "sphinx-autobuild"
], ],
'test': ['nose', 'torchaudio==0.10.2'],
}, },
cmdclass={ cmdclass={
'develop': DevelopCommand, 'develop': DevelopCommand,
'install': InstallCommand, 'install': InstallCommand,
'upload': UploadCommand, 'upload': UploadCommand,
'test': TestCommand,
}, },
# Package info # Package info
......
# DeepSpeech2 ONNX model
1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
2. check paddleinference and onnxruntime output equal.
3. optimize onnx model
4. check paddleinference and optimized onnxruntime output equal.
Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct.
The example test with these packages installed:
```
paddle2onnx 0.9.8 # develop 62c5424e22cd93968dc831216fc9e0f0fce3d819
paddleaudio 0.2.1
paddlefsl 1.1.0
paddlenlp 2.2.6
paddlepaddle-gpu 2.2.2
paddlespeech 0.0.0 # develop
paddlespeech-ctcdecoders 0.2.0
paddlespeech-feat 0.1.0
onnx 1.11.0
onnx-simplifier 0.0.0 # https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape
onnxoptimizer 0.2.7
onnxruntime 1.11.0
```
## Using
```
bash run.sh
```
For more details please see `run.sh`.
## Outputs
The optimized onnx model is `exp/model.opt.onnx`.
To show the graph, please using `local/netron.sh`.
#!/usr/bin/env python3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import pickle
import numpy as np
import onnxruntime
import paddle
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--input_file',
type=str,
default="static_ds2online_inputs.pickle",
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model",
)
parser.add_argument(
'--model_type',
type=str,
default="aishell",
help="aishell(1024) or wenetspeech(2048)", )
parser.add_argument(
'--model_dir', type=str, default=".", help="paddle model dir.")
parser.add_argument(
'--model_prefix',
type=str,
default="avg_1.jit",
help="paddle model prefix.")
parser.add_argument(
'--onnx_model',
type=str,
default='./model.old.onnx',
help="onnx model.")
return parser.parse_args()
if __name__ == '__main__':
FLAGS = parse_args()
# input and output
with open(FLAGS.input_file, 'rb') as f:
iodict = pickle.load(f)
print(iodict.keys())
audio_chunk = iodict['audio_chunk']
audio_chunk_lens = iodict['audio_chunk_lens']
chunk_state_h_box = iodict['chunk_state_h_box']
chunk_state_c_box = iodict['chunk_state_c_bos']
print("raw state shape: ", chunk_state_c_box.shape)
if FLAGS.model_type == 'wenetspeech':
chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1)
chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1)
print("state shape: ", chunk_state_c_box.shape)
# paddle
model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix))
res_chunk, res_lens, chunk_state_h, chunk_state_c = model(
paddle.to_tensor(audio_chunk),
paddle.to_tensor(audio_chunk_lens),
paddle.to_tensor(chunk_state_h_box),
paddle.to_tensor(chunk_state_c_box), )
# onnxruntime
options = onnxruntime.SessionOptions()
options.enable_profiling = True
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
"audio_chunk": audio_chunk,
"audio_chunk_lens": audio_chunk_lens,
"chunk_state_h_box": chunk_state_h_box,
"chunk_state_c_box": chunk_state_c_box
})
print(sess.end_profiling())
# assert paddle equal ort
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
if FLAGS.model_type == 'aishell':
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))
#!/bin/bash
# show model
if [ $# != 1 ];then
echo "usage: $0 model_path"
exit 1
fi
file=$1
pip install netron
netron -p 8082 --host $(hostname -i) $file
\ No newline at end of file
#!/bin/bash
# clone onnx repos
git clone https://github.com/onnx/onnx.git
git clone https://github.com/microsoft/onnxruntime.git
git clone https://github.com/PaddlePaddle/Paddle2ONNX.git
\ No newline at end of file
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# flake8: noqa
import argparse
import logging
import numpy as np
import onnx
import sympy
from onnx import helper
from onnx import numpy_helper
from onnx import shape_inference
from packaging import version
assert version.parse(onnx.__version__) >= version.parse("1.8.0")
logger = logging.getLogger(__name__)
def get_attribute(node, attr_name, default_value=None):
found = [attr for attr in node.attribute if attr.name == attr_name]
if found:
return helper.get_attribute_value(found[0])
return default_value
def get_dim_from_proto(dim):
return getattr(dim, dim.WhichOneof('value')) if type(
dim.WhichOneof('value')) == str else None
def is_sequence(type_proto):
cls_type = type_proto.WhichOneof('value')
assert cls_type in ['tensor_type', 'sequence_type']
return cls_type == 'sequence_type'
def get_shape_from_type_proto(type_proto):
assert not is_sequence(type_proto)
if type_proto.tensor_type.HasField('shape'):
return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim]
else:
return None # note no shape is different from shape without dim (scalar)
def get_shape_from_value_info(vi):
cls_type = vi.type.WhichOneof('value')
if cls_type is None:
return None
if is_sequence(vi.type):
if 'tensor_type' == vi.type.sequence_type.elem_type.WhichOneof('value'):
return get_shape_from_type_proto(vi.type.sequence_type.elem_type)
else:
return None
else:
return get_shape_from_type_proto(vi.type)
def make_named_value_info(name):
vi = onnx.ValueInfoProto()
vi.name = name
return vi
def get_shape_from_sympy_shape(sympy_shape):
return [
None if i is None else (int(i) if is_literal(i) else str(i))
for i in sympy_shape
]
def is_literal(dim):
return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(
dim, 'is_number') and dim.is_number)
def handle_negative_axis(axis, rank):
assert axis < rank and axis >= -rank
return axis if axis >= 0 else rank + axis
def get_opset(mp, domain=None):
domain = domain or ['', 'onnx', 'ai.onnx']
if type(domain) != list:
domain = [domain]
for opset in mp.opset_import:
if opset.domain in domain:
return opset.version
return None
def as_scalar(x):
if type(x) == list:
assert len(x) == 1
return x[0]
elif type(x) == np.ndarray:
return x.item()
else:
return x
def as_list(x, keep_none):
if type(x) == list:
return x
elif type(x) == np.ndarray:
return list(x)
elif keep_none and x is None:
return None
else:
return [x]
def sympy_reduce_product(x):
if type(x) == list:
value = sympy.Integer(1)
for v in x:
value = value * v
else:
value = x
return value
class SymbolicShapeInference:
def __init__(self,
int_max,
auto_merge,
guess_output_rank,
verbose,
prefix=''):
self.dispatcher_ = {
'Add':
self._infer_symbolic_compute_ops,
'ArrayFeatureExtractor':
self._infer_ArrayFeatureExtractor,
'AveragePool':
self._infer_Pool,
'BatchNormalization':
self._infer_BatchNormalization,
'Cast':
self._infer_Cast,
'CategoryMapper':
self._infer_CategoryMapper,
'Compress':
self._infer_Compress,
'Concat':
self._infer_Concat,
'ConcatFromSequence':
self._infer_ConcatFromSequence,
'Constant':
self._infer_Constant,
'ConstantOfShape':
self._infer_ConstantOfShape,
'Conv':
self._infer_Conv,
'CumSum':
self._pass_on_shape_and_type,
'Div':
self._infer_symbolic_compute_ops,
'Einsum':
self._infer_Einsum,
'Expand':
self._infer_Expand,
'Equal':
self._infer_symbolic_compute_ops,
'Floor':
self._infer_symbolic_compute_ops,
'Gather':
self._infer_Gather,
'GatherElements':
self._infer_GatherElements,
'GatherND':
self._infer_GatherND,
'Gelu':
self._pass_on_shape_and_type,
'If':
self._infer_If,
'Loop':
self._infer_Loop,
'MatMul':
self._infer_MatMul,
'MatMulInteger16':
self._infer_MatMulInteger,
'MaxPool':
self._infer_Pool,
'Max':
self._infer_symbolic_compute_ops,
'Min':
self._infer_symbolic_compute_ops,
'Mul':
self._infer_symbolic_compute_ops,
'NonMaxSuppression':
self._infer_NonMaxSuppression,
'NonZero':
self._infer_NonZero,
'OneHot':
self._infer_OneHot,
'Pad':
self._infer_Pad,
'Range':
self._infer_Range,
'Reciprocal':
self._pass_on_shape_and_type,
'ReduceSum':
self._infer_ReduceSum,
'ReduceProd':
self._infer_ReduceProd,
'Reshape':
self._infer_Reshape,
'Resize':
self._infer_Resize,
'Round':
self._pass_on_shape_and_type,
'Scan':
self._infer_Scan,
'ScatterElements':
self._infer_ScatterElements,
'SequenceAt':
self._infer_SequenceAt,
'SequenceInsert':
self._infer_SequenceInsert,
'Shape':
self._infer_Shape,
'Size':
self._infer_Size,
'Slice':
self._infer_Slice,
'SoftmaxCrossEntropyLoss':
self._infer_SoftmaxCrossEntropyLoss,
'SoftmaxCrossEntropyLossInternal':
self._infer_SoftmaxCrossEntropyLoss,
'NegativeLogLikelihoodLossInternal':
self._infer_SoftmaxCrossEntropyLoss,
'Split':
self._infer_Split,
'SplitToSequence':
self._infer_SplitToSequence,
'Squeeze':
self._infer_Squeeze,
'Sub':
self._infer_symbolic_compute_ops,
'Tile':
self._infer_Tile,
'TopK':
self._infer_TopK,
'Transpose':
self._infer_Transpose,
'Unsqueeze':
self._infer_Unsqueeze,
'Where':
self._infer_symbolic_compute_ops,
'ZipMap':
self._infer_ZipMap,
'Neg':
self._infer_symbolic_compute_ops,
# contrib ops:
'Attention':
self._infer_Attention,
'BiasGelu':
self._infer_BiasGelu,
'EmbedLayerNormalization':
self._infer_EmbedLayerNormalization,
'FastGelu':
self._infer_FastGelu,
'Gelu':
self._infer_Gelu,
'LayerNormalization':
self._infer_LayerNormalization,
'LongformerAttention':
self._infer_LongformerAttention,
'PythonOp':
self._infer_PythonOp,
'SkipLayerNormalization':
self._infer_SkipLayerNormalization
}
self.aten_op_dispatcher_ = {
'aten::embedding': self._infer_Gather,
'aten::bitwise_or': self._infer_aten_bitwise_or,
'aten::diagonal': self._infer_aten_diagonal,
'aten::max_pool2d_with_indices': self._infer_aten_pool2d,
'aten::multinomial': self._infer_aten_multinomial,
'aten::unfold': self._infer_aten_unfold,
'aten::argmax': self._infer_aten_argmax,
'aten::avg_pool2d': self._infer_aten_pool2d,
'aten::_adaptive_avg_pool2d': self._infer_aten_pool2d,
'aten::binary_cross_entropy_with_logits': self._infer_aten_bce,
'aten::numpy_T': self._infer_Transpose,
}
self.run_ = True
self.suggested_merge_ = {}
self.symbolic_dims_ = {}
self.input_symbols_ = {}
self.auto_merge_ = auto_merge
self.guess_output_rank_ = guess_output_rank
self.verbose_ = verbose
self.int_max_ = int_max
self.subgraph_id_ = 0
self.prefix_ = prefix
def _add_suggested_merge(self, symbols, apply=False):
assert all([(type(s) == str and s in self.symbolic_dims_) or
is_literal(s) for s in symbols])
symbols = set(symbols)
for k, v in self.suggested_merge_.items():
if k in symbols:
symbols.remove(k)
symbols.add(v)
map_to = None
# if there is literal, map to it first
for s in symbols:
if is_literal(s):
map_to = s
break
# when no literals, map to input symbolic dims, then existing symbolic dims
if map_to is None:
for s in symbols:
if s in self.input_symbols_:
map_to = s
break
if map_to is None:
for s in symbols:
if type(self.symbolic_dims_[s]) == sympy.Symbol:
map_to = s
break
# when nothing to map to, use the shorter one
if map_to is None:
if self.verbose_ > 0:
logger.warning(
'Potential unsafe merge between symbolic expressions: ({})'.
format(','.join(symbols)))
symbols_list = list(symbols)
lens = [len(s) for s in symbols_list]
map_to = symbols_list[lens.index(min(lens))]
symbols.remove(map_to)
for s in symbols:
if s == map_to:
continue
if is_literal(map_to) and is_literal(s):
assert int(map_to) == int(s)
self.suggested_merge_[s] = int(map_to) if is_literal(
map_to) else map_to
for k, v in self.suggested_merge_.items():
if v == s:
self.suggested_merge_[k] = map_to
if apply and self.auto_merge_:
self._apply_suggested_merge()
def _apply_suggested_merge(self, graph_input_only=False):
if not self.suggested_merge_:
return
for i in list(self.out_mp_.graph.input) + (
[] if graph_input_only else list(self.out_mp_.graph.value_info)):
for d in i.type.tensor_type.shape.dim:
if d.dim_param in self.suggested_merge_:
v = self.suggested_merge_[d.dim_param]
if is_literal(v):
d.dim_value = int(v)
else:
d.dim_param = v
def _preprocess(self, in_mp):
self.out_mp_ = onnx.ModelProto()
self.out_mp_.CopyFrom(in_mp)
self.graph_inputs_ = dict(
[(i.name, i) for i in list(self.out_mp_.graph.input)])
self.initializers_ = dict(
[(i.name, i) for i in self.out_mp_.graph.initializer])
self.known_vi_ = dict(
[(i.name, i) for i in list(self.out_mp_.graph.input)])
self.known_vi_.update(
dict([(i.name, helper.make_tensor_value_info(i.name, i.data_type,
list(i.dims)))
for i in self.out_mp_.graph.initializer]))
def _merge_symbols(self, dims):
if not all([type(d) == str for d in dims]):
if self.auto_merge_:
unique_dims = list(set(dims))
is_int = [is_literal(d) for d in unique_dims]
assert sum(
is_int
) <= 1 # if there are more than 1 unique ints, something is wrong
if sum(is_int) == 1:
int_dim = is_int.index(1)
if self.verbose_ > 0:
logger.debug('dim {} has been merged with value {}'.
format(unique_dims[:int_dim] + unique_dims[
int_dim + 1:], unique_dims[int_dim]))
self._check_merged_dims(unique_dims, allow_broadcast=False)
return unique_dims[int_dim]
else:
if self.verbose_ > 0:
logger.debug('dim {} has been mergd with dim {}'.format(
unique_dims[1:], unique_dims[0]))
return dims[0]
else:
return None
if all([d == dims[0] for d in dims]):
return dims[0]
merged = [
self.suggested_merge_[d] if d in self.suggested_merge_ else d
for d in dims
]
if all([d == merged[0] for d in merged]):
assert merged[0] in self.symbolic_dims_
return merged[0]
else:
return None
# broadcast from right to left, and merge symbolic dims if needed
def _broadcast_shapes(self, shape1, shape2):
new_shape = []
rank1 = len(shape1)
rank2 = len(shape2)
new_rank = max(rank1, rank2)
for i in range(new_rank):
dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
if dim1 == 1 or dim1 == dim2:
new_dim = dim2
elif dim2 == 1:
new_dim = dim1
else:
new_dim = self._merge_symbols([dim1, dim2])
if not new_dim:
# warning about unsupported broadcast when not auto merge
# note that auto merge has the risk of incorrectly merge symbols while one of them being 1
# for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b'
if self.auto_merge_:
self._add_suggested_merge([dim1, dim2], apply=True)
else:
logger.warning('unsupported broadcast between ' + str(
dim1) + ' ' + str(dim2))
new_shape = [new_dim] + new_shape
return new_shape
def _get_shape(self, node, idx):
name = node.input[idx]
if name in self.known_vi_:
vi = self.known_vi_[name]
return get_shape_from_value_info(vi)
else:
assert name in self.initializers_
return list(self.initializers_[name].dims)
def _get_shape_rank(self, node, idx):
return len(self._get_shape(node, idx))
def _get_sympy_shape(self, node, idx):
sympy_shape = []
for d in self._get_shape(node, idx):
if type(d) == str:
sympy_shape.append(self.symbolic_dims_[d] if d in
self.symbolic_dims_ else sympy.Symbol(
d, integer=True, nonnegative=True))
else:
assert None != d
sympy_shape.append(d)
return sympy_shape
def _get_value(self, node, idx):
name = node.input[idx]
assert name in self.sympy_data_ or name in self.initializers_
return self.sympy_data_[
name] if name in self.sympy_data_ else numpy_helper.to_array(
self.initializers_[name])
def _try_get_value(self, node, idx):
if idx >= len(node.input):
return None
name = node.input[idx]
if name in self.sympy_data_ or name in self.initializers_:
return self._get_value(node, idx)
return None
def _update_computed_dims(self, new_sympy_shape):
for i, new_dim in enumerate(new_sympy_shape):
if not is_literal(new_dim) and not type(new_dim) == str:
str_dim = str(new_dim)
if str_dim in self.suggested_merge_:
if is_literal(self.suggested_merge_[str_dim]):
continue # no need to create dim for literals
new_sympy_shape[i] = self.symbolic_dims_[
self.suggested_merge_[str_dim]]
else:
# add new_dim if it's a computational expression
if not str(new_dim) in self.symbolic_dims_:
self.symbolic_dims_[str(new_dim)] = new_dim
def _onnx_infer_single_node(self, node):
# skip onnx shape inference for some ops, as they are handled in _infer_*
skip_infer = node.op_type in [
'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \
# contrib ops
'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \
'LongformerAttention', \
'SkipLayerNormalization', \
'PythonOp'
]
if not skip_infer:
# Only pass initializers that satisfy the following condition:
# (1) Operator need value of some input for shape inference.
# For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output.
# (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
# (3) The initializer is not in graph input. The means the node input is "constant" in inference.
initializers = []
if (get_opset(self.out_mp_) >= 9) and node.op_type in ['Unsqueeze']:
initializers = [
self.initializers_[name] for name in node.input
if (name in self.initializers_ and
name not in self.graph_inputs_)
]
# run single node inference with self.known_vi_ shapes
tmp_graph = helper.make_graph(
[node], 'tmp', [self.known_vi_[i] for i in node.input if i],
[make_named_value_info(i) for i in node.output], initializers)
self.tmp_mp_.graph.CopyFrom(tmp_graph)
self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
for i_o in range(len(node.output)):
o = node.output[i_o]
vi = self.out_mp_.graph.value_info.add()
if not skip_infer:
vi.CopyFrom(self.tmp_mp_.graph.output[i_o])
else:
vi.name = o
self.known_vi_[o] = vi
def _onnx_infer_subgraph(self,
node,
subgraph,
use_node_input=True,
inc_subgraph_id=True):
if self.verbose_ > 2:
logger.debug(
'Inferencing subgraph of node {} with output({}...): {}'.format(
node.name, node.output[0], node.op_type))
# node inputs are not passed directly to the subgraph
# it's up to the node dispatcher to prepare subgraph input
# for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape
# besides, inputs in subgraph could shadow implicit inputs
subgraph_inputs = set(
[i.name for i in list(subgraph.initializer) + list(subgraph.input)])
subgraph_implicit_input = set([
name for name in self.known_vi_.keys()
if not name in subgraph_inputs
])
tmp_graph = helper.make_graph(
list(subgraph.node), 'tmp',
list(subgraph.input) +
[self.known_vi_[i] for i in subgraph_implicit_input],
[make_named_value_info(i.name) for i in subgraph.output])
tmp_graph.initializer.extend([
i for i in self.out_mp_.graph.initializer
if i.name in subgraph_implicit_input
])
tmp_graph.initializer.extend(subgraph.initializer)
self.tmp_mp_.graph.CopyFrom(tmp_graph)
symbolic_shape_inference = SymbolicShapeInference(
self.int_max_,
self.auto_merge_,
self.guess_output_rank_,
self.verbose_,
prefix=self.prefix_ + '_' + str(self.subgraph_id_))
if inc_subgraph_id:
self.subgraph_id_ += 1
all_shapes_inferred = False
symbolic_shape_inference._preprocess(self.tmp_mp_)
symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl(
self.sympy_data_.copy())
symbolic_shape_inference._update_output_from_vi()
if use_node_input:
# if subgraph uses node input, it needs to update to merged dims
subgraph.ClearField('input')
subgraph.input.extend(
symbolic_shape_inference.out_mp_.graph.input[:len(node.input)])
subgraph.ClearField('output')
subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output)
subgraph.ClearField('value_info')
subgraph.value_info.extend(
symbolic_shape_inference.out_mp_.graph.value_info)
subgraph.ClearField('node')
subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node)
# for new symbolic dims from subgraph output, add to main graph symbolic dims
subgraph_shapes = [
get_shape_from_value_info(o)
for o in symbolic_shape_inference.out_mp_.graph.output
]
subgraph_new_symbolic_dims = set([
d for s in subgraph_shapes if s for d in s
if type(d) == str and not d in self.symbolic_dims_
])
new_dims = {}
for d in subgraph_new_symbolic_dims:
assert d in symbolic_shape_inference.symbolic_dims_
new_dims[d] = symbolic_shape_inference.symbolic_dims_[d]
self.symbolic_dims_.update(new_dims)
return symbolic_shape_inference
def _get_int_values(self, node, broadcast=False):
values = [self._try_get_value(node, i) for i in range(len(node.input))]
if all([v is not None for v in values]):
# some shape compute is in floating point, cast to int for sympy
for i, v in enumerate(values):
if type(v) != np.ndarray:
continue
if len(v.shape) > 1:
new_v = None # ignore value for rank > 1
elif len(v.shape) == 0:
new_v = int(v.item())
else:
assert len(v.shape) == 1
new_v = [int(vv) for vv in v]
values[i] = new_v
values_len = [len(v) if type(v) == list else 0 for v in values]
max_len = max(values_len)
if max_len >= 1 and broadcast:
# broadcast
for i, v in enumerate(values):
if v is None:
continue # don't broadcast if value is unknown
if type(v) == list:
if len(v) < max_len:
values[i] = v * max_len
else:
assert len(v) == max_len
else:
values[i] = [v] * max_len
return values
def _compute_on_sympy_data(self, node, op_func):
assert len(node.output) == 1
values = self._get_int_values(node, broadcast=True)
if all([v is not None for v in values]):
is_list = [type(v) == list for v in values]
as_list = any(is_list)
if as_list:
self.sympy_data_[node.output[
0]] = [op_func(vs) for vs in zip(*values)]
else:
self.sympy_data_[node.output[0]] = op_func(values)
def _pass_on_sympy_data(self, node):
assert len(
node.
input) == 1 or node.op_type in ['Reshape', 'Unsqueeze', 'Squeeze']
self._compute_on_sympy_data(node, lambda x: x[0])
def _pass_on_shape_and_type(self, node):
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
self._get_shape(node, 0)))
def _new_symbolic_dim(self, prefix, dim):
new_dim = '{}_d{}'.format(prefix, dim)
if new_dim in self.suggested_merge_:
v = self.suggested_merge_[new_dim]
new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v
else:
new_symbolic_dim = sympy.Symbol(
new_dim, integer=True, nonnegative=True)
self.symbolic_dims_[new_dim] = new_symbolic_dim
return new_symbolic_dim
def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
return self._new_symbolic_dim('{}{}_{}_o{}_'.format(
node.op_type, self.prefix_,
list(self.out_mp_.graph.node).index(node), out_idx), dim)
def _new_symbolic_shape(self, rank, node, out_idx=0):
return [
self._new_symbolic_dim_from_output(node, out_idx, i)
for i in range(rank)
]
def _compute_conv_pool_shape(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
if len(node.input) > 1:
W_shape = self._get_sympy_shape(node, 1)
rank = len(W_shape) - 2 # number of spatial axes
kernel_shape = W_shape[-rank:]
sympy_shape[1] = W_shape[0]
else:
W_shape = None
kernel_shape = get_attribute(node, 'kernel_shape')
rank = len(kernel_shape)
assert len(sympy_shape) == rank + 2
# only need to symbolic shape inference if input has symbolic dims in spatial axes
is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]]
if not any(is_symbolic_dims):
shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
if len(shape) > 0:
assert len(sympy_shape) == len(shape)
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
return sympy_shape
dilations = get_attribute(node, 'dilations', [1] * rank)
strides = get_attribute(node, 'strides', [1] * rank)
effective_kernel_shape = [(k - 1) * d + 1
for k, d in zip(kernel_shape, dilations)]
pads = get_attribute(node, 'pads')
if pads is None:
pads = [0] * (2 * rank)
auto_pad = get_attribute(node, 'auto_pad',
b'NOTSET').decode('utf-8')
if auto_pad != 'VALID' and auto_pad != 'NOTSET':
try:
residual = [
sympy.Mod(d, s)
for d, s in zip(sympy_shape[-rank:], strides)
]
total_pads = [
max(0, (k - s) if r == 0 else (k - r)) for k, s, r in
zip(effective_kernel_shape, strides, residual)
]
except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
total_pads = [
max(0, (k - s))
for k, s in zip(effective_kernel_shape, strides)
] # assuming no residual if sympy throws error
elif auto_pad == 'VALID':
total_pads = []
else:
total_pads = [0] * rank
else:
assert len(pads) == 2 * rank
total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])]
ceil_mode = get_attribute(node, 'ceil_mode', 0)
for i in range(rank):
effective_input_size = sympy_shape[-rank + i]
if len(total_pads) > 0:
effective_input_size = effective_input_size + total_pads[i]
if ceil_mode:
strided_kernel_positions = sympy.ceiling(
(effective_input_size - effective_kernel_shape[i]) /
strides[i])
else:
strided_kernel_positions = (
effective_input_size - effective_kernel_shape[i]
) // strides[i]
sympy_shape[-rank + i] = strided_kernel_positions + 1
return sympy_shape
def _check_merged_dims(self, dims, allow_broadcast=True):
if allow_broadcast:
dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
if not all([d == dims[0] for d in dims]):
self._add_suggested_merge(dims, apply=True)
def _compute_matmul_shape(self, node, output_dtype=None):
lhs_shape = self._get_shape(node, 0)
rhs_shape = self._get_shape(node, 1)
lhs_rank = len(lhs_shape)
rhs_rank = len(rhs_shape)
lhs_reduce_dim = 0
rhs_reduce_dim = 0
assert lhs_rank > 0 and rhs_rank > 0
if lhs_rank == 1 and rhs_rank == 1:
new_shape = []
elif lhs_rank == 1:
rhs_reduce_dim = -2
new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]]
elif rhs_rank == 1:
lhs_reduce_dim = -1
new_shape = lhs_shape[:lhs_reduce_dim]
else:
lhs_reduce_dim = -1
rhs_reduce_dim = -2
new_shape = self._broadcast_shapes(
lhs_shape[:-2],
rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]]
# merge reduce dim
self._check_merged_dims(
[lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
allow_broadcast=False)
if output_dtype is None:
# infer output_dtype from input type when not specified
output_dtype = self.known_vi_[node.input[
0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_dtype,
new_shape))
def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
'''
update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches
'''
dst_tensor_type = dst_type.sequence_type.elem_type.tensor_type if is_sequence(
dst_type) else dst_type.tensor_type
src_tensor_type = src_type.sequence_type.elem_type.tensor_type if is_sequence(
src_type) else src_type.tensor_type
if dst_tensor_type.elem_type != src_tensor_type.elem_type:
node_id = node.name if node.name else node.op_type
raise ValueError(
f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}"
)
if dst_tensor_type.HasField('shape'):
for di, ds in enumerate(
zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)):
if ds[0] != ds[1]:
# create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type
# for sequence_type, clear the dimension
new_dim = onnx.TensorShapeProto.Dimension()
if not is_sequence(dst_type):
new_dim.dim_param = str(
self._new_symbolic_dim_from_output(node, out_idx,
di))
dst_tensor_type.shape.dim[di].CopyFrom(new_dim)
else:
dst_tensor_type.CopyFrom(src_tensor_type)
def _infer_ArrayFeatureExtractor(self, node):
data_shape = self._get_shape(node, 0)
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, data_shape[:-1] +
indices_shape))
def _infer_symbolic_compute_ops(self, node):
funcs = {
'Add':
lambda l: l[0] + l[1],
'Div':
lambda l: l[0] // l[1], # integer div in sympy
'Equal':
lambda l: l[0] == l[1],
'Floor':
lambda l: sympy.floor(l[0]),
'Max':
lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
'Min':
lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])),
'Mul':
lambda l: l[0] * l[1],
'Sub':
lambda l: l[0] - l[1],
'Where':
lambda l: l[1] if l[0] else l[2],
'Neg':
lambda l: -l[0]
}
assert node.op_type in funcs
self._compute_on_sympy_data(node, funcs[node.op_type])
def _infer_Cast(self, node):
self._pass_on_sympy_data(node)
def _infer_CategoryMapper(self, node):
input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
if input_type == onnx.TensorProto.STRING:
output_type = onnx.TensorProto.INT64
else:
output_type = onnx.TensorProto.STRING
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_type,
self._get_shape(node, 0)))
def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0)
# create a new symbolic dimension for Compress output
compress_len = str(self._new_symbolic_dim_from_output(node))
axis = get_attribute(node, 'axis')
if axis == None:
# when axis is not specified, input is flattened before compress so output is 1D
output_shape = [compress_len]
else:
output_shape = input_shape
output_shape[handle_negative_axis(axis, len(
input_shape))] = compress_len
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Concat(self, node):
if any([
i in self.sympy_data_ or i in self.initializers_
for i in node.input
]):
values = self._get_int_values(node)
print("=======", values, node.name, get_attribute(node, 'axis'))
if all([v is not None for v in values]):
axis = get_attribute(node, 'axis')
if axis < 0:
axis = axis + len(values[0])
assert 0 == axis
self.sympy_data_[node.output[0]] = []
for i in range(len(node.input)):
value = values[i]
if type(value) == list:
self.sympy_data_[node.output[0]].extend(value)
else:
self.sympy_data_[node.output[0]].append(value)
sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis'), len(sympy_shape))
for i_idx in range(1, len(node.input)):
input_shape = self._get_sympy_shape(node, i_idx)
if input_shape:
sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
self._update_computed_dims(sympy_shape)
# merge symbolic dims for non-concat axes
for d in range(len(sympy_shape)):
if d == axis:
continue
dims = [
self._get_shape(node, i_idx)[d]
for i_idx in range(len(node.input))
if self._get_shape(node, i_idx)
]
if all([d == dims[0] for d in dims]):
continue
merged = self._merge_symbols(dims)
if type(merged) == str:
sympy_shape[d] = self.symbolic_dims_[merged] if merged else None
else:
sympy_shape[d] = merged
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[0]].type.tensor_type.
elem_type, get_shape_from_sympy_shape(sympy_shape)))
def _infer_ConcatFromSequence(self, node):
seq_shape = self._get_shape(node, 0)
new_axis = 1 if get_attribute(node, 'new_axis') else 0
axis = handle_negative_axis(
get_attribute(node, 'axis'), len(seq_shape) + new_axis)
concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis))
new_shape = seq_shape
if new_axis:
new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:]
else:
new_shape[axis] = concat_dim
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[0]]
.type.sequence_type.elem_type.tensor_type.elem_type, new_shape))
def _infer_Constant(self, node):
t = get_attribute(node, 'value')
self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
def _infer_ConstantOfShape(self, node):
sympy_shape = self._get_int_values(node)[0]
vi = self.known_vi_[node.output[0]]
if sympy_shape is not None:
if type(sympy_shape) != list:
sympy_shape = [sympy_shape]
self._update_computed_dims(sympy_shape)
# update sympy data if output type is int, and shape is known
if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all(
[is_literal(x) for x in sympy_shape]):
self.sympy_data_[node.output[0]] = np.ones(
[int(x) for x in sympy_shape],
dtype=np.int64) * numpy_helper.to_array(
get_attribute(node, 'value', 0))
else:
# create new dynamic shape
# note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length
sympy_shape = self._new_symbolic_shape(
self._get_shape(node, 0)[0], node)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape)))
def _infer_Conv(self, node):
sympy_shape = self._compute_conv_pool_shape(node)
self._update_computed_dims(sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape)))
def _infer_Einsum(self, node):
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
equation = get_attribute(node, 'equation')
equation = equation.replace(b' ', b'')
mid_index = equation.find(b'->')
left_equation = equation[:mid_index] if mid_index != -1 else equation
num_operands = 0
num_ellipsis = 0
num_ellipsis_indices = 0
letter_to_dim = {}
terms = left_equation.split(b',')
for term in terms:
ellipsis_index = term.find(b'...')
shape = self._get_shape(node, num_operands)
rank = len(shape)
if ellipsis_index != -1:
if num_ellipsis == 0:
num_ellipsis_indices = rank - len(term) + 3
num_ellipsis = num_ellipsis + 1
for i in range(1, rank + 1):
letter = term[-i]
if letter != 46: # letter != b'.'
dim = shape[-i]
if letter not in letter_to_dim.keys():
letter_to_dim[letter] = dim
elif type(dim) != sympy.Symbol:
letter_to_dim[letter] = dim
num_operands = num_operands + 1
new_sympy_shape = []
from collections import OrderedDict
num_letter_occurrences = OrderedDict()
if mid_index != -1:
right_equation = equation[mid_index + 2:]
right_ellipsis_index = right_equation.find(b'...')
if right_ellipsis_index != -1:
for i in range(num_ellipsis_indices):
new_sympy_shape.append(shape[i])
for c in right_equation:
if c != 46: # c != b'.'
new_sympy_shape.append(letter_to_dim[c])
else:
for i in range(num_ellipsis_indices):
new_sympy_shape.append(shape[i])
for c in left_equation:
if c != 44 and c != 46: # c != b',' and c != b'.':
if c in num_letter_occurrences:
num_letter_occurrences[c] = num_letter_occurrences[
c] + 1
else:
num_letter_occurrences[c] = 1
for key, value in num_letter_occurrences.items():
if value == 1:
new_sympy_shape.append(letter_to_dim[key])
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_dtype,
new_sympy_shape))
def _infer_Expand(self, node):
expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True)
if expand_to_shape is not None:
# new_shape's dim can come from shape value
self._update_computed_dims(expand_to_shape)
shape = self._get_shape(node, 0)
new_shape = self._broadcast_shapes(
shape, get_shape_from_sympy_shape(expand_to_shape))
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, new_shape))
def _infer_Gather(self, node):
data_shape = self._get_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(data_shape))
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, data_shape[:axis] +
indices_shape + data_shape[axis +
1:]))
# for 1D input, do some sympy compute
if node.input[0] in self.sympy_data_ and len(
data_shape) == 1 and 0 == get_attribute(node, 'axis', 0):
idx = self._try_get_value(node, 1)
if idx is not None:
data = self.sympy_data_[node.input[0]]
if type(data) == list:
if type(idx) == np.ndarray and len(idx.shape) == 1:
self.sympy_data_[node.output[
0]] = [data[int(i)] for i in idx]
else:
self.sympy_data_[node.output[0]] = data[int(idx)]
else:
assert idx == 0 or idx == -1
self.sympy_data_[node.output[0]] = data
def _infer_GatherElements(self, node):
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, indices_shape))
def _infer_GatherND(self, node):
data_shape = self._get_shape(node, 0)
data_rank = len(data_shape)
indices_shape = self._get_shape(node, 1)
indices_rank = len(indices_shape)
last_index_dimension = indices_shape[-1]
assert is_literal(
last_index_dimension) and last_index_dimension <= data_rank
new_shape = indices_shape[:-1] + data_shape[last_index_dimension:]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, new_shape))
def _infer_If(self, node):
# special case for constant condition, in case there are mismatching shape from the non-executed branch
subgraphs = [
get_attribute(node, 'then_branch'), get_attribute(node,
'else_branch')
]
cond = self._try_get_value(node, 0)
if cond is not None:
if as_scalar(cond) > 0:
subgraphs[1].CopyFrom(subgraphs[0])
else:
subgraphs[0].CopyFrom(subgraphs[1])
for i_sub, subgraph in enumerate(subgraphs):
subgraph_infer = self._onnx_infer_subgraph(
node, subgraph, use_node_input=False)
for i_out in range(len(node.output)):
vi = self.known_vi_[node.output[i_out]]
if i_sub == 0:
vi.CopyFrom(subgraph.output[i_out])
vi.name = node.output[i_out]
else:
self._fuse_tensor_type(node, i_out, vi.type,
subgraph.output[i_out].type)
# pass on sympy data from subgraph, if cond is constant
if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else
1):
if subgraph.output[
i_out].name in subgraph_infer.sympy_data_:
self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[
subgraph.output[i_out].name]
def _infer_Loop(self, node):
subgraph = get_attribute(node, 'body')
assert len(subgraph.input) == len(node.input)
num_loop_carried = len(
node.input) - 2 # minus the length and initial loop condition
# when sequence_type is used as loop carried input
# needs to run subgraph infer twice if the tensor shape in sequence contains None
for i, si in enumerate(subgraph.input):
si_name = si.name
si.CopyFrom(self.known_vi_[node.input[i]])
si.name = si_name
self._onnx_infer_subgraph(node, subgraph)
# check subgraph input/output for shape changes in loop carried variables
# for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a)
# for sequence_type, propagate from output to input
need_second_infer = False
for i_out in range(1, num_loop_carried + 1):
so = subgraph.output[i_out]
so_shape = get_shape_from_value_info(so)
if is_sequence(so.type):
if so_shape and None in so_shape:
# copy shape from output to input
# note that loop input is [loop_len, cond, input_0, input_1, ...]
# while loop output is [cond, output_0, output_1, ...]
subgraph.input[i_out +
1].type.sequence_type.elem_type.CopyFrom(
so.type.sequence_type.elem_type)
need_second_infer = True
else:
si = subgraph.input[i_out + 1]
si_shape = get_shape_from_value_info(si)
for di, dims in enumerate(zip(si_shape, so_shape)):
if dims[0] != dims[1]:
new_dim = onnx.TensorShapeProto.Dimension()
new_dim.dim_param = str(
self._new_symbolic_dim_from_output(node, i_out, di))
si.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
so.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
need_second_infer = True
if need_second_infer:
if self.verbose_ > 2:
logger.debug(
"Rerun Loop: {}({}...), because of sequence in loop carried variables".
format(node.name, node.output[0]))
self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False)
# create a new symbolic dimension for iteration dependent dimension
loop_iter_dim = str(self._new_symbolic_dim_from_output(node))
for i in range(len(node.output)):
vi = self.known_vi_[node.output[i]]
vi.CopyFrom(subgraph.output[
i +
1]) # first subgraph output is condition, not in node output
if i >= num_loop_carried:
assert not is_sequence(
vi.type) # TODO: handle loop accumulation in sequence_type
subgraph_vi_dim = subgraph.output[i +
1].type.tensor_type.shape.dim
vi.type.tensor_type.shape.ClearField('dim')
vi_dim = vi.type.tensor_type.shape.dim
vi_dim.add().dim_param = loop_iter_dim
vi_dim.extend(list(subgraph_vi_dim))
vi.name = node.output[i]
def _infer_MatMul(self, node):
self._compute_matmul_shape(node)
def _infer_MatMulInteger(self, node):
self._compute_matmul_shape(node, onnx.TensorProto.INT32)
def _infer_NonMaxSuppression(self, node):
selected = str(self._new_symbolic_dim_from_output(node))
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], onnx.TensorProto.INT64, [selected, 3]))
def _infer_NonZero(self, node):
input_rank = self._get_shape_rank(node, 0)
# create a new symbolic dimension for NonZero output
nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1))
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
def _infer_OneHot(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
depth = self._try_get_value(node, 1)
axis = get_attribute(node, 'axis', -1)
axis = handle_negative_axis(axis, len(sympy_shape) + 1)
new_shape = get_shape_from_sympy_shape(sympy_shape[:axis] + [
self._new_symbolic_dim_from_output(node)
if not is_literal(depth) else depth
] + sympy_shape[axis:])
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[2]].type.tensor_type.elem_type, new_shape))
def _infer_Pad(self, node):
if get_opset(self.out_mp_) <= 10:
pads = get_attribute(node, 'pads')
else:
pads = self._try_get_value(node, 1)
sympy_shape = self._get_sympy_shape(node, 0)
rank = len(sympy_shape)
if pads is not None:
assert len(pads) == 2 * rank
new_sympy_shape = [
d + pad_up + pad_down for d, pad_up, pad_down in
zip(sympy_shape, pads[:rank], pads[rank:])
]
self._update_computed_dims(new_sympy_shape)
else:
# dynamic pads, create new symbolic dimensions
new_sympy_shape = self._new_symbolic_shape(rank, node)
output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)))
def _infer_Pool(self, node):
sympy_shape = self._compute_conv_pool_shape(node)
self._update_computed_dims(sympy_shape)
for o in node.output:
if not o:
continue
vi = self.known_vi_[o]
vi.CopyFrom(
helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(
sympy_shape)))
def _infer_aten_bitwise_or(self, node):
shape0 = self._get_shape(node, 0)
shape1 = self._get_shape(node, 1)
new_shape = self._broadcast_shapes(shape0, shape1)
t0 = self.known_vi_[node.input[0]]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], t0.type.tensor_type.elem_type, new_shape))
def _infer_aten_diagonal(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
rank = len(sympy_shape)
offset = self._try_get_value(node, 1)
dim1 = self._try_get_value(node, 2)
dim2 = self._try_get_value(node, 3)
assert offset is not None and dim1 is not None and dim2 is not None
dim1 = handle_negative_axis(dim1, rank)
dim2 = handle_negative_axis(dim2, rank)
new_shape = []
for dim, val in enumerate(sympy_shape):
if dim not in [dim1, dim2]:
new_shape.append(val)
shape1 = sympy_shape[dim1]
shape2 = sympy_shape[dim2]
if offset >= 0:
diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset))
else:
diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2))
new_shape.append(diag_shape)
if node.output[0]:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
new_shape)))
def _infer_aten_multinomial(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
rank = len(sympy_shape)
assert rank in [1, 2]
num_samples = self._try_get_value(node, 1)
di = rank - 1
last_dim = num_samples if num_samples else str(
self._new_symbolic_dim_from_output(node, 0, di))
output_shape = sympy_shape[:-1] + [last_dim]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], onnx.TensorProto.INT64,
get_shape_from_sympy_shape(output_shape)))
def _infer_aten_pool2d(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
assert len(sympy_shape) == 4
sympy_shape[-2:] = [
self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]
]
self._update_computed_dims(sympy_shape)
for i, o in enumerate(node.output):
if not o:
continue
vi = self.known_vi_[o]
elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[
node.input[0]].type.tensor_type.elem_type
vi.CopyFrom(
helper.make_tensor_value_info(
o, elem_type, get_shape_from_sympy_shape(sympy_shape)))
def _infer_aten_unfold(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
dimension = self._try_get_value(node, 1)
size = self._try_get_value(node, 2)
step = self._try_get_value(node, 3)
if dimension is not None and size is not None and step is not None:
assert dimension < len(sympy_shape)
sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1
sympy_shape.append(size)
else:
rank = len(sympy_shape)
sympy_shape = self._new_symbolic_shape(rank + 1, node)
self._update_computed_dims(sympy_shape)
if node.output[0]:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
sympy_shape)))
def _infer_aten_argmax(self, node):
new_shape = None
if node.input[1] == '':
# The argmax of the flattened input is returned.
new_shape = []
else:
dim = self._try_get_value(node, 1)
keepdim = self._try_get_value(node, 2)
if keepdim is not None:
sympy_shape = self._get_sympy_shape(node, 0)
if dim is not None:
dim = handle_negative_axis(dim, len(sympy_shape))
if keepdim:
sympy_shape[dim] = 1
else:
del sympy_shape[dim]
else:
rank = len(sympy_shape)
sympy_shape = self._new_symbolic_shape(rank if keepdim else
rank - 1, node)
self._update_computed_dims(sympy_shape)
new_shape = get_shape_from_sympy_shape(sympy_shape)
if node.output[0] and new_shape is not None:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], onnx.TensorProto.INT64, new_shape))
def _infer_aten_bce(self, node):
reduction = self._try_get_value(node, 4)
if reduction is None:
reduction = 1
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
if reduction == 0:
vi.type.tensor_type.elem_type = elem_type
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
else:
vi.CopyFrom(
helper.make_tensor_value_info(vi.name, elem_type,
self._get_shape(node, 0)))
def _infer_BatchNormalization(self, node):
self._propagate_shape_and_type(node)
# this works for opsets < 14 and 14 since we check i < len(node.output) in the loop
for i in [1, 2, 3, 4]:
if i < len(node.output) and node.output[i] != "":
# all of these parameters have the same shape as the 1st input
self._propagate_shape_and_type(
node, input_index=1, output_index=i)
def _infer_Range(self, node):
vi = self.known_vi_[node.output[0]]
input_data = self._get_int_values(node)
if all([i is not None for i in input_data]):
start = as_scalar(input_data[0])
limit = as_scalar(input_data[1])
delta = as_scalar(input_data[2])
new_sympy_shape = [
sympy.Max(sympy.ceiling((limit - start) / delta), 0)
]
else:
new_sympy_shape = [self._new_symbolic_dim_from_output(node)]
self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[0]].type.tensor_type.
elem_type, get_shape_from_sympy_shape(new_sympy_shape)))
def _infer_ReduceSum(self, node):
keep_dims = get_attribute(node, 'keepdims', 1)
if get_opset(self.out_mp_) >= 13 and len(node.input) > 1:
# ReduceSum changes axes to input[1] in opset 13
axes = self._try_get_value(node, 1)
vi = self.known_vi_[node.output[0]]
if axes is None:
assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[
0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
self._new_symbolic_shape(
self._get_shape_rank(node, 0), node))))
else:
shape = self._get_shape(node, 0)
output_shape = []
axes = [handle_negative_axis(a, len(shape)) for a in axes]
for i, d in enumerate(shape):
if i in axes:
if keep_dims:
output_shape.append(1)
else:
output_shape.append(d)
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], self.known_vi_[node.input[
0]].type.tensor_type.elem_type, output_shape))
def _infer_ReduceProd(self, node):
axes = get_attribute(node, 'axes')
keep_dims = get_attribute(node, 'keepdims', 1)
if keep_dims == 0 and axes == [0]:
data = self._get_int_values(node)[0]
if data is not None:
self.sympy_data_[node.output[0]] = sympy_reduce_product(data)
def _infer_Reshape(self, node):
shape_value = self._try_get_value(node, 1)
vi = self.known_vi_[node.output[0]]
if shape_value is None:
shape_shape = self._get_shape(node, 1)
assert len(shape_shape) == 1
shape_rank = shape_shape[0]
assert is_literal(shape_rank)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(
self._new_symbolic_shape(shape_rank, node))))
else:
input_sympy_shape = self._get_sympy_shape(node, 0)
total = int(1)
for d in input_sympy_shape:
total = total * d
new_sympy_shape = []
deferred_dim_idx = -1
non_deferred_size = int(1)
for i, d in enumerate(shape_value):
if type(d) == sympy.Symbol:
new_sympy_shape.append(d)
elif d == 0:
new_sympy_shape.append(input_sympy_shape[i])
non_deferred_size = non_deferred_size * input_sympy_shape[i]
else:
new_sympy_shape.append(d)
if d == -1:
deferred_dim_idx = i
elif d != 0:
non_deferred_size = non_deferred_size * d
assert new_sympy_shape.count(-1) < 2
if -1 in new_sympy_shape:
new_dim = total // non_deferred_size
new_sympy_shape[deferred_dim_idx] = new_dim
self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
self._pass_on_sympy_data(node)
def _infer_Resize(self, node):
vi = self.known_vi_[node.output[0]]
input_sympy_shape = self._get_sympy_shape(node, 0)
if get_opset(self.out_mp_) <= 10:
scales = self._try_get_value(node, 1)
if scales is not None:
new_sympy_shape = [
sympy.simplify(sympy.floor(d * s))
for d, s in zip(input_sympy_shape, scales)
]
self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[
0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
else:
roi = self._try_get_value(node, 1)
scales = self._try_get_value(node, 2)
sizes = self._try_get_value(node, 3)
if sizes is not None:
new_sympy_shape = [
sympy.simplify(sympy.floor(s)) for s in sizes
]
self._update_computed_dims(new_sympy_shape)
elif scales is not None:
rank = len(scales)
if get_attribute(node, 'coordinate_transformation_mode'
) == 'tf_crop_and_resize':
assert len(roi) == 2 * rank
roi_start = list(roi)[:rank]
roi_end = list(roi)[rank:]
else:
roi_start = [0] * rank
roi_end = [1] * rank
scales = list(scales)
new_sympy_shape = [
sympy.simplify(sympy.floor(d * (end - start) * scale))
for d, start, end, scale in
zip(input_sympy_shape, roi_start, roi_end, scales)
]
self._update_computed_dims(new_sympy_shape)
else:
new_sympy_shape = self._new_symbolic_shape(
self._get_shape_rank(node, 0), node)
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
new_sympy_shape)))
def _infer_Scan(self, node):
subgraph = get_attribute(node, 'body')
num_scan_inputs = get_attribute(node, 'num_scan_inputs')
scan_input_axes = get_attribute(node, 'scan_input_axes',
[0] * num_scan_inputs)
num_scan_states = len(node.input) - num_scan_inputs
scan_input_axes = [
handle_negative_axis(
ax, self._get_shape_rank(node, i + num_scan_states))
for i, ax in enumerate(scan_input_axes)
]
# We may have cases where the subgraph has optionial inputs that appear in both subgraph's input and initializer,
# but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs.
assert len(subgraph.input) >= len(node.input)
subgraph_inputs = subgraph.input[:len(node.input)]
for i, si in enumerate(subgraph_inputs):
subgraph_name = si.name
si.CopyFrom(self.known_vi_[node.input[i]])
if i >= num_scan_states:
scan_input_dim = si.type.tensor_type.shape.dim
scan_input_dim.remove(
scan_input_dim[scan_input_axes[i - num_scan_states]])
si.name = subgraph_name
self._onnx_infer_subgraph(node, subgraph)
num_scan_outputs = len(node.output) - num_scan_states
scan_output_axes = get_attribute(node, 'scan_output_axes',
[0] * num_scan_outputs)
scan_input_dim = get_shape_from_type_proto(
self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
for i, o in enumerate(node.output):
vi = self.known_vi_[o]
if i >= num_scan_states:
shape = get_shape_from_type_proto(subgraph.output[i].type)
new_dim = handle_negative_axis(
scan_output_axes[i - num_scan_states], len(shape) + 1)
shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:]
vi.CopyFrom(
helper.make_tensor_value_info(o, subgraph.output[
i].type.tensor_type.elem_type, shape))
else:
vi.CopyFrom(subgraph.output[i])
vi.name = o
def _infer_ScatterElements(self, node):
data_shape = self._get_shape(node, 0)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, data_shape))
def _infer_SequenceAt(self, node):
# need to create new symbolic dimension if sequence shape has None:
seq_shape = self._get_shape(node, 0)
vi = self.known_vi_[node.output[0]]
if seq_shape is not None:
for di, d in enumerate(seq_shape):
if d is not None:
continue
new_dim = onnx.TensorShapeProto.Dimension()
new_dim.dim_param = str(
self._new_symbolic_dim_from_output(node, 0, di))
vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
def _infer_SequenceInsert(self, node):
# workaround bug in onnx's shape inference
vi_seq = self.known_vi_[node.input[0]]
vi_tensor = self.known_vi_[node.input[1]]
vi_out_seq = self.known_vi_[node.output[0]]
vi_out_seq.CopyFrom(vi_seq)
vi_out_seq.name = node.output[0]
self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type)
def _infer_Shape(self, node):
self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)
def _infer_Size(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape)
self.known_vi_[node.output[0]].CopyFrom(
helper.make_tensor_value_info(node.output[0],
onnx.TensorProto.INT64, []))
def _infer_Slice(self, node):
def less_equal(x, y):
try:
return bool(x <= y)
except TypeError:
pass
try:
return bool(y >= x)
except TypeError:
pass
try:
return bool(-x >= -y)
except TypeError:
pass
try:
return bool(-y <= -x)
except TypeError:
# the last attempt; this may raise TypeError
return bool(y - x >= 0)
def handle_negative_index(index, bound):
""" normalizes a negative index to be in [0, bound) """
try:
if not less_equal(0, index):
if is_literal(index) and index <= -self.int_max_:
# this case is handled separately
return index
return bound + index
except TypeError:
logger.warning("Cannot determine if {} < 0".format(index))
return index
if get_opset(self.out_mp_) <= 9:
axes = get_attribute(node, 'axes')
starts = get_attribute(node, 'starts')
ends = get_attribute(node, 'ends')
if not axes:
axes = list(range(len(starts)))
steps = [1] * len(axes)
else:
starts = as_list(self._try_get_value(node, 1), keep_none=True)
ends = as_list(self._try_get_value(node, 2), keep_none=True)
axes = self._try_get_value(node, 3)
steps = self._try_get_value(node, 4)
if axes is None and not (starts is None and ends is None):
axes = list(
range(0, len(starts if starts is not None else ends)))
if steps is None and not (starts is None and ends is None):
steps = [1] * len(starts if starts is not None else ends)
axes = as_list(axes, keep_none=True)
steps = as_list(steps, keep_none=True)
new_sympy_shape = self._get_sympy_shape(node, 0)
if starts is None or ends is None:
if axes is None:
for i in range(len(new_sympy_shape)):
new_sympy_shape[i] = self._new_symbolic_dim_from_output(
node, 0, i)
else:
new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape)
for i in axes:
new_sympy_shape[i] = self._new_symbolic_dim_from_output(
node, 0, i)
else:
for i, s, e, t in zip(axes, starts, ends, steps):
e = handle_negative_index(e, new_sympy_shape[i])
if is_literal(e):
if e >= self.int_max_:
e = new_sympy_shape[i]
elif e <= -self.int_max_:
e = 0 if s > 0 else -1
elif is_literal(new_sympy_shape[i]):
if e < 0:
e = max(0, e + new_sympy_shape[i])
e = min(e, new_sympy_shape[i])
else:
if e > 0:
e = sympy.Min(
e, new_sympy_shape[i]
) if e > 1 else e #special case for slicing first to make computation easier
else:
if is_literal(new_sympy_shape[i]):
e = sympy.Min(e, new_sympy_shape[i])
else:
try:
if not less_equal(e, new_sympy_shape[i]):
e = new_sympy_shape[i]
except Exception:
logger.warning(
'Unable to determine if {} <= {}, treat as equal'.
format(e, new_sympy_shape[i]))
e = new_sympy_shape[i]
s = handle_negative_index(s, new_sympy_shape[i])
if is_literal(new_sympy_shape[i]) and is_literal(s):
s = max(0, min(s, new_sympy_shape[i]))
new_sympy_shape[i] = sympy.simplify(
(e - s + t + (-1 if t > 0 else 1)) // t)
self._update_computed_dims(new_sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
# handle sympy_data if needed, for slice in shape computation
if (node.input[0] in self.sympy_data_ and [0] == axes and
len(starts) == 1 and len(ends) == 1 and len(steps) == 1):
input_sympy_data = self.sympy_data_[node.input[0]]
if type(input_sympy_data) == list or (
type(input_sympy_data) == np.array and
len(input_sympy_data.shape) == 1):
self.sympy_data_[node.output[0]] = input_sympy_data[starts[
0]:ends[0]:steps[0]]
def _infer_SoftmaxCrossEntropyLoss(self, node):
vi = self.known_vi_[node.output[0]]
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi.type.tensor_type.elem_type = elem_type
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
if len(node.output) > 1:
data_shape = self._get_shape(node, 0)
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(
helper.make_tensor_value_info(vi.name, elem_type, data_shape))
def _infer_Split_Common(self, node, make_value_info_func):
input_sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(input_sympy_shape))
split = get_attribute(node, 'split')
if not split:
num_outputs = len(node.output)
split = [input_sympy_shape[axis] /
sympy.Integer(num_outputs)] * num_outputs
self._update_computed_dims(split)
else:
split = [sympy.Integer(s) for s in split]
for i_o in range(len(split)):
vi = self.known_vi_[node.output[i_o]]
vi.CopyFrom(
make_value_info_func(node.output[i_o], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
input_sympy_shape[:axis] + [
split[i_o]
] + input_sympy_shape[axis + 1:])))
self.known_vi_[vi.name] = vi
def _infer_Split(self, node):
self._infer_Split_Common(node, helper.make_tensor_value_info)
def _infer_SplitToSequence(self, node):
self._infer_Split_Common(node, helper.make_sequence_value_info)
def _infer_Squeeze(self, node):
input_shape = self._get_shape(node, 0)
op_set = get_opset(self.out_mp_)
# Depending on op-version 'axes' are provided as attribute or via 2nd input
if op_set < 13:
axes = get_attribute(node, 'axes')
assert self._try_get_value(node, 1) is None
else:
axes = self._try_get_value(node, 1)
assert get_attribute(node, 'axes') is None
if axes is None:
# No axes have been provided (neither via attribute nor via input).
# In this case the 'Shape' op should remove all axis with dimension 1.
# For symbolic dimensions we guess they are !=1.
output_shape = [s for s in input_shape if s != 1]
if self.verbose_ > 0:
symbolic_dimensions = [s for s in input_shape if type(s) != int]
if len(symbolic_dimensions) > 0:
logger.debug(
f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
+
f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}"
)
else:
axes = [handle_negative_axis(a, len(input_shape)) for a in axes]
output_shape = []
for i in range(len(input_shape)):
if i not in axes:
output_shape.append(input_shape[i])
else:
assert input_shape[i] == 1 or type(input_shape[i]) != int
if self.verbose_ > 0 and type(input_shape[i]) != int:
logger.debug(
f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
+
f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1."
)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
self._pass_on_sympy_data(node)
def _infer_Tile(self, node):
repeats_value = self._try_get_value(node, 1)
new_sympy_shape = []
if repeats_value is not None:
input_sympy_shape = self._get_sympy_shape(node, 0)
for i, d in enumerate(input_sympy_shape):
new_dim = d * repeats_value[i]
new_sympy_shape.append(new_dim)
self._update_computed_dims(new_sympy_shape)
else:
new_sympy_shape = self._new_symbolic_shape(
self._get_shape_rank(node, 0), node)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
def _infer_TopK(self, node):
rank = self._get_shape_rank(node, 0)
axis = handle_negative_axis(get_attribute(node, 'axis', -1), rank)
new_shape = self._get_shape(node, 0)
if get_opset(self.out_mp_) <= 9:
k = get_attribute(node, 'k')
else:
k = self._get_int_values(node)[1]
if k == None:
k = self._new_symbolic_dim_from_output(node)
else:
k = as_scalar(k)
if type(k) in [int, str]:
new_shape[axis] = k
else:
new_sympy_shape = self._get_sympy_shape(node, 0)
new_sympy_shape[axis] = k
self._update_computed_dims(
new_sympy_shape
) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape
new_shape = get_shape_from_sympy_shape(new_sympy_shape)
for i_o in range(len(node.output)):
vi = self.known_vi_[node.output[i_o]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
i_o], vi.type.tensor_type.elem_type, new_shape))
def _infer_Transpose(self, node):
if node.input[0] in self.sympy_data_:
data_shape = self._get_shape(node, 0)
perm = get_attribute(node, 'perm',
reversed(list(range(len(data_shape)))))
input_data = self.sympy_data_[node.input[0]]
self.sympy_data_[node.output[0]] = np.transpose(
np.array(input_data).reshape(*data_shape),
axes=tuple(perm)).flatten().tolist()
def _infer_Unsqueeze(self, node):
input_shape = self._get_shape(node, 0)
op_set = get_opset(self.out_mp_)
# Depending on op-version 'axes' are provided as attribute or via 2nd input
if op_set < 13:
axes = get_attribute(node, 'axes')
assert self._try_get_value(node, 1) is None
else:
axes = self._try_get_value(node, 1)
assert get_attribute(node, 'axes') is None
output_rank = len(input_shape) + len(axes)
axes = [handle_negative_axis(a, output_rank) for a in axes]
input_axis = 0
output_shape = []
for i in range(output_rank):
if i in axes:
output_shape.append(1)
else:
output_shape.append(input_shape[input_axis])
input_axis += 1
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
self._pass_on_sympy_data(node)
def _infer_ZipMap(self, node):
map_key_type = None
if get_attribute(node, 'classlabels_int64s') is not None:
map_key_type = onnx.TensorProto.INT64
elif get_attribute(node, 'classlabels_strings') is not None:
map_key_type = onnx.TensorProto.STRING
assert map_key_type is not None
new_vi = onnx.ValueInfoProto()
new_vi.name = node.output[0]
new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(new_vi)
def _infer_Attention(self, node):
shape = self._get_shape(node, 0)
shape_bias = self._get_shape(node, 2)
assert len(shape) == 3 and len(shape_bias) == 1
qkv_hidden_sizes_attr = get_attribute(node, 'qkv_hidden_sizes')
if qkv_hidden_sizes_attr is not None:
assert len(qkv_hidden_sizes_attr) == 3
shape[2] = int(qkv_hidden_sizes_attr[2])
else:
shape[2] = int(shape_bias[0] / 3)
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_dtype, shape))
if len(node.output) > 1:
# input shape: (batch_size, sequence_length, hidden_size)
# past shape: (2, batch_size, num_heads, past_sequence_length, head_size)
# mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len)
# present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length
input_shape = self._get_shape(node, 0)
past_shape = self._get_shape(node, 4)
mask_shape = self._get_shape(node, 3)
if len(past_shape) == 5:
if len(mask_shape) in [2, 3]:
past_shape[3] = mask_shape[-1]
elif isinstance(input_shape[1], int) and isinstance(
past_shape[3], int):
past_shape[3] = input_shape[1] + past_shape[3]
else:
past_shape[3] = f"{past_shape[3]}+{input_shape[1]}"
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(
helper.make_tensor_value_info(vi.name, output_dtype,
past_shape))
def _infer_BiasGelu(self, node):
self._propagate_shape_and_type(node)
def _infer_FastGelu(self, node):
self._propagate_shape_and_type(node)
def _infer_Gelu(self, node):
self._propagate_shape_and_type(node)
def _infer_LayerNormalization(self, node):
self._propagate_shape_and_type(node)
def _infer_LongformerAttention(self, node):
self._propagate_shape_and_type(node)
def _infer_EmbedLayerNormalization(self, node):
input_ids_shape = self._get_shape(node, 0)
word_embedding_shape = self._get_shape(node, 2)
assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2
output_shape = input_ids_shape + [word_embedding_shape[1]]
word_embedding_dtype = self.known_vi_[node.input[
2]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], word_embedding_dtype,
output_shape))
mask_index_shape = [input_ids_shape[0]]
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
1], onnx.TensorProto.INT32, mask_index_shape))
if len(node.output) > 2:
# Optional output of add before layer nomalization is done
# shape is same as the output
vi = self.known_vi_[node.output[2]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
2], word_embedding_dtype, output_shape))
def _infer_SkipLayerNormalization(self, node):
self._propagate_shape_and_type(node)
def _infer_PythonOp(self, node):
output_tensor_types = get_attribute(node, 'output_tensor_types')
assert output_tensor_types
output_tensor_ranks = get_attribute(node, 'output_tensor_ranks')
assert output_tensor_ranks
# set the context output seperately.
# The first output is autograd's context.
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0],
onnx.TensorProto.INT64, []))
# Outputs after autograd's context are tensors.
# We assume their ranks are fixed for different model inputs.
for i in range(len(node.output) - 1):
# Process the i-th tensor outputs.
vi = self.known_vi_[node.output[i + 1]]
sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node)
shape = get_shape_from_sympy_shape(sympy_shape)
value_info = helper.make_tensor_value_info(
node.output[i + 1], output_tensor_types[i], shape)
vi.CopyFrom(value_info)
def _propagate_shape_and_type(self, node, input_index=0, output_index=0):
shape = self._get_shape(node, input_index)
output_dtype = self.known_vi_[node.input[
input_index]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[output_index]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[output_index],
output_dtype, shape))
def _is_none_dim(self, dim_value):
if type(dim_value) != str:
return False
if "unk__" not in dim_value:
return False
if dim_value in self.symbolic_dims_.keys():
return False
return True
def _is_shape_contains_none_dim(self, out_shape):
for out in out_shape:
if self._is_none_dim(out):
return out
return None
def _infer_impl(self, start_sympy_data=None):
self.sympy_data_ = start_sympy_data or {}
self.out_mp_.graph.ClearField('value_info')
self._apply_suggested_merge(graph_input_only=True)
self.input_symbols_ = set()
for i in self.out_mp_.graph.input:
input_shape = get_shape_from_value_info(i)
if input_shape is None:
continue
if is_sequence(i.type):
input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
else:
input_dims = i.type.tensor_type.shape.dim
for i_dim, dim in enumerate(input_shape):
if dim is None:
# some models use None for symbolic dim in input, replace it with a string
input_dims[i_dim].dim_param = str(
self._new_symbolic_dim(i.name, i_dim))
self.input_symbols_.update(
[d for d in input_shape if type(d) == str])
for s in self.input_symbols_:
if s in self.suggested_merge_:
s_merge = self.suggested_merge_[s]
assert s_merge in self.symbolic_dims_
self.symbolic_dims_[s] = self.symbolic_dims_[s_merge]
else:
# Since inputs are not produced by other ops, we can assume positivity
self.symbolic_dims_[s] = sympy.Symbol(
s, integer=True, positive=True)
# create a temporary ModelProto for single node inference
# note that we remove initializer to have faster inference
# for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways
self.tmp_mp_ = onnx.ModelProto()
self.tmp_mp_.CopyFrom(self.out_mp_)
self.tmp_mp_.graph.ClearField('initializer')
# compute prerequesite for node for topological sort
# node with subgraphs may have dependency on implicit inputs, which will affect topological sort
prereq_for_node = {
} # map from node to all its inputs, including implicit ones in subgraph
def get_prereq(node):
names = set(i for i in node.input if i)
subgraphs = []
if 'If' == node.op_type:
subgraphs = [
get_attribute(node, 'then_branch'),
get_attribute(node, 'else_branch')
]
elif node.op_type in ['Loop', 'Scan']:
subgraphs = [get_attribute(node, 'body')]
for g in subgraphs:
g_outputs_and_initializers = {i.name for i in g.initializer}
g_prereq = set()
for n in g.node:
g_outputs_and_initializers.update(n.output)
for n in g.node:
g_prereq.update([
i for i in get_prereq(n)
if i not in g_outputs_and_initializers
])
names.update(g_prereq)
# remove subgraph inputs from g_prereq since those are local-only
for i in g.input:
if i.name in names:
names.remove(i.name)
return names
for n in self.tmp_mp_.graph.node:
prereq_for_node[n.output[0]] = get_prereq(n)
# topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
sorted_nodes = []
sorted_known_vi = set([
i.name for i in list(self.out_mp_.graph.input) +
list(self.out_mp_.graph.initializer)
])
if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]):
# Loop/Scan will have some graph output in graph inputs, so don't do topological sort
sorted_nodes = self.out_mp_.graph.node
else:
while not all(
[o.name in sorted_known_vi for o in self.out_mp_.graph.output]):
old_sorted_nodes_len = len(sorted_nodes)
for node in self.out_mp_.graph.node:
if (node.output[0] not in sorted_known_vi) and all([
i in sorted_known_vi
for i in prereq_for_node[node.output[0]] if i
]):
sorted_known_vi.update(node.output)
sorted_nodes.append(node)
if old_sorted_nodes_len == len(sorted_nodes) and not all([
o.name in sorted_known_vi
for o in self.out_mp_.graph.output
]):
raise Exception('Invalid model with cyclic graph')
for node in sorted_nodes:
assert all([i in self.known_vi_ for i in node.input if i])
self._onnx_infer_single_node(node)
known_aten_op = False
if node.op_type in self.dispatcher_:
self.dispatcher_[node.op_type](node)
elif node.op_type in ['ConvTranspose']:
# onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
# before adding symbolic compute for them
# mark the output type as UNDEFINED to allow guessing of rank
vi = self.known_vi_[node.output[0]]
if len(vi.type.tensor_type.shape.dim) == 0:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
elif node.op_type == 'ATen' and node.domain == 'org.pytorch.aten':
for attr in node.attribute:
# TODO: Is overload_name needed?
if attr.name == 'operator':
aten_op_name = attr.s.decode('utf-8') if isinstance(
attr.s, bytes) else attr.s
if aten_op_name in self.aten_op_dispatcher_:
known_aten_op = True
self.aten_op_dispatcher_[aten_op_name](node)
break
if self.verbose_ > 2:
logger.debug(node.op_type + ': ' + node.name)
for i, name in enumerate(node.input):
logger.debug(' Input {}: {} {}'.format(
i, name, 'initializer'
if name in self.initializers_ else ''))
# onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb']
# symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case
if node.op_type in [
'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger',
'MatMulInteger16', 'Where', 'Sum'
]:
vi = self.known_vi_[node.output[0]]
out_rank = len(get_shape_from_type_proto(vi.type))
in_shapes = [
self._get_shape(node, i) for i in range(len(node.input))
]
for d in range(out_rank - (2 if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16'
] else 0)):
in_dims = [
s[len(s) - out_rank + d] for s in in_shapes
if len(s) + d >= out_rank
]
if len(in_dims) > 1:
self._check_merged_dims(in_dims, allow_broadcast=True)
for i_o in range(len(node.output)):
vi = self.known_vi_[node.output[i_o]]
out_type = vi.type
out_type_kind = out_type.WhichOneof('value')
# do not process shape for non-tensors
if out_type_kind not in [
'tensor_type', 'sparse_tensor_type', None
]:
if self.verbose_ > 2:
if out_type_kind == 'sequence_type':
seq_cls_type = out_type.sequence_type.elem_type.WhichOneof(
'value')
if 'tensor_type' == seq_cls_type:
logger.debug(' {}: sequence of {} {}'.format(
node.output[i_o],
str(get_shape_from_value_info(vi)),
onnx.TensorProto.DataType.Name(
vi.type.sequence_type.elem_type.
tensor_type.elem_type)))
else:
logger.debug(' {}: sequence of {}'.format(
node.output[i_o], seq_cls_type))
else:
logger.debug(' {}: {}'.format(node.output[i_o],
out_type_kind))
continue
out_shape = get_shape_from_value_info(vi)
out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
if self.verbose_ > 2:
logger.debug(' {}: {} {}'.format(
node.output[i_o],
str(out_shape),
onnx.TensorProto.DataType.Name(
vi.type.tensor_type.elem_type)))
if node.output[i_o] in self.sympy_data_:
logger.debug(' Sympy Data: ' + str(self.sympy_data_[
node.output[i_o]]))
# onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain
if (out_shape is not None and
(None in out_shape or
self._is_shape_contains_none_dim(out_shape))
) or out_type_undefined:
if self.auto_merge_:
if node.op_type in [
'Add', 'Sub', 'Mul', 'Div', 'MatMul',
'MatMulInteger', 'MatMulInteger16', 'Concat',
'Where', 'Sum', 'Equal', 'Less', 'Greater',
'LessOrEqual', 'GreaterOrEqual'
]:
shapes = [
self._get_shape(node, i)
for i in range(len(node.input))
]
if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16'
]:
if None in out_shape or self._is_shape_contains_none_dim(
out_shape):
if None in out_shape:
idx = out_shape.index(None)
else:
idx = out_shape.index(
self._is_shape_contains_none_dim(
out_shape))
dim_idx = [
len(s) - len(out_shape) + idx
for s in shapes
]
# only support auto merge for MatMul for dim < rank-2 when rank > 2
assert len(
shapes[0]) > 2 and dim_idx[0] < len(
shapes[0]) - 2
assert len(
shapes[1]) > 2 and dim_idx[1] < len(
shapes[1]) - 2
elif node.op_type == 'Expand':
# auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq])
shapes = [
self._get_shape(node, 0), self._get_value(node,
1)
]
else:
shapes = []
if shapes:
for idx in range(len(out_shape)):
if out_shape[
idx] is not None and not self._is_none_dim(
out_shape[idx]):
continue
# note that the broadcasting rule aligns from right to left
# if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge
dim_idx = [
len(s) - len(out_shape) + idx
for s in shapes
]
if len(dim_idx) > 0:
self._add_suggested_merge([
s[i] if is_literal(s[i]) else str(s[i])
for s, i in zip(shapes, dim_idx)
if i >= 0
])
self.run_ = True
else:
self.run_ = False
else:
self.run_ = False
# create new dynamic dims for ops not handled by symbolic shape inference
if self.run_ == False and not node.op_type in self.dispatcher_ and not known_aten_op:
is_unknown_op = out_type_undefined and (
out_shape is None or len(out_shape) == 0)
if is_unknown_op:
# unknown op to ONNX, maybe from higher opset or other domain
# only guess the output rank from input 0 when using guess_output_rank option
out_rank = self._get_shape_rank(
node, 0) if self.guess_output_rank_ else -1
else:
# valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape
out_rank = len(out_shape)
if out_rank >= 0:
new_shape = self._new_symbolic_shape(out_rank, node,
i_o)
if out_type_undefined:
# guess output data type from input vi if not defined
out_dtype = self.known_vi_[node.input[
0]].type.tensor_type.elem_type
else:
# otherwise, use original data type
out_dtype = vi.type.tensor_type.elem_type
vi.CopyFrom(
helper.make_tensor_value_info(
vi.name, out_dtype,
get_shape_from_sympy_shape(new_shape)))
if self.verbose_ > 0:
if is_unknown_op:
logger.debug(
"Possible unknown op: {} node: {}, guessing {} shape".
format(node.op_type, node.name,
vi.name))
if self.verbose_ > 2:
logger.debug(' {}: {} {}'.format(
node.output[i_o],
str(new_shape),
vi.type.tensor_type.elem_type))
self.run_ = True
continue # continue the inference after guess, no need to stop as no merge is needed
if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
logger.debug(
'Stopping at incomplete shape inference at ' +
node.op_type + ': ' + node.name)
logger.debug('node inputs:')
for i in node.input:
logger.debug(self.known_vi_[i])
logger.debug('node outputs:')
for o in node.output:
logger.debug(self.known_vi_[o])
if self.auto_merge_ and not out_type_undefined:
logger.debug('Merging: ' + str(
self.suggested_merge_))
return False
self.run_ = False
return True
def _update_output_from_vi(self):
for output in self.out_mp_.graph.output:
if output.name in self.known_vi_:
output.CopyFrom(self.known_vi_[output.name])
@staticmethod
def infer_shapes(in_mp,
int_max=2**31 - 1,
auto_merge=False,
guess_output_rank=False,
verbose=0):
onnx_opset = get_opset(in_mp)
if (not onnx_opset) or onnx_opset < 7:
logger.warning('Only support models of onnx opset 7 and above.')
return None
symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(in_mp)
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl()
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
raise Exception("Incomplete symbolic shape inference")
return symbolic_shape_inference.out_mp_
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True, help='The input model file')
parser.add_argument('--output', help='The output model file')
parser.add_argument(
'--auto_merge',
help='Automatically merge symbolic dims when confliction happens',
action='store_true',
default=False)
parser.add_argument(
'--int_max',
help='maximum value for integer to be treated as boundless for ops like slice',
type=int,
default=2**31 - 1)
parser.add_argument(
'--guess_output_rank',
help='guess output rank to be the same as input 0 for unknown ops',
action='store_true',
default=False)
parser.add_argument(
'--verbose',
help='Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed',
type=int,
default=0)
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
logger.info('input model: ' + args.input)
if args.output:
logger.info('output model ' + args.output)
logger.info('Doing symbolic shape inference...')
out_mp = SymbolicShapeInference.infer_shapes(
onnx.load(args.input), args.int_max, args.auto_merge,
args.guess_output_rank, args.verbose)
if args.output and out_mp:
onnx.save(out_mp, args.output)
logger.info('Done!')
#!/bin/bash
set -e
if [ $# != 3 ];then
# ./local/onnx_opt.sh model.old.onnx model.opt.onnx "audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
echo "usage: $0 onnx.model.in onnx.model.out input_shape "
exit 1
fi
# onnx optimizer
pip install onnx-simplifier
in=$1
out=$2
input_shape=$3
check_n=3
onnxsim $in $out $check_n --dynamic-input-shape --input-shape $input_shape
\ No newline at end of file
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# prune model by output names
import argparse
import copy
import sys
import onnx
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--output_names',
required=True,
nargs='+',
help='The outputs of pruned model.')
parser.add_argument(
'--save_file', required=True, help='Path to save the new onnx model.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
if len(set(args.output_names)) < len(args.output_names):
print(
"[ERROR] There's dumplicate name in --output_names, which is not allowed."
)
sys.exit(-1)
model = onnx.load(args.model)
# collect all node outputs and graph output
output_tensor_names = set()
for node in model.graph.node:
for out in node.output:
# may contain model output
output_tensor_names.add(out)
# for out in model.graph.output:
# output_tensor_names.add(out.name)
for output_name in args.output_names:
if output_name not in output_tensor_names:
print(
"[ERROR] Cannot find output tensor name '{}' in onnx model graph.".
format(output_name))
sys.exit(-1)
output_node_indices = set() # has output names
output_to_node = dict() # all node outputs
for i, node in enumerate(model.graph.node):
for out in node.output:
output_to_node[out] = i
if out in args.output_names:
output_node_indices.add(i)
# from outputs find all the ancestors
reserved_node_indices = copy.deepcopy(
output_node_indices) # nodes need to keep
reserved_inputs = set() # model input to keep
new_output_node_indices = copy.deepcopy(output_node_indices)
while True and len(new_output_node_indices) > 0:
output_node_indices = copy.deepcopy(new_output_node_indices)
new_output_node_indices = set()
for out_node_idx in output_node_indices:
# backtrace to parenet
for ipt in model.graph.node[out_node_idx].input:
if ipt in output_to_node:
reserved_node_indices.add(output_to_node[ipt])
new_output_node_indices.add(output_to_node[ipt])
else:
reserved_inputs.add(ipt)
num_inputs = len(model.graph.input)
num_outputs = len(model.graph.output)
num_nodes = len(model.graph.node)
print(
f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes"
)
print(f"{len(reserved_node_indices)} node to keep.")
# del node not to keep
for idx in range(num_nodes - 1, -1, -1):
if idx not in reserved_node_indices:
del model.graph.node[idx]
# del graph input not to keep
for idx in range(num_inputs - 1, -1, -1):
if model.graph.input[idx].name not in reserved_inputs:
del model.graph.input[idx]
# del old graph outputs
for i in range(num_outputs):
del model.graph.output[0]
# new graph output as user input
for out in args.output_names:
model.graph.output.extend([onnx.ValueInfoProto(name=out)])
# infer shape
try:
from onnx_infer_shape import SymbolicShapeInference
model = SymbolicShapeInference.infer_shapes(
model,
int_max=2**31 - 1,
auto_merge=True,
guess_output_rank=False,
verbose=1)
except Exception as e:
print(f"skip infer shape step: {e}")
# check onnx model
onnx.checker.check_model(model)
# save onnx model
onnx.save(model, args.save_file)
print("[Finished] The new model saved in {}.".format(args.save_file))
print("[DEBUG INFO] The inputs of new model: {}".format(
[x.name for x in model.graph.input]))
print("[DEBUG INFO] The outputs of new model: {}".format(
[x.name for x in model.graph.output]))
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# rename node to new names
import argparse
import sys
import onnx
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--origin_names',
required=True,
nargs='+',
help='The original name you want to modify.')
parser.add_argument(
'--new_names',
required=True,
nargs='+',
help='The new name you want change to, the number of new_names should be same with the number of origin_names'
)
parser.add_argument(
'--save_file', required=True, help='Path to save the new onnx model.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
if len(set(args.origin_names)) < len(args.origin_names):
print(
"[ERROR] There's dumplicate name in --origin_names, which is not allowed."
)
sys.exit(-1)
if len(set(args.new_names)) < len(args.new_names):
print(
"[ERROR] There's dumplicate name in --new_names, which is not allowed."
)
sys.exit(-1)
if len(args.new_names) != len(args.origin_names):
print(
"[ERROR] Number of --new_names must be same with the number of --origin_names."
)
sys.exit(-1)
model = onnx.load(args.model)
# collect input and all node output
output_tensor_names = set()
for ipt in model.graph.input:
output_tensor_names.add(ipt.name)
for node in model.graph.node:
for out in node.output:
output_tensor_names.add(out)
for origin_name in args.origin_names:
if origin_name not in output_tensor_names:
print(
f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph."
)
sys.exit(-1)
for new_name in args.new_names:
if new_name in output_tensor_names:
print(
"[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
)
sys.exit(-1)
# rename graph input
for i, ipt in enumerate(model.graph.input):
if ipt.name in args.origin_names:
idx = args.origin_names.index(ipt.name)
model.graph.input[i].name = args.new_names[idx]
# rename node input and output
for i, node in enumerate(model.graph.node):
for j, ipt in enumerate(node.input):
if ipt in args.origin_names:
idx = args.origin_names.index(ipt)
model.graph.node[i].input[j] = args.new_names[idx]
for j, out in enumerate(node.output):
if out in args.origin_names:
idx = args.origin_names.index(out)
model.graph.node[i].output[j] = args.new_names[idx]
# rename graph output
for i, out in enumerate(model.graph.output):
if out.name in args.origin_names:
idx = args.origin_names.index(out.name)
model.graph.output[i].name = args.new_names[idx]
# check onnx model
onnx.checker.check_model(model)
# save model
onnx.save(model, args.save_file)
print("[Finished] The new model saved in {}.".format(args.save_file))
print("[DEBUG INFO] The inputs of new model: {}".format(
[x.name for x in model.graph.input]))
print("[DEBUG INFO] The outputs of new model: {}".format(
[x.name for x in model.graph.output]))
#!/usr/bin/env python3
import argparse
import onnxruntime as ort
# onnxruntime optimizer.
# https://onnxruntime.ai/docs/performance/graph-optimizations.html
# https://onnxruntime.ai/docs/api/python/api_summary.html#api
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_in', required=True, type=str, help='Path to onnx model.')
parser.add_argument(
'--opt_level',
required=True,
type=int,
default=0,
choices=[0, 1, 2],
help='Path to onnx model.')
parser.add_argument(
'--model_out', required=True, help='path to save the optimized model.')
parser.add_argument('--debug', default=False, help='output debug info.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
sess_options = ort.SessionOptions()
# Set graph optimization level
print(f"opt level: {args.opt_level}")
if args.opt_level == 0:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
elif args.opt_level == 1:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
else:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = args.model_out
session = ort.InferenceSession(args.model_in, sess_options)
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#2-%E4%BF%AE%E6%94%B9paddle%E6%A8%A1%E5%9E%8B%E8%BE%93%E5%85%A5shape
import argparse
# paddle inference shape
def process_old_ops_desc(program):
"""set matmul op head_number attr to 1 is not exist.
Args:
program (_type_): _description_
"""
for i in range(len(program.blocks[0].ops)):
if program.blocks[0].ops[i].type == "matmul":
if not program.blocks[0].ops[i].has_attr("head_number"):
program.blocks[0].ops[i]._set_attr("head_number", 1)
def infer_shape(program, input_shape_dict):
# 2002002
model_version = program.desc._version()
# 2.2.2
paddle_version = paddle.__version__
major_ver = model_version // 1000000
minor_ver = (model_version - major_ver * 1000000) // 1000
patch_ver = model_version - major_ver * 1000000 - minor_ver * 1000
model_version = "{}.{}.{}".format(major_ver, minor_ver, patch_ver)
if model_version != paddle_version:
print(
f"[WARNING] The model is saved by paddlepaddle v{model_version}, but now your paddlepaddle is version of {paddle_version}, this difference may cause error, it is recommend you reinstall a same version of paddlepaddle for this model"
)
OP_WITHOUT_KERNEL_SET = {
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv',
'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl',
'copy_cross_scope'
}
for k, v in input_shape_dict.items():
program.blocks[0].var(k).desc.set_shape(v)
for i in range(len(program.blocks)):
for j in range(len(program.blocks[0].ops)):
# for ops
if program.blocks[i].ops[j].type in OP_WITHOUT_KERNEL_SET:
print(f"not infer: {program.blocks[i].ops[j].type} op")
continue
print(f"infer: {program.blocks[i].ops[j].type} op")
program.blocks[i].ops[j].desc.infer_shape(program.blocks[i].desc)
def parse_arguments():
# python pd_infer_shape.py --model_dir data/exp/deepspeech2_online/checkpoints \
# --model_filename avg_1.jit.pdmodel\
# --params_filename avg_1.jit.pdiparams \
# --save_dir . \
# --input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_dir',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--model_filename', required=True, help='model.pdmodel.')
parser.add_argument(
'--params_filename', required=True, help='model.pdiparams.')
parser.add_argument(
'--save_dir',
required=True,
help='directory to save the exported model.')
parser.add_argument(
'--input_shape_dict', required=True, help="The new shape information.")
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
import paddle
paddle.enable_static()
import paddle.fluid as fluid
input_shape_dict_str = args.input_shape_dict
input_shape_dict = eval(input_shape_dict_str)
print("Start to load paddle model...")
exe = fluid.Executor(fluid.CPUPlace())
prog, ipts, outs = fluid.io.load_inference_model(
args.model_dir,
exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
process_old_ops_desc(prog)
infer_shape(prog, input_shape_dict)
fluid.io.save_inference_model(
args.save_dir,
ipts,
outs,
exe,
prog,
model_filename=args.model_filename,
params_filename=args.params_filename)
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#1-%E8%A3%81%E5%89%AApaddle%E6%A8%A1%E5%9E%8B
import argparse
import sys
from typing import List
# paddle prune model.
def prepend_feed_ops(program,
feed_target_names: List[str],
feed_holder_name='feed'):
import paddle.fluid.core as core
if len(feed_target_names) == 0:
return
global_block = program.global_block()
feed_var = global_block.create_var(
name=feed_holder_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True, )
for i, name in enumerate(feed_target_names, 0):
if not global_block.has_var(name):
print(
f"The input[{i}]: '{name}' doesn't exist in pruned inference program, which will be ignored in new saved model."
)
continue
out = global_block.var(name)
global_block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i}, )
def append_fetch_ops(program,
fetch_target_names: List[str],
fetch_holder_name='fetch'):
"""in the place, we will add the fetch op
Args:
program (_type_): inference program
fetch_target_names (List[str]): target names
fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'.
"""
import paddle.fluid.core as core
global_block = program.global_block()
fetch_var = global_block.create_var(
name=fetch_holder_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True, )
print(f"the len of fetch_target_names: {len(fetch_target_names)}")
for i, name in enumerate(fetch_target_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i}, )
def insert_fetch(program,
fetch_target_names: List[str],
fetch_holder_name='fetch'):
"""in the place, we will add the fetch op
Args:
program (_type_): inference program
fetch_target_names (List[str]): target names
fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'.
"""
global_block = program.global_block()
# remove fetch
need_to_remove_op_index = []
for i, op in enumerate(global_block.ops):
if op.type == 'fetch':
need_to_remove_op_index.append(i)
for index in reversed(need_to_remove_op_index):
global_block._remove_op(index)
program.desc.flush()
# append new fetch
append_fetch_ops(program, fetch_target_names, fetch_holder_name)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_dir',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--model_filename', required=True, help='model.pdmodel.')
parser.add_argument(
'--params_filename', required=True, help='model.pdiparams.')
parser.add_argument(
'--output_names',
required=True,
help='The outputs of model. sep by comma')
parser.add_argument(
'--save_dir',
required=True,
help='directory to save the exported model.')
parser.add_argument('--debug', default=False, help='output debug info.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
args.output_names = args.output_names.split(",")
if len(set(args.output_names)) < len(args.output_names):
print(
f"[ERROR] There's dumplicate name in --output_names {args.output_names}, which is not allowed."
)
sys.exit(-1)
import paddle
paddle.enable_static()
# hack prepend_feed_ops
paddle.fluid.io.prepend_feed_ops = prepend_feed_ops
import paddle.fluid as fluid
print("start to load paddle model")
exe = fluid.Executor(fluid.CPUPlace())
prog, ipts, outs = fluid.io.load_inference_model(
args.model_dir,
exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
print("start to load insert fetch op")
new_outputs = []
insert_fetch(prog, args.output_names)
for out_name in args.output_names:
new_outputs.append(prog.global_block().var(out_name))
# not equal to paddle.static.save_inference_model
fluid.io.save_inference_model(
args.save_dir,
ipts,
new_outputs,
exe,
prog,
model_filename=args.model_filename,
params_filename=args.params_filename)
if args.debug:
for op in prog.global_block().ops:
print(op)
#!/bin/bash
set -e
if [ $# != 5 ]; then
# local/prune.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 $PWD
echo "usage: $0 model_dir model_filename param_filename outputs_names save_dir"
exit 1
fi
dir=$1
model=$2
param=$3
outputs=$4
save_dir=$5
python local/pd_prune_model.py \
--model_dir $dir \
--model_filename $model \
--params_filename $param \
--output_names $outputs \
--save_dir $save_dir
\ No newline at end of file
#!/bin/bash
if [ $# != 4 ];then
# local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx
echo "usage: $0 model_dir model_name param_name onnx_output_name"
exit 1
fi
dir=$1
model=$2
param=$3
output=$4
pip install paddle2onnx
pip install onnx
# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2
paddle2onnx --model_dir $dir \
--model_filename $model \
--params_filename $param \
--save_file $output \
--enable_dev_version True \
--opset_version 9 \
--enable_onnx_checker True
\ No newline at end of file
# This contains the locations of binarys build required for running the examples.
MAIN_ROOT=`realpath $PWD/../../../../`
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
export PATH=$PATH:$TOOLS_BIN
#!/bin/bash
set -e
. path.sh
stage=0
stop_stage=50
tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz
#tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz
model_prefix=avg_10.jit
#model_prefix=avg_1.jit
model=${model_prefix}.pdmodel
param=${model_prefix}.pdiparams
. utils/parse_options.sh
data=data
exp=exp
mkdir -p $data $exp
dir=$data/exp/deepspeech2_online/checkpoints
# wenetspeech or aishell
model_type=$(echo $tarfile | cut -d '_' -f 4)
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
test -f $data/$tarfile || wget -P $data -c https://paddlespeech.bj.bcebos.com/s2t/$model_type/asr0/$tarfile
# wenetspeech ds2 model
pushd $data
tar zxvf $tarfile
popd
# ds2 model demo inputs
pushd $exp
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle
popd
fi
output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then
# prune model by outputs
mkdir -p $exp/prune
# prune model deps on output_names.
./local/prune.sh $dir $model $param $output_names $exp/prune
fi
# aishell rnn hidden is 1024
# wenetspeech rnn hiddn is 2048
if [ $model_type == 'aishell' ];then
input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
elif [ $model_type == 'wenetspeech' ];then
input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 2048], 'chunk_state_h_box':[5,1,2048]}"
else
echo "not support: $model_type"
exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ];then
# infer shape by new shape
mkdir -p $exp/shape
echo $input_shape_dict
python3 local/pd_infer_shape.py \
--model_dir $dir \
--model_filename $model \
--params_filename $param \
--save_dir $exp/shape \
--input_shape_dict="${input_shape_dict}"
fi
input_file=$exp/static_ds2online_inputs.pickle
test -e $input_file
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then
# to onnx
./local/tonnx.sh $dir $model $param $exp/model.onnx
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.onnx
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] ;then
# ort graph optmize
./local/ort_opt.py --model_in $exp/model.onnx --opt_level 0 --model_out $exp/model.ort.opt.onnx
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.ort.opt.onnx
fi
# aishell rnn hidden is 1024
# wenetspeech rnn hiddn is 2048
if [ $model_type == 'aishell' ];then
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
elif [ $model_type == 'wenetspeech' ];then
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,2048 chunk_state_h_box:5,1,2048"
else
echo "not support: $model_type"
exit -1
fi
if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ] ;then
# wenetspeech ds2 model execed 2GB limit, will error.
# simplifying onnx model
./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape"
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.opt.onnx
fi
../../../../utils/
\ No newline at end of file
...@@ -15,7 +15,6 @@ Result: ...@@ -15,7 +15,6 @@ Result:
========================================================================== test session starts ========================================================================== ========================================================================== test session starts ==========================================================================
platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0 platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0
benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddleaudio
plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0 plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0
collected 4 items collected 4 items
......
...@@ -17,15 +17,17 @@ import urllib.request ...@@ -17,15 +17,17 @@ import urllib.request
import librosa import librosa
import numpy as np import numpy as np
import paddle import paddle
import paddleaudio
import torch import torch
import torchaudio import torchaudio
import paddlespeech.audio
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
if not os.path.isfile(os.path.basename(wav_url)): if not os.path.isfile(os.path.basename(wav_url)):
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) waveform, sr = paddlespeech.audio.load(
os.path.abspath(os.path.basename(wav_url)))
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
...@@ -55,7 +57,7 @@ def enable_gpu_device(): ...@@ -55,7 +57,7 @@ def enable_gpu_device():
paddle.set_device('gpu') paddle.set_device('gpu')
log_mel_extractor = paddleaudio.features.LogMelSpectrogram( log_mel_extractor = paddlespeech.audio.features.LogMelSpectrogram(
**mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype) **mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype)
...@@ -65,20 +67,20 @@ def log_melspectrogram(): ...@@ -65,20 +67,20 @@ def log_melspectrogram():
def test_log_melspect_cpu(benchmark): def test_log_melspect_cpu(benchmark):
enable_cpu_device() enable_cpu_device()
feature_paddleaudio = benchmark(log_melspectrogram) feature_audio = benchmark(log_melspectrogram)
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
feature_librosa, feature_paddleaudio, decimal=3) feature_librosa, feature_audio, decimal=3)
def test_log_melspect_gpu(benchmark): def test_log_melspect_gpu(benchmark):
enable_gpu_device() enable_gpu_device()
feature_paddleaudio = benchmark(log_melspectrogram) feature_audio = benchmark(log_melspectrogram)
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
feature_librosa, feature_paddleaudio, decimal=2) feature_librosa, feature_audio, decimal=2)
mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
...@@ -102,11 +104,11 @@ def test_log_melspect_cpu_torchaudio(benchmark): ...@@ -102,11 +104,11 @@ def test_log_melspect_cpu_torchaudio(benchmark):
waveform_tensor_torch = waveform_tensor_torch.to('cpu') waveform_tensor_torch = waveform_tensor_torch.to('cpu')
amplitude_to_DB = amplitude_to_DB.to('cpu') amplitude_to_DB = amplitude_to_DB.to('cpu')
feature_paddleaudio = benchmark(log_melspectrogram_torchaudio) feature_audio = benchmark(log_melspectrogram_torchaudio)
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
feature_librosa, feature_paddleaudio, decimal=3) feature_librosa, feature_audio, decimal=3)
def test_log_melspect_gpu_torchaudio(benchmark): def test_log_melspect_gpu_torchaudio(benchmark):
......
...@@ -17,15 +17,17 @@ import urllib.request ...@@ -17,15 +17,17 @@ import urllib.request
import librosa import librosa
import numpy as np import numpy as np
import paddle import paddle
import paddleaudio
import torch import torch
import torchaudio import torchaudio
import paddlespeech.audio
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
if not os.path.isfile(os.path.basename(wav_url)): if not os.path.isfile(os.path.basename(wav_url)):
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) waveform, sr = paddlespeech.audio.load(
os.path.abspath(os.path.basename(wav_url)))
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
...@@ -55,7 +57,7 @@ def enable_gpu_device(): ...@@ -55,7 +57,7 @@ def enable_gpu_device():
paddle.set_device('gpu') paddle.set_device('gpu')
mel_extractor = paddleaudio.features.MelSpectrogram( mel_extractor = paddlespeech.audio.features.MelSpectrogram(
**mel_conf, f_min=0.0, dtype=waveform_tensor.dtype) **mel_conf, f_min=0.0, dtype=waveform_tensor.dtype)
...@@ -65,18 +67,18 @@ def melspectrogram(): ...@@ -65,18 +67,18 @@ def melspectrogram():
def test_melspect_cpu(benchmark): def test_melspect_cpu(benchmark):
enable_cpu_device() enable_cpu_device()
feature_paddleaudio = benchmark(melspectrogram) feature_audio = benchmark(melspectrogram)
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
feature_librosa, feature_paddleaudio, decimal=3) feature_librosa, feature_audio, decimal=3)
def test_melspect_gpu(benchmark): def test_melspect_gpu(benchmark):
enable_gpu_device() enable_gpu_device()
feature_paddleaudio = benchmark(melspectrogram) feature_audio = benchmark(melspectrogram)
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
feature_librosa, feature_paddleaudio, decimal=3) feature_librosa, feature_audio, decimal=3)
mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
...@@ -91,10 +93,10 @@ def test_melspect_cpu_torchaudio(benchmark): ...@@ -91,10 +93,10 @@ def test_melspect_cpu_torchaudio(benchmark):
global waveform_tensor_torch, mel_extractor_torchaudio global waveform_tensor_torch, mel_extractor_torchaudio
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu') mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu')
waveform_tensor_torch = waveform_tensor_torch.to('cpu') waveform_tensor_torch = waveform_tensor_torch.to('cpu')
feature_paddleaudio = benchmark(melspectrogram_torchaudio) feature_audio = benchmark(melspectrogram_torchaudio)
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
feature_librosa, feature_paddleaudio, decimal=3) feature_librosa, feature_audio, decimal=3)
def test_melspect_gpu_torchaudio(benchmark): def test_melspect_gpu_torchaudio(benchmark):
......
...@@ -17,15 +17,17 @@ import urllib.request ...@@ -17,15 +17,17 @@ import urllib.request
import librosa import librosa
import numpy as np import numpy as np
import paddle import paddle
import paddleaudio
import torch import torch
import torchaudio import torchaudio
import paddlespeech.audio
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
if not os.path.isfile(os.path.basename(wav_url)): if not os.path.isfile(os.path.basename(wav_url)):
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) waveform, sr = paddlespeech.audio.load(
os.path.abspath(os.path.basename(wav_url)))
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
...@@ -64,7 +66,7 @@ def enable_gpu_device(): ...@@ -64,7 +66,7 @@ def enable_gpu_device():
paddle.set_device('gpu') paddle.set_device('gpu')
mfcc_extractor = paddleaudio.features.MFCC( mfcc_extractor = paddlespeech.audio.features.MFCC(
**mfcc_conf, f_min=0.0, dtype=waveform_tensor.dtype) **mfcc_conf, f_min=0.0, dtype=waveform_tensor.dtype)
...@@ -74,18 +76,18 @@ def mfcc(): ...@@ -74,18 +76,18 @@ def mfcc():
def test_mfcc_cpu(benchmark): def test_mfcc_cpu(benchmark):
enable_cpu_device() enable_cpu_device()
feature_paddleaudio = benchmark(mfcc) feature_audio = benchmark(mfcc)
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
feature_librosa, feature_paddleaudio, decimal=3) feature_librosa, feature_audio, decimal=3)
def test_mfcc_gpu(benchmark): def test_mfcc_gpu(benchmark):
enable_gpu_device() enable_gpu_device()
feature_paddleaudio = benchmark(mfcc) feature_audio = benchmark(mfcc)
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
feature_librosa, feature_paddleaudio, decimal=3) feature_librosa, feature_audio, decimal=3)
del mel_conf_torchaudio['sample_rate'] del mel_conf_torchaudio['sample_rate']
...@@ -103,10 +105,10 @@ def test_mfcc_cpu_torchaudio(benchmark): ...@@ -103,10 +105,10 @@ def test_mfcc_cpu_torchaudio(benchmark):
mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cpu') mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cpu')
waveform_tensor_torch = waveform_tensor_torch.to('cpu') waveform_tensor_torch = waveform_tensor_torch.to('cpu')
feature_paddleaudio = benchmark(mfcc_torchaudio) feature_audio = benchmark(mfcc_torchaudio)
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
feature_librosa, feature_paddleaudio, decimal=3) feature_librosa, feature_audio, decimal=3)
def test_mfcc_gpu_torchaudio(benchmark): def test_mfcc_gpu_torchaudio(benchmark):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -16,16 +16,16 @@ import os ...@@ -16,16 +16,16 @@ import os
import unittest import unittest
import numpy as np import numpy as np
import paddleaudio
import soundfile as sf import soundfile as sf
import paddlespeech.audio
from ..base import BackendTest from ..base import BackendTest
class TestIO(BackendTest): class TestIO(BackendTest):
def test_load_mono_channel(self): def test_load_mono_channel(self):
sf_data, sf_sr = sf.read(self.files[0]) sf_data, sf_sr = sf.read(self.files[0])
pa_data, pa_sr = paddleaudio.load( pa_data, pa_sr = paddlespeech.audio.load(
self.files[0], normal=False, dtype='float64') self.files[0], normal=False, dtype='float64')
self.assertEqual(sf_data.dtype, pa_data.dtype) self.assertEqual(sf_data.dtype, pa_data.dtype)
...@@ -35,7 +35,7 @@ class TestIO(BackendTest): ...@@ -35,7 +35,7 @@ class TestIO(BackendTest):
def test_load_multi_channels(self): def test_load_multi_channels(self):
sf_data, sf_sr = sf.read(self.files[1]) sf_data, sf_sr = sf.read(self.files[1])
sf_data = sf_data.T # Channel dim first sf_data = sf_data.T # Channel dim first
pa_data, pa_sr = paddleaudio.load( pa_data, pa_sr = paddlespeech.audio.load(
self.files[1], mono=False, normal=False, dtype='float64') self.files[1], mono=False, normal=False, dtype='float64')
self.assertEqual(sf_data.dtype, pa_data.dtype) self.assertEqual(sf_data.dtype, pa_data.dtype)
...@@ -49,7 +49,7 @@ class TestIO(BackendTest): ...@@ -49,7 +49,7 @@ class TestIO(BackendTest):
pa_tmp_file = 'pa_tmp.wav' pa_tmp_file = 'pa_tmp.wav'
sf.write(sf_tmp_file, waveform, sr) sf.write(sf_tmp_file, waveform, sr)
paddleaudio.save(waveform, sr, pa_tmp_file) paddlespeech.audio.save(waveform, sr, pa_tmp_file)
self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
for file in [sf_tmp_file, pa_tmp_file]: for file in [sf_tmp_file, pa_tmp_file]:
...@@ -62,7 +62,7 @@ class TestIO(BackendTest): ...@@ -62,7 +62,7 @@ class TestIO(BackendTest):
pa_tmp_file = 'pa_tmp.wav' pa_tmp_file = 'pa_tmp.wav'
sf.write(sf_tmp_file, waveform.T, sr) sf.write(sf_tmp_file, waveform.T, sr)
paddleaudio.save(waveform.T, sr, pa_tmp_file) paddlespeech.audio.save(waveform.T, sr, pa_tmp_file)
self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
for file in [sf_tmp_file, pa_tmp_file]: for file in [sf_tmp_file, pa_tmp_file]:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -17,7 +17,8 @@ import urllib.request ...@@ -17,7 +17,8 @@ import urllib.request
import numpy as np import numpy as np
import paddle import paddle
from paddleaudio import load
from paddlespeech.audio import load
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
......
...@@ -15,9 +15,9 @@ import unittest ...@@ -15,9 +15,9 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddleaudio.functional.window import get_window
from .base import FeatTest from .base import FeatTest
from paddlespeech.audio.functional.window import get_window
from paddlespeech.s2t.transform.spectrogram import IStft from paddlespeech.s2t.transform.spectrogram import IStft
from paddlespeech.s2t.transform.spectrogram import Stft from paddlespeech.s2t.transform.spectrogram import Stft
......
...@@ -15,10 +15,10 @@ import unittest ...@@ -15,10 +15,10 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddleaudio
import torch import torch
import torchaudio import torchaudio
import paddlespeech.audio
from .base import FeatTest from .base import FeatTest
...@@ -40,17 +40,17 @@ class TestKaldi(FeatTest): ...@@ -40,17 +40,17 @@ class TestKaldi(FeatTest):
self.window_size, periodic=False, self.window_size, periodic=False,
dtype=eval(f'torch.{self.dtype}')).pow(0.85) dtype=eval(f'torch.{self.dtype}')).pow(0.85)
p_hann_window = paddleaudio.functional.window.get_window( p_hann_window = paddlespeech.audio.functional.window.get_window(
'hann', 'hann',
self.window_size, self.window_size,
fftbins=False, fftbins=False,
dtype=eval(f'paddle.{self.dtype}')) dtype=eval(f'paddle.{self.dtype}'))
p_hamm_window = paddleaudio.functional.window.get_window( p_hamm_window = paddlespeech.audio.functional.window.get_window(
'hamming', 'hamming',
self.window_size, self.window_size,
fftbins=False, fftbins=False,
dtype=eval(f'paddle.{self.dtype}')) dtype=eval(f'paddle.{self.dtype}'))
p_povey_window = paddleaudio.functional.window.get_window( p_povey_window = paddlespeech.audio.functional.window.get_window(
'hann', 'hann',
self.window_size, self.window_size,
fftbins=False, fftbins=False,
...@@ -63,7 +63,7 @@ class TestKaldi(FeatTest): ...@@ -63,7 +63,7 @@ class TestKaldi(FeatTest):
def test_fbank(self): def test_fbank(self):
ta_features = torchaudio.compliance.kaldi.fbank( ta_features = torchaudio.compliance.kaldi.fbank(
torch.from_numpy(self.waveform.astype(self.dtype))) torch.from_numpy(self.waveform.astype(self.dtype)))
pa_features = paddleaudio.compliance.kaldi.fbank( pa_features = paddlespeech.audio.compliance.kaldi.fbank(
paddle.to_tensor(self.waveform.astype(self.dtype))) paddle.to_tensor(self.waveform.astype(self.dtype)))
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
ta_features, pa_features, decimal=4) ta_features, pa_features, decimal=4)
...@@ -71,7 +71,7 @@ class TestKaldi(FeatTest): ...@@ -71,7 +71,7 @@ class TestKaldi(FeatTest):
def test_mfcc(self): def test_mfcc(self):
ta_features = torchaudio.compliance.kaldi.mfcc( ta_features = torchaudio.compliance.kaldi.mfcc(
torch.from_numpy(self.waveform.astype(self.dtype))) torch.from_numpy(self.waveform.astype(self.dtype)))
pa_features = paddleaudio.compliance.kaldi.mfcc( pa_features = paddlespeech.audio.compliance.kaldi.mfcc(
paddle.to_tensor(self.waveform.astype(self.dtype))) paddle.to_tensor(self.waveform.astype(self.dtype)))
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
ta_features, pa_features, decimal=4) ta_features, pa_features, decimal=4)
......
...@@ -16,10 +16,10 @@ import unittest ...@@ -16,10 +16,10 @@ import unittest
import librosa import librosa
import numpy as np import numpy as np
import paddle import paddle
import paddleaudio
from paddleaudio.functional.window import get_window
import paddlespeech.audio
from .base import FeatTest from .base import FeatTest
from paddlespeech.audio.functional.window import get_window
class TestLibrosa(FeatTest): class TestLibrosa(FeatTest):
...@@ -117,7 +117,7 @@ class TestLibrosa(FeatTest): ...@@ -117,7 +117,7 @@ class TestLibrosa(FeatTest):
htk=False, htk=False,
norm='slaney', norm='slaney',
dtype=self.waveform.dtype, ) dtype=self.waveform.dtype, )
feature_compliance = paddleaudio.compliance.librosa.compute_fbank_matrix( feature_compliance = paddlespeech.audio.compliance.librosa.compute_fbank_matrix(
sr=self.sr, sr=self.sr,
n_fft=self.n_fft, n_fft=self.n_fft,
n_mels=self.n_mels, n_mels=self.n_mels,
...@@ -127,7 +127,7 @@ class TestLibrosa(FeatTest): ...@@ -127,7 +127,7 @@ class TestLibrosa(FeatTest):
norm='slaney', norm='slaney',
dtype=self.waveform.dtype, ) dtype=self.waveform.dtype, )
x = paddle.to_tensor(self.waveform) x = paddle.to_tensor(self.waveform)
feature_functional = paddleaudio.functional.compute_fbank_matrix( feature_functional = paddlespeech.audio.functional.compute_fbank_matrix(
sr=self.sr, sr=self.sr,
n_fft=self.n_fft, n_fft=self.n_fft,
n_mels=self.n_mels, n_mels=self.n_mels,
...@@ -156,8 +156,8 @@ class TestLibrosa(FeatTest): ...@@ -156,8 +156,8 @@ class TestLibrosa(FeatTest):
n_mels=self.n_mels, n_mels=self.n_mels,
fmin=self.fmin) fmin=self.fmin)
# paddleaudio.compliance.librosa: # paddlespeech.audio.compliance.librosa:
feature_compliance = paddleaudio.compliance.librosa.melspectrogram( feature_compliance = paddlespeech.audio.compliance.librosa.melspectrogram(
x=self.waveform, x=self.waveform,
sr=self.sr, sr=self.sr,
window_size=self.n_fft, window_size=self.n_fft,
...@@ -166,10 +166,10 @@ class TestLibrosa(FeatTest): ...@@ -166,10 +166,10 @@ class TestLibrosa(FeatTest):
fmin=self.fmin, fmin=self.fmin,
to_db=False) to_db=False)
# paddleaudio.features.layer # paddlespeech.audio.features.layer
x = paddle.to_tensor( x = paddle.to_tensor(
self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
feature_extractor = paddleaudio.features.MelSpectrogram( feature_extractor = paddlespeech.audio.features.MelSpectrogram(
sr=self.sr, sr=self.sr,
n_fft=self.n_fft, n_fft=self.n_fft,
hop_length=self.hop_length, hop_length=self.hop_length,
...@@ -198,8 +198,8 @@ class TestLibrosa(FeatTest): ...@@ -198,8 +198,8 @@ class TestLibrosa(FeatTest):
fmin=self.fmin) fmin=self.fmin)
feature_librosa = librosa.power_to_db(feature_librosa, top_db=None) feature_librosa = librosa.power_to_db(feature_librosa, top_db=None)
# paddleaudio.compliance.librosa: # paddlespeech.audio.compliance.librosa:
feature_compliance = paddleaudio.compliance.librosa.melspectrogram( feature_compliance = paddlespeech.audio.compliance.librosa.melspectrogram(
x=self.waveform, x=self.waveform,
sr=self.sr, sr=self.sr,
window_size=self.n_fft, window_size=self.n_fft,
...@@ -207,10 +207,10 @@ class TestLibrosa(FeatTest): ...@@ -207,10 +207,10 @@ class TestLibrosa(FeatTest):
n_mels=self.n_mels, n_mels=self.n_mels,
fmin=self.fmin) fmin=self.fmin)
# paddleaudio.features.layer # paddlespeech.audio.features.layer
x = paddle.to_tensor( x = paddle.to_tensor(
self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
feature_extractor = paddleaudio.features.LogMelSpectrogram( feature_extractor = paddlespeech.audio.features.LogMelSpectrogram(
sr=self.sr, sr=self.sr,
n_fft=self.n_fft, n_fft=self.n_fft,
hop_length=self.hop_length, hop_length=self.hop_length,
...@@ -243,8 +243,8 @@ class TestLibrosa(FeatTest): ...@@ -243,8 +243,8 @@ class TestLibrosa(FeatTest):
n_mels=self.n_mels, n_mels=self.n_mels,
fmin=self.fmin) fmin=self.fmin)
# paddleaudio.compliance.librosa: # paddlespeech.audio.compliance.librosa:
feature_compliance = paddleaudio.compliance.librosa.mfcc( feature_compliance = paddlespeech.audio.compliance.librosa.mfcc(
x=self.waveform, x=self.waveform,
sr=self.sr, sr=self.sr,
n_mfcc=self.n_mfcc, n_mfcc=self.n_mfcc,
...@@ -257,10 +257,10 @@ class TestLibrosa(FeatTest): ...@@ -257,10 +257,10 @@ class TestLibrosa(FeatTest):
fmin=self.fmin, fmin=self.fmin,
top_db=self.top_db) top_db=self.top_db)
# paddleaudio.features.layer # paddlespeech.audio.features.layer
x = paddle.to_tensor( x = paddle.to_tensor(
self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
feature_extractor = paddleaudio.features.MFCC( feature_extractor = paddlespeech.audio.features.MFCC(
sr=self.sr, sr=self.sr,
n_mfcc=self.n_mfcc, n_mfcc=self.n_mfcc,
n_fft=self.n_fft, n_fft=self.n_fft,
......
...@@ -15,8 +15,8 @@ import unittest ...@@ -15,8 +15,8 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddleaudio
import paddlespeech.audio
from .base import FeatTest from .base import FeatTest
from paddlespeech.s2t.transform.spectrogram import LogMelSpectrogram from paddlespeech.s2t.transform.spectrogram import LogMelSpectrogram
...@@ -33,8 +33,7 @@ class TestLogMelSpectrogram(FeatTest): ...@@ -33,8 +33,7 @@ class TestLogMelSpectrogram(FeatTest):
ps_res = ps_melspect(self.waveform.T).squeeze(1).T ps_res = ps_melspect(self.waveform.T).squeeze(1).T
x = paddle.to_tensor(self.waveform) x = paddle.to_tensor(self.waveform)
# paddlespeech.s2t的特征存在幅度谱和功率谱滥用的情况 ps_melspect = paddlespeech.audio.features.LogMelSpectrogram(
ps_melspect = paddleaudio.features.LogMelSpectrogram(
self.sr, self.sr,
self.n_fft, self.n_fft,
self.hop_length, self.hop_length,
......
...@@ -15,8 +15,8 @@ import unittest ...@@ -15,8 +15,8 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddleaudio
import paddlespeech.audio
from .base import FeatTest from .base import FeatTest
from paddlespeech.s2t.transform.spectrogram import Spectrogram from paddlespeech.s2t.transform.spectrogram import Spectrogram
...@@ -31,7 +31,7 @@ class TestSpectrogram(FeatTest): ...@@ -31,7 +31,7 @@ class TestSpectrogram(FeatTest):
ps_res = ps_spect(self.waveform.T).squeeze(1).T # Magnitude ps_res = ps_spect(self.waveform.T).squeeze(1).T # Magnitude
x = paddle.to_tensor(self.waveform) x = paddle.to_tensor(self.waveform)
pa_spect = paddleaudio.features.Spectrogram( pa_spect = paddlespeech.audio.features.Spectrogram(
self.n_fft, self.hop_length, power=1.0) self.n_fft, self.hop_length, power=1.0)
pa_res = pa_spect(x).squeeze(0).numpy() pa_res = pa_spect(x).squeeze(0).numpy()
......
...@@ -15,9 +15,9 @@ import unittest ...@@ -15,9 +15,9 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddleaudio.functional.window import get_window
from .base import FeatTest from .base import FeatTest
from paddlespeech.audio.functional.window import get_window
from paddlespeech.s2t.transform.spectrogram import Stft from paddlespeech.s2t.transform.spectrogram import Stft
......
...@@ -22,6 +22,9 @@ paddlespeech asr --model deepspeech2online_wenetspeech --input ./zh.wav ...@@ -22,6 +22,9 @@ paddlespeech asr --model deepspeech2online_wenetspeech --input ./zh.wav
paddlespeech asr --model deepspeech2online_aishell --input ./zh.wav paddlespeech asr --model deepspeech2online_aishell --input ./zh.wav
paddlespeech asr --model deepspeech2offline_librispeech --lang en --input ./en.wav paddlespeech asr --model deepspeech2offline_librispeech --lang en --input ./en.wav
# Support editing num_decoding_left_chunks
paddlespeech asr --model conformer_online_wenetspeech --num_decoding_left_chunks 3 --input ./zh.wav
# long audio restriction # long audio restriction
{ {
wget -c https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/test_long_audio_01.wav wget -c https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/test_long_audio_01.wav
......
...@@ -747,7 +747,7 @@ def num2chn(number_string, ...@@ -747,7 +747,7 @@ def num2chn(number_string,
previous_symbol, (CNU, type(None))): previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and ( if next_symbol.power != 1 and (
(previous_symbol is None) or (previous_symbol is None) or
(previous_symbol.power != 1)): (previous_symbol.power != 1)): # noqa: E129
result_symbols[i] = liang result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output # if big is True, '两' will not be used and `alt_two` has no impact on output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册