未验证 提交 7cab869d 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #3197 from PaddlePaddle/speechx

[engine] merge speechx
...@@ -136,7 +136,7 @@ pull_request_rules: ...@@ -136,7 +136,7 @@ pull_request_rules:
add: ["Docker"] add: ["Docker"]
- name: "auto add label=Deployment" - name: "auto add label=Deployment"
conditions: conditions:
- files~=^speechx/ - files~=^runtime/
actions: actions:
label: label:
add: ["Deployment"] add: ["Deployment"]
...@@ -3,8 +3,12 @@ repos: ...@@ -3,8 +3,12 @@ repos:
rev: v0.16.0 rev: v0.16.0
hooks: hooks:
- id: yapf - id: yapf
files: \.py$ name: yapf
exclude: (?=third_party).*(\.py)$ language: python
entry: yapf
args: [-i, -vv]
types: [python]
exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: a11d9314b22d8f8c7556443875b731ef05965464 rev: a11d9314b22d8f8c7556443875b731ef05965464
...@@ -31,7 +35,7 @@ repos: ...@@ -31,7 +35,7 @@ repos:
- --ignore=E501,E228,E226,E261,E266,E128,E402,W503 - --ignore=E501,E228,E226,E261,E266,E128,E402,W503
- --builtins=G,request - --builtins=G,request
- --jobs=1 - --jobs=1
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$
- repo : https://github.com/Lucas-C/pre-commit-hooks - repo : https://github.com/Lucas-C/pre-commit-hooks
rev: v1.0.1 rev: v1.0.1
...@@ -53,16 +57,16 @@ repos: ...@@ -53,16 +57,16 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
- id: cpplint - id: cpplint
name: cpplint name: cpplint
description: Static code analysis of C/C++ files description: Static code analysis of C/C++ files
language: python language: python
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ exclude: (?=runtime/engine/kaldi|runtime/engine/common/matrix|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:
- id: reorder-python-imports - id: reorder-python-imports
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$ exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$
...@@ -193,7 +193,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision ...@@ -193,7 +193,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation. - 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation.
- 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](./demos/speech_ssl), Support ASR and Feature Extraction. - 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](./demos/speech_ssl), Support ASR and Feature Extraction.
- 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660). - 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660).
- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](./speechx/examples/u2pp_ol/wenetspeech). - 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/runtime/examples/u2pp_ol/wenetspeech).
- 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3). - 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3).
- 🔥 2022.10.26: Add [Prosody Prediction](./examples/other/rhy) for TTS. - 🔥 2022.10.26: Add [Prosody Prediction](./examples/other/rhy) for TTS.
- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend. - 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
......
engine/common/base/flags.h
engine/common/base/log.h
tools/valgrind*
*log
fc_patch/*
test
# >=3.17 support -DCMAKE_FIND_DEBUG_MODE=ON
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0077.cmake")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
include(system)
project(paddlespeech VERSION 0.1)
set(PPS_VERSION_MAJOR 1)
set(PPS_VERSION_MINOR 0)
set(PPS_VERSION_PATCH 0)
set(PPS_VERSION "${PPS_VERSION_MAJOR}.${PPS_VERSION_MINOR}.${PPS_VERSION_PATCH}")
# compiler option
# Keep the same with openfst, -fPIC or -fpic
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_FIND_DEBUG_MODE OFF)
set(PPS_CXX_STANDARD 14)
# set std-14
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
# Ninja Generator will set CMAKE_BUILD_TYPE to Debug
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" FORCE)
endif()
# find_* e.g. find_library work when Cross-Compiling
if(ANDROID)
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
endif()
if(BUILD_IN_MACOS)
add_definitions("-DOS_MACOSX")
endif()
# install dir into `build/install`
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install)
include(FetchContent)
include(ExternalProject)
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
###############################################################################
# Option Configurations
###############################################################################
# https://github.com/google/brotli/pull/655
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_PPS_DEBUG "debug option" OFF)
if (WITH_PPS_DEBUG)
add_definitions("-DPPS_DEBUG")
endif()
option(WITH_ASR "build asr" ON)
option(WITH_CLS "build cls" ON)
option(WITH_VAD "build vad" ON)
option(WITH_GPU "NNet using GPU." OFF)
option(WITH_PROFILING "enable c++ profling" OFF)
option(WITH_TESTING "unit test" ON)
option(WITH_ONNX "u2 support onnx runtime" OFF)
###############################################################################
# Include Third Party
###############################################################################
include(gflags)
include(glog)
include(pybind)
#onnx
if(WITH_ONNX)
add_definitions(-DUSE_ONNX)
endif()
# gtest
if(WITH_TESTING)
include(gtest) # download, build, install gtest
endif()
# fastdeploy
include(fastdeploy)
if(WITH_ASR)
# openfst
include(openfst)
add_dependencies(openfst gflags extern_glog)
endif()
###############################################################################
# Find Package
###############################################################################
# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207
find_package(Threads REQUIRED)
if(WITH_ASR)
# https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3
find_package(Python3 COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG)
if(Python3_FOUND)
message(STATUS "Python3_FOUND = ${Python3_FOUND}")
message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}")
message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}")
message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}")
message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}")
set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE)
set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE)
endif()
message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}")
include_directories(${PYTHON_INCLUDE_DIR})
if(pybind11_FOUND)
message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}")
message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}")
message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}")
endif()
# paddle libpaddle.so
# paddle include and link option
# -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so
set(EXECUTE_COMMAND "import os"
"import paddle"
"include_dir = paddle.sysconfig.get_include()"
"paddle_dir=os.path.split(include_dir)[0]"
"libs_dir=os.path.join(paddle_dir, 'libs')"
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
"out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir])"
"out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out)"
)
execute_process(
COMMAND python -c "${EXECUTE_COMMAND}"
OUTPUT_VARIABLE PADDLE_LINK_FLAGS
RESULT_VARIABLE SUCESS)
message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS})
string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS)
# paddle compile option
# -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include
set(EXECUTE_COMMAND "import paddle"
"include_dir = paddle.sysconfig.get_include()"
"print(f\"-I{include_dir}\")"
)
execute_process(
COMMAND python -c "${EXECUTE_COMMAND}"
OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS)
message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS})
string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
# for LD_LIBRARY_PATH
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
set(EXECUTE_COMMAND "import os"
"import paddle"
"include_dir=paddle.sysconfig.get_include()"
"paddle_dir=os.path.split(include_dir)[0]"
"libs_dir=os.path.join(paddle_dir, 'libs')"
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
"out=':'.join([libs_dir, fluid_dir]); print(out)"
)
execute_process(
COMMAND python -c "${EXECUTE_COMMAND}"
OUTPUT_VARIABLE PADDLE_LIB_DIRS)
message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS})
endif()
include(summary)
###############################################################################
# Add local library
###############################################################################
set(ENGINE_ROOT ${CMAKE_SOURCE_DIR}/engine)
add_subdirectory(engine)
###############################################################################
# CPack library
###############################################################################
# build a CPack driven installer package
include (InstallRequiredSystemLibraries)
set(CPACK_PACKAGE_NAME "paddlespeech_library")
set(CPACK_PACKAGE_VENDOR "paddlespeech")
set(CPACK_PACKAGE_VERSION_MAJOR 1)
set(CPACK_PACKAGE_VERSION_MINOR 0)
set(CPACK_PACKAGE_VERSION_PATCH 0)
set(CPACK_PACKAGE_DESCRIPTION "paddlespeech library")
set(CPACK_PACKAGE_CONTACT "paddlespeech@baidu.com")
set(CPACK_SOURCE_GENERATOR "TGZ")
include (CPack)
# SpeechX -- All in One Speech Task Inference
## Environment ## Environment
...@@ -9,7 +8,7 @@ We develop under: ...@@ -9,7 +8,7 @@ We develop under:
* gcc/g++/gfortran - 8.2.0 * gcc/g++/gfortran - 8.2.0
* cmake - 3.16.0 * cmake - 3.16.0
> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx. > Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build engine.
> We make sure all things work fun under docker, and recommend using it to develop and deploy. > We make sure all things work fun under docker, and recommend using it to develop and deploy.
...@@ -33,7 +32,7 @@ docker run --privileged --net=host --ipc=host -it --rm -v /path/to/paddlespeech ...@@ -33,7 +32,7 @@ docker run --privileged --net=host --ipc=host -it --rm -v /path/to/paddlespeech
bash tools/venv.sh bash tools/venv.sh
``` ```
2. Build `speechx` and `examples`. 2. Build `engine` and `examples`.
For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version. For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version.
For example: For example:
...@@ -113,3 +112,11 @@ apt-get install gfortran-8 ...@@ -113,3 +112,11 @@ apt-get install gfortran-8
4. `Undefined reference to '_gfortran_concat_string'` 4. `Undefined reference to '_gfortran_concat_string'`
using gcc 8.2, gfortran 8.2. using gcc 8.2, gfortran 8.2.
5. `./boost/python/detail/wrap_python.hpp:57:11: fatal error: pyconfig.h: No such file or directory`
```
apt-get install python3-dev
```
for more info please see [here](https://github.com/okfn/piati/issues/65).
#!/usr/bin/env bash
set -xe
BUILD_ROOT=build/Linux
BUILD_DIR=${BUILD_ROOT}/x86_64
mkdir -p ${BUILD_DIR}
BUILD_TYPE=Release
#BUILD_TYPE=Debug
BUILD_SO=OFF
BUILD_ONNX=ON
BUILD_ASR=ON
BUILD_CLS=ON
BUILD_VAD=ON
PPS_DEBUG=OFF
FASTDEPLOY_INSTALL_DIR=""
# the build script had verified in the paddlepaddle docker image.
# please follow the instruction below to install PaddlePaddle image.
# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
#cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON -DFASTDEPLOY_INSTALL_DIR=/workspace/zhanghui/paddle/FastDeploy/build/Android/arm64-v8a-api-21/install
cmake -B ${BUILD_DIR} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DBUILD_SHARED_LIBS=${BUILD_SO} \
-DWITH_ONNX=${BUILD_ONNX} \
-DWITH_ASR=${BUILD_ASR} \
-DWITH_CLS=${BUILD_CLS} \
-DWITH_VAD=${BUILD_VAD} \
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
-DWITH_PPS_DEBUG=${PPS_DEBUG}
cmake --build ${BUILD_DIR} -j
#!/bin/bash
set -ex
ANDROID_NDK=/mnt/masimeng/workspace/software/android-ndk-r25b/
# Setting up Android toolchanin
ANDROID_ABI=arm64-v8a # 'arm64-v8a', 'armeabi-v7a'
ANDROID_PLATFORM="android-21" # API >= 21
ANDROID_STL=c++_shared # 'c++_shared', 'c++_static'
ANDROID_TOOLCHAIN=clang # 'clang' only
TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
# Create build directory
BUILD_ROOT=build/Android
BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21
FASTDEPLOY_INSTALL_DIR="/mnt/masimeng/workspace/FastDeploy/build/Android/arm64-v8a-api-21/install"
mkdir -p ${BUILD_DIR}
cd ${BUILD_DIR}
# CMake configuration with Android toolchain
cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DANDROID_ABI=${ANDROID_ABI} \
-DANDROID_NDK=${ANDROID_NDK} \
-DANDROID_PLATFORM=${ANDROID_PLATFORM} \
-DANDROID_STL=${ANDROID_STL} \
-DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \
-DBUILD_SHARED_LIBS=OFF \
-DWITH_ASR=OFF \
-DWITH_CLS=OFF \
-DWITH_VAD=ON \
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
-DCMAKE_FIND_DEBUG_MODE=OFF \
-Wno-dev ../../..
# Build FastDeploy Android C++ SDK
make
# https://www.jianshu.com/p/33672fb819f5
PATH="/Applications/CMake.app/Contents/bin":"$PATH"
tools_dir=$1
ios_toolchain_cmake=${tools_dir}/"/ios-cmake-4.2.0/ios.toolchain.cmake"
fastdeploy_dir=${tools_dir}"/fastdeploy-ort-mac-build/"
build_targets=("OS64")
build_type_array=("Release")
#static_name="libocr"
#lib_name="libocr"
# Switch to workpath
current_path=`cd $(dirname $0);pwd`
work_path=${current_path}/
build_path=${current_path}/build/
output_path=${current_path}/output/
cd ${work_path}
# Clean
rm -rf ${build_path}
rm -rf ${output_path}
if [ "$1"x = "clean"x ]; then
exit 0
fi
# Build Every Target
for target in "${build_targets[@]}"
do
for build_type in "${build_type_array[@]}"
do
echo -e "\033[1;36;40mBuilding ${build_type} ${target} ... \033[0m"
target_build_path=${build_path}/${target}/${build_type}/
mkdir -p ${target_build_path}
cd ${target_build_path}
if [ $? -ne 0 ];then
echo -e "\033[1;31;40mcd ${target_build_path} failed \033[0m"
exit -1
fi
if [ ${target} == "OS64" ];then
fastdeploy_install_dir=${fastdeploy_dir}/arm64
else
fastdeploy_install_dir=""
echo "fastdeploy_install_dir is null"
exit -1
fi
cmake -DCMAKE_TOOLCHAIN_FILE=${ios_toolchain_cmake} \
-DBUILD_IN_MACOS=ON \
-DBUILD_SHARED_LIBS=OFF \
-DWITH_ASR=OFF \
-DWITH_CLS=OFF \
-DWITH_VAD=ON \
-DFASTDEPLOY_INSTALL_DIR=${fastdeploy_install_dir} \
-DPLATFORM=${target} ../../../
cmake --build . --config ${build_type}
mkdir output
cp engine/vad/interface/libpps_vad_interface.a output
cp engine/vad/interface/vad_interface_main.app/vad_interface_main output
cp ${fastdeploy_install_dir}/lib/libfastdeploy.dylib output
cp ${fastdeploy_install_dir}/third_libs/install/onnxruntime/lib/libonnxruntime.dylib output
done
done
## combine all ios libraries
#DEVROOT=/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/
#LIPO_TOOL=${DEVROOT}/usr/bin/lipo
#LIBRARY_PATH=${build_path}
#LIBRARY_OUTPUT_PATH=${output_path}/IOS
#mkdir -p ${LIBRARY_OUTPUT_PATH}
#
#${LIPO_TOOL} \
# -arch i386 ${LIBRARY_PATH}/ios_x86/Release/${lib_name}.a \
# -arch x86_64 ${LIBRARY_PATH}/ios_x86_64/Release/${lib_name}.a \
# -arch armv7 ${LIBRARY_PATH}/ios_armv7/Release/${lib_name}.a \
# -arch armv7s ${LIBRARY_PATH}/ios_armv7s/Release/${lib_name}.a \
# -arch arm64 ${LIBRARY_PATH}/ios_armv8/Release/${lib_name}.a \
# -output ${LIBRARY_OUTPUT_PATH}/${lib_name}.a -create
#
#cp ${work_path}/lib/houyi/lib/ios/libhouyi_score.a ${LIBRARY_OUTPUT_PATH}/
#cp ${work_path}/interface/ocr-interface.h ${output_path}
#cp ${work_path}/version/release.v ${output_path}
#
#echo -e "\033[1;36;40mBuild All Target Success At:\n${output_path}\033[0m"
#exit 0
cmake_policy(SET CMP0077 NEW)
include(FetchContent)
set(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD 1 # Wrap download in script to log output
LOG_UPDATE 1 # Wrap update in script to log output
LOG_PATCH 1
LOG_CONFIGURE 1# Wrap configure in script to log output
LOG_BUILD 1 # Wrap build in script to log output
LOG_INSTALL 1
LOG_TEST 1 # Wrap test in script to log output
LOG_MERGED_STDOUTERR 1
LOG_OUTPUT_ON_FAILURE 1
)
if(NOT FASTDEPLOY_INSTALL_DIR)
if(ANDROID)
FetchContent_Declare(
fastdeploy
URL https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz
URL_HASH MD5=2a15301158e9eb157a4f11283689e7ba
${EXTERNAL_PROJECT_LOG_ARGS}
)
add_definitions("-DUSE_PADDLE_LITE_BAKEND")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
else() # Linux
FetchContent_Declare(
fastdeploy
URL https://paddlespeech.bj.bcebos.com/speechx/fastdeploy/fastdeploy-1.0.5-x86_64-onnx.tar.gz
URL_HASH MD5=33900d986ea71aa78635e52f0733227c
${EXTERNAL_PROJECT_LOG_ARGS}
)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3")
endif()
FetchContent_MakeAvailable(fastdeploy)
set(FASTDEPLOY_INSTALL_DIR ${fc_patch}/fastdeploy-src)
endif()
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# fix compiler flags conflict, since fastdeploy using c++11 for project
# this line must after `include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)`
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
include_directories(${FASTDEPLOY_INCS})
# install fastdeploy and dependents lib
# install_fastdeploy_libraries(${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
# No dynamic libs need to install while using
# FastDeploy static lib.
if(ANDROID AND WITH_ANDROID_STATIC_LIB)
return()
endif()
set(DYN_LIB_SUFFIX "*.so*")
if(WIN32)
set(DYN_LIB_SUFFIX "*.dll")
elseif(APPLE)
set(DYN_LIB_SUFFIX "*.dylib*")
endif()
if(FastDeploy_DIR)
set(DYN_SEARCH_DIR ${FastDeploy_DIR})
elseif(FASTDEPLOY_INSTALL_DIR)
set(DYN_SEARCH_DIR ${FASTDEPLOY_INSTALL_DIR})
else()
message(FATAL_ERROR "Please set FastDeploy_DIR/FASTDEPLOY_INSTALL_DIR before call install_fastdeploy_libraries.")
endif()
file(GLOB_RECURSE ALL_NEED_DYN_LIBS ${DYN_SEARCH_DIR}/lib/${DYN_LIB_SUFFIX})
file(GLOB_RECURSE ALL_DEPS_DYN_LIBS ${DYN_SEARCH_DIR}/third_libs/${DYN_LIB_SUFFIX})
if(ENABLE_VISION)
# OpenCV
if(ANDROID)
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${DYN_LIB_SUFFIX})
else()
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_DIR}/../../${DYN_LIB_SUFFIX})
endif()
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_OPENCV_DYN_LIBS})
if(WIN32)
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/x64/vc15/bin/${DYN_LIB_SUFFIX})
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
elseif(ANDROID AND (NOT WITH_ANDROID_OPENCV_STATIC))
file(GLOB OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${ANDROID_ABI}/${DYN_LIB_SUFFIX})
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
else() # linux/mac
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/lib/${DYN_LIB_SUFFIX})
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
endif()
# FlyCV
if(ENABLE_FLYCV)
file(GLOB_RECURSE ALL_FLYCV_DYN_LIBS ${FLYCV_LIB_DIR}/${DYN_LIB_SUFFIX})
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_FLYCV_DYN_LIBS})
if(ANDROID AND (NOT WITH_ANDROID_FLYCV_STATIC))
install(FILES ${ALL_FLYCV_DYN_LIBS} DESTINATION lib)
endif()
endif()
endif()
if(ENABLE_OPENVINO_BACKEND)
# need plugins.xml for openvino backend
set(OPENVINO_RUNTIME_BIN_DIR ${OPENVINO_DIR}/bin)
file(GLOB OPENVINO_PLUGIN_XML ${OPENVINO_RUNTIME_BIN_DIR}/*.xml)
install(FILES ${OPENVINO_PLUGIN_XML} DESTINATION lib)
endif()
# Install other libraries
install(FILES ${ALL_NEED_DYN_LIBS} DESTINATION lib)
install(FILES ${ALL_DEPS_DYN_LIBS} DESTINATION lib)
...@@ -2,10 +2,13 @@ include(FetchContent) ...@@ -2,10 +2,13 @@ include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
gflags gflags
URL https://github.com/gflags/gflags/archive/v2.2.2.zip URL https://paddleaudio.bj.bcebos.com/build/gflag-2.2.2.zip
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5 URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
) )
FetchContent_MakeAvailable(gflags) FetchContent_MakeAvailable(gflags)
# openfst need # openfst need
include_directories(${gflags_BINARY_DIR}/include) include_directories(${gflags_BINARY_DIR}/include)
\ No newline at end of file link_directories(${gflags_BINARY_DIR})
#install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib)
include(FetchContent)
if(ANDROID)
else() # UNIX
add_definitions(-DWITH_GLOG)
FetchContent_Declare(
glog
URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DWITH_GFLAGS=OFF
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
)
set(BUILD_TESTING OFF)
FetchContent_MakeAvailable(glog)
include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
endif()
if(ANDROID)
add_library(extern_glog INTERFACE)
add_dependencies(extern_glog gflags)
else() # UNIX
add_library(extern_glog ALIAS glog)
add_dependencies(glog gflags)
endif()
\ No newline at end of file
include(FetchContent)
if(ANDROID)
else() # UNIX
FetchContent_Declare(
gtest
URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip
URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
)
FetchContent_MakeAvailable(gtest)
include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src)
endif()
if(ANDROID)
add_library(extern_gtest INTERFACE)
else() # UNIX
add_dependencies(gtest gflags extern_glog)
add_library(extern_gtest ALIAS gtest)
endif()
if(WITH_TESTING)
enable_testing()
endif()
include(FetchContent)
set(openfst_PREFIX_DIR ${fc_patch}/openfst) set(openfst_PREFIX_DIR ${fc_patch}/openfst)
set(openfst_SOURCE_DIR ${fc_patch}/openfst-src) set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
set(openfst_BINARY_DIR ${fc_patch}/openfst-build) set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
include(FetchContent)
# openfst Acknowledgments: # openfst Acknowledgments:
#Cyril Allauzen, Michael Riley, Johan Schalkwyk, Wojciech Skut and Mehryar Mohri, #Cyril Allauzen, Michael Riley, Johan Schalkwyk, Wojciech Skut and Mehryar Mohri,
#"OpenFst: A General and Efficient Weighted Finite-State Transducer Library", #"OpenFst: A General and Efficient Weighted Finite-State Transducer Library",
...@@ -10,18 +10,33 @@ set(openfst_BINARY_DIR ${fc_patch}/openfst-build) ...@@ -10,18 +10,33 @@ set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
#Application of Automata, (CIAA 2007), volume 4783 of Lecture Notes in #Application of Automata, (CIAA 2007), volume 4783 of Lecture Notes in
#Computer Science, pages 11-23. Springer, 2007. http://www.openfst.org. #Computer Science, pages 11-23. Springer, 2007. http://www.openfst.org.
set(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD 1 # Wrap download in script to log output
LOG_UPDATE 1 # Wrap update in script to log output
LOG_CONFIGURE 1# Wrap configure in script to log output
LOG_BUILD 1 # Wrap build in script to log output
LOG_TEST 1 # Wrap test in script to log output
LOG_INSTALL 1 # Wrap install in script to log output
)
ExternalProject_Add(openfst ExternalProject_Add(openfst
URL https://paddleaudio.bj.bcebos.com/build/openfst_1.7.2.zip URL https://paddleaudio.bj.bcebos.com/build/openfst_1.7.2.zip
URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6 URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${openfst_PREFIX_DIR} PREFIX ${openfst_PREFIX_DIR}
SOURCE_DIR ${openfst_SOURCE_DIR} SOURCE_DIR ${openfst_SOURCE_DIR}
BINARY_DIR ${openfst_BINARY_DIR} BINARY_DIR ${openfst_BINARY_DIR}
BUILD_ALWAYS 0
CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR} CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}" "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}" "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
"LIBS=-lgflags_nothreads -lglog -lpthread" "LIBS=-lgflags_nothreads -lglog -lpthread -fPIC"
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR} COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
BUILD_COMMAND make -j 4 BUILD_COMMAND make -j 4
) )
link_directories(${openfst_PREFIX_DIR}/lib) link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include) include_directories(${openfst_PREFIX_DIR}/include)
message(STATUS "OpenFST inc dir: ${openfst_PREFIX_DIR}/include")
message(STATUS "OpenFST lib dir: ${openfst_PREFIX_DIR}/lib")
#the pybind11 is from:https://github.com/pybind/pybind11
# Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>, All rights reserved.
SET(PYBIND_ZIP "v2.10.0.zip")
SET(LOCAL_PYBIND_ZIP ${FETCHCONTENT_BASE_DIR}/${PYBIND_ZIP})
SET(PYBIND_SRC ${FETCHCONTENT_BASE_DIR}/pybind11)
SET(DOWNLOAD_URL "https://paddleaudio.bj.bcebos.com/build/v2.10.0.zip")
SET(PYBIND_TIMEOUT 600 CACHE STRING "Timeout in seconds when downloading pybind.")
IF(NOT EXISTS ${LOCAL_PYBIND_ZIP})
FILE(DOWNLOAD ${DOWNLOAD_URL}
${LOCAL_PYBIND_ZIP}
TIMEOUT ${PYBIND_TIMEOUT}
STATUS ERR
SHOW_PROGRESS
)
IF(ERR EQUAL 0)
MESSAGE(STATUS "download pybind success")
ELSE()
MESSAGE(FATAL_ERROR "download pybind fail")
ENDIF()
ENDIF()
IF(NOT EXISTS ${PYBIND_SRC})
EXECUTE_PROCESS(
COMMAND ${CMAKE_COMMAND} -E tar xfz ${LOCAL_PYBIND_ZIP}
WORKING_DIRECTORY ${FETCHCONTENT_BASE_DIR}
RESULT_VARIABLE tar_result
)
file(RENAME ${FETCHCONTENT_BASE_DIR}/pybind11-2.10.0 ${PYBIND_SRC})
IF (tar_result MATCHES 0)
MESSAGE(STATUS "unzip pybind success")
ELSE()
MESSAGE(FATAL_ERROR "unzip pybind fail")
ENDIF()
ENDIF()
include_directories(${PYBIND_SRC}/include)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(pps_summary)
message(STATUS "")
message(STATUS "*************PaddleSpeech Building Summary**********")
message(STATUS " PPS_VERSION : ${PPS_VERSION}")
message(STATUS " CMake version : ${CMAKE_VERSION}")
message(STATUS " CMake command : ${CMAKE_COMMAND}")
message(STATUS " UNIX : ${UNIX}")
message(STATUS " ANDROID : ${ANDROID}")
message(STATUS " System : ${CMAKE_SYSTEM_NAME}")
message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}")
message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}")
message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}")
message(STATUS " Build type : ${CMAKE_BUILD_TYPE}")
message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}")
get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS)
message(STATUS " Compile definitions : ${tmp}")
message(STATUS " CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}")
message(STATUS " CMAKE_CURRENT_BINARY_DIR : ${CMAKE_CURRENT_BINARY_DIR}")
message(STATUS " CMAKE_INSTALL_PREFIX : ${CMAKE_INSTALL_PREFIX}")
message(STATUS " CMAKE_INSTALL_LIBDIR : ${CMAKE_INSTALL_LIBDIR}")
message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}")
message(STATUS " CMAKE_SYSTEM_NAME : ${CMAKE_SYSTEM_NAME}")
message(STATUS "")
message(STATUS " WITH_ASR : ${WITH_ASR}")
message(STATUS " WITH_CLS : ${WITH_CLS}")
message(STATUS " WITH_VAD : ${WITH_VAD}")
message(STATUS " WITH_GPU : ${WITH_GPU}")
message(STATUS " WITH_TESTING : ${WITH_TESTING}")
message(STATUS " WITH_PROFILING : ${WITH_PROFILING}")
message(STATUS " FASTDEPLOY_INSTALL_DIR : ${FASTDEPLOY_INSTALL_DIR}")
message(STATUS " FASTDEPLOY_INCS : ${FASTDEPLOY_INCS}")
message(STATUS " FASTDEPLOY_LIBS : ${FASTDEPLOY_LIBS}")
if(WITH_GPU)
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
endif()
if(ANDROID)
message(STATUS " ANDROID_ABI : ${ANDROID_ABI}")
message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}")
message(STATUS " ANDROID_NDK : ${ANDROID_NDK}")
message(STATUS " ANDROID_NDK_VERSION : ${CMAKE_ANDROID_NDK_VERSION}")
endif()
if (WITH_ASR)
message(STATUS " Python executable : ${PYTHON_EXECUTABLE}")
message(STATUS " Python includes : ${PYTHON_INCLUDE_DIR}")
endif()
endfunction()
pps_summary()
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Detects the OS and sets appropriate variables.
# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is
# building for, but the host processor name like centos is necessary
# in some scenes to distinguish system for customization.
#
# for instance, protobuf libs path is <install_dir>/lib64
# on CentOS, but <install_dir>/lib on other systems.
if(UNIX AND NOT APPLE)
# except apple from nix*Os family
set(LINUX TRUE)
endif()
if(WIN32)
set(HOST_SYSTEM "win32")
else()
if(APPLE)
set(HOST_SYSTEM "macosx")
exec_program(
sw_vers ARGS
-productVersion
OUTPUT_VARIABLE HOST_SYSTEM_VERSION)
string(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}")
if(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET})
# Set cache variable - end user may change this during ccmake or cmake-gui configure.
set(CMAKE_OSX_DEPLOYMENT_TARGET
${MACOS_VERSION}
CACHE
STRING
"Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value."
)
endif()
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
else()
if(EXISTS "/etc/issue")
file(READ "/etc/issue" LINUX_ISSUE)
if(LINUX_ISSUE MATCHES "CentOS")
set(HOST_SYSTEM "centos")
elseif(LINUX_ISSUE MATCHES "Debian")
set(HOST_SYSTEM "debian")
elseif(LINUX_ISSUE MATCHES "Ubuntu")
set(HOST_SYSTEM "ubuntu")
elseif(LINUX_ISSUE MATCHES "Red Hat")
set(HOST_SYSTEM "redhat")
elseif(LINUX_ISSUE MATCHES "Fedora")
set(HOST_SYSTEM "fedora")
endif()
string(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION
"${LINUX_ISSUE}")
endif()
if(EXISTS "/etc/redhat-release")
file(READ "/etc/redhat-release" LINUX_ISSUE)
if(LINUX_ISSUE MATCHES "CentOS")
set(HOST_SYSTEM "centos")
endif()
endif()
if(NOT HOST_SYSTEM)
set(HOST_SYSTEM ${CMAKE_SYSTEM_NAME})
endif()
endif()
endif()
# query number of logical cores
cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES)
mark_as_advanced(HOST_SYSTEM CPU_CORES)
message(
STATUS
"Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}")
message(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores")
# external dependencies log output
set(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD
0 # Wrap download in script to log output
LOG_UPDATE
1 # Wrap update in script to log output
LOG_CONFIGURE
1 # Wrap configure in script to log output
LOG_BUILD
0 # Wrap build in script to log output
LOG_TEST
1 # Wrap test in script to log output
LOG_INSTALL
0 # Wrap install in script to log output
)
\ No newline at end of file
project(speechx LANGUAGES CXX)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common)
add_subdirectory(kaldi)
add_subdirectory(common)
if(WITH_ASR)
add_subdirectory(asr)
endif()
if(WITH_CLS)
add_subdirectory(audio_classification)
endif()
if(WITH_VAD)
add_subdirectory(vad)
endif()
add_subdirectory(codelab)
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(ASR LANGUAGES CXX)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/server)
add_subdirectory(decoder)
add_subdirectory(recognizer)
add_subdirectory(nnet)
add_subdirectory(server)
set(srcs decodable.cc) set(srcs)
list(APPEND srcs
if(USING_DS2) ctc_prefix_beam_search_decoder.cc
list(APPEND srcs ds2_nnet.cc) ctc_tlg_decoder.cc
endif() )
if(USING_U2) add_library(decoder STATIC ${srcs})
list(APPEND srcs u2_nnet.cc) target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
endif()
# test
add_library(nnet STATIC ${srcs}) set(TEST_BINS
target_link_libraries(nnet absl::strings) ctc_prefix_beam_search_decoder_main
ctc_tlg_decoder_main
if(USING_U2) )
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) foreach(bin_name IN LISTS TEST_BINS)
endif()
if(USING_DS2)
set(bin_name ds2_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_link_libraries(${bin_name} ${DEPS})
endif()
# test bin
if(USING_U2)
set(bin_name u2_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
endif() endforeach()
...@@ -22,51 +22,22 @@ namespace ppspeech { ...@@ -22,51 +22,22 @@ namespace ppspeech {
struct CTCBeamSearchOptions { struct CTCBeamSearchOptions {
// common // common
int blank; int blank;
std::string word_symbol_table;
// ds2
std::string dict_file;
std::string lm_path;
int beam_size;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int cutoff_top_n;
int num_proc_bsearch;
// u2 // u2
int first_beam_size; int first_beam_size;
int second_beam_size; int second_beam_size;
CTCBeamSearchOptions() CTCBeamSearchOptions()
: blank(0), : blank(0),
dict_file("vocab.txt"), word_symbol_table("vocab.txt"),
lm_path(""),
beam_size(300),
alpha(1.9f),
beta(5.0),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10),
first_beam_size(10), first_beam_size(10),
second_beam_size(10) {} second_beam_size(10) {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
std::string module = "Ds2BeamSearchConfig: "; std::string module = "CTCBeamSearchOptions: ";
opts->Register("dict", &dict_file, module + "vocab file path."); opts->Register("word_symbol_table", &word_symbol_table, module + "vocab file path.");
opts->Register(
"lm-path", &lm_path, module + "ngram language model path.");
opts->Register("alpha", &alpha, module + "alpha");
opts->Register("beta", &beta, module + "beta");
opts->Register("beam-size",
&beam_size,
module + "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
opts->Register("blank", &blank, "blank id, default is 0."); opts->Register("blank", &blank, "blank id, default is 0.");
module = "U2BeamSearchConfig: ";
opts->Register( opts->Register(
"first-beam-size", &first_beam_size, module + "first beam size."); "first-beam-size", &first_beam_size, module + "first beam size.");
opts->Register("second-beam-size", opts->Register("second-beam-size",
......
...@@ -17,13 +17,12 @@ ...@@ -17,13 +17,12 @@
#include "decoder/ctc_prefix_beam_search_decoder.h" #include "decoder/ctc_prefix_beam_search_decoder.h"
#include "absl/strings/str_join.h"
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h" #include "decoder/ctc_prefix_beam_search_score.h"
#include "utils/math.h" #include "utils/math.h"
#ifdef USE_PROFILING #ifdef WITH_PROFILING
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
using paddle::platform::RecordEvent; using paddle::platform::RecordEvent;
using paddle::platform::TracerEventType; using paddle::platform::TracerEventType;
...@@ -31,11 +30,10 @@ using paddle::platform::TracerEventType; ...@@ -31,11 +30,10 @@ using paddle::platform::TracerEventType;
namespace ppspeech { namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path, CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts)
const CTCBeamSearchOptions& opts)
: opts_(opts) { : opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>( unit_table_ = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(vocab_path)); fst::SymbolTable::ReadText(opts.word_symbol_table));
CHECK(unit_table_ != nullptr); CHECK(unit_table_ != nullptr);
Reset(); Reset();
...@@ -66,7 +64,6 @@ void CTCPrefixBeamSearch::Reset() { ...@@ -66,7 +64,6 @@ void CTCPrefixBeamSearch::Reset() {
void CTCPrefixBeamSearch::InitDecoder() { Reset(); } void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::AdvanceDecode( void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
double search_cost = 0.0; double search_cost = 0.0;
...@@ -78,21 +75,21 @@ void CTCPrefixBeamSearch::AdvanceDecode( ...@@ -78,21 +75,21 @@ void CTCPrefixBeamSearch::AdvanceDecode(
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
feat_nnet_cost += timer.Elapsed(); feat_nnet_cost += timer.Elapsed();
if (flag == false) { if (flag == false) {
VLOG(3) << "decoder advance decode exit." << frame_prob.size(); VLOG(2) << "decoder advance decode exit." << frame_prob.size();
break; break;
} }
timer.Reset(); timer.Reset();
std::vector<std::vector<kaldi::BaseFloat>> likelihood; std::vector<std::vector<kaldi::BaseFloat>> likelihood;
likelihood.push_back(frame_prob); likelihood.push_back(std::move(frame_prob));
AdvanceDecoding(likelihood); AdvanceDecoding(likelihood);
search_cost += timer.Elapsed(); search_cost += timer.Elapsed();
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_; VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_;
} }
VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost VLOG(2) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
<< " sec."; << " sec.";
VLOG(1) << "AdvanceDecode search cost: " << search_cost << " sec."; VLOG(2) << "AdvanceDecode search cost: " << search_cost << " sec.";
} }
static bool PrefixScoreCompare( static bool PrefixScoreCompare(
...@@ -105,7 +102,7 @@ static bool PrefixScoreCompare( ...@@ -105,7 +102,7 @@ static bool PrefixScoreCompare(
void CTCPrefixBeamSearch::AdvanceDecoding( void CTCPrefixBeamSearch::AdvanceDecoding(
const std::vector<std::vector<kaldi::BaseFloat>>& logp) { const std::vector<std::vector<kaldi::BaseFloat>>& logp) {
#ifdef USE_PROFILING #ifdef WITH_PROFILING
RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding", RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding",
TracerEventType::UserDefined, TracerEventType::UserDefined,
1); 1);
......
...@@ -27,8 +27,7 @@ namespace ppspeech { ...@@ -27,8 +27,7 @@ namespace ppspeech {
class ContextGraph; class ContextGraph;
class CTCPrefixBeamSearch : public DecoderBase { class CTCPrefixBeamSearch : public DecoderBase {
public: public:
CTCPrefixBeamSearch(const std::string& vocab_path, CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {} ~CTCPrefixBeamSearch() {}
SearchType Type() const { return SearchType::kPrefixBeamSearch; } SearchType Type() const { return SearchType::kPrefixBeamSearch; }
...@@ -45,7 +44,7 @@ class CTCPrefixBeamSearch : public DecoderBase { ...@@ -45,7 +44,7 @@ class CTCPrefixBeamSearch : public DecoderBase {
void FinalizeSearch(); void FinalizeSearch();
const std::shared_ptr<fst::SymbolTable> VocabTable() const { const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
return unit_table_; return unit_table_;
} }
...@@ -57,7 +56,6 @@ class CTCPrefixBeamSearch : public DecoderBase { ...@@ -57,7 +56,6 @@ class CTCPrefixBeamSearch : public DecoderBase {
} }
const std::vector<std::vector<int>>& Times() const { return times_; } const std::vector<std::vector<int>>& Times() const { return times_; }
protected: protected:
std::string GetBestPath() override; std::string GetBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath() override; std::vector<std::pair<double, std::string>> GetNBestPath() override;
......
...@@ -12,18 +12,18 @@ ...@@ -12,18 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/strings/str_split.h"
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_prefix_beam_search_decoder.h" #include "decoder/ctc_prefix_beam_search_decoder.h"
#include "frontend/audio/data_cache.h" #include "frontend/data_cache.h"
#include "fst/symbol-table.h" #include "fst/symbol-table.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
#include "nnet/u2_nnet.h" #include "nnet/u2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(vocab_path, "", "vocab path"); DEFINE_string(word_symbol_table, "", "vocab path");
DEFINE_string(model_path, "", "paddle nnet model"); DEFINE_string(model_path, "", "paddle nnet model");
...@@ -40,7 +40,7 @@ using kaldi::BaseFloat; ...@@ -40,7 +40,7 @@ using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
// test ds2 online decoder by feeding speech feature // test u2 online decoder by feeding speech feature
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:"); gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
...@@ -52,10 +52,10 @@ int main(int argc, char* argv[]) { ...@@ -52,10 +52,10 @@ int main(int argc, char* argv[]) {
CHECK_NE(FLAGS_result_wspecifier, ""); CHECK_NE(FLAGS_result_wspecifier, "");
CHECK_NE(FLAGS_feature_rspecifier, ""); CHECK_NE(FLAGS_feature_rspecifier, "");
CHECK_NE(FLAGS_vocab_path, ""); CHECK_NE(FLAGS_word_symbol_table, "");
CHECK_NE(FLAGS_model_path, ""); CHECK_NE(FLAGS_model_path, "");
LOG(INFO) << "model path: " << FLAGS_model_path; LOG(INFO) << "model path: " << FLAGS_model_path;
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path; LOG(INFO) << "Reading vocab table " << FLAGS_word_symbol_table;
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
...@@ -70,15 +70,18 @@ int main(int argc, char* argv[]) { ...@@ -70,15 +70,18 @@ int main(int argc, char* argv[]) {
// decodeable // decodeable
std::shared_ptr<ppspeech::DataCache> raw_data = std::shared_ptr<ppspeech::DataCache> raw_data =
std::make_shared<ppspeech::DataCache>(); std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nnet, raw_data, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable = std::shared_ptr<ppspeech::Decodable> decodable =
std::make_shared<ppspeech::Decodable>(nnet, raw_data); std::make_shared<ppspeech::Decodable>(nnet_producer);
// decoder // decoder
ppspeech::CTCBeamSearchOptions opts; ppspeech::CTCBeamSearchOptions opts;
opts.blank = 0; opts.blank = 0;
opts.first_beam_size = 10; opts.first_beam_size = 10;
opts.second_beam_size = 10; opts.second_beam_size = 10;
ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts); opts.word_symbol_table = FLAGS_word_symbol_table;
ppspeech::CTCPrefixBeamSearch decoder(opts);
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
...@@ -122,15 +125,14 @@ int main(int argc, char* argv[]) { ...@@ -122,15 +125,14 @@ int main(int argc, char* argv[]) {
} }
kaldi::Vector<kaldi::BaseFloat> feature_chunk(this_chunk_size * std::vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim); feat_dim);
int32 start = chunk_idx * chunk_stride; int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < this_chunk_size; ++row_id) { for (int row_id = 0; row_id < this_chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start); kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row( std::memcpy(feature_chunk.data() + row_id * feat_dim,
feature_chunk.Data() + row_id * feat_dim, feat_dim); feat_row.Data(),
feat_dim * sizeof(kaldi::BaseFloat));
feature_chunk_row.CopyFromVec(feat_row);
++start; ++start;
} }
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
// limitations under the License. // limitations under the License.
#include "decoder/ctc_tlg_decoder.h" #include "decoder/ctc_tlg_decoder.h"
namespace ppspeech { namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path)); fst_ = opts.fst_ptr;
CHECK(fst_ != nullptr); CHECK(fst_ != nullptr);
CHECK(!opts.word_symbol_table.empty());
word_symbol_table_.reset( word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table)); fst::SymbolTable::ReadText(opts.word_symbol_table));
...@@ -29,6 +31,11 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { ...@@ -29,6 +31,11 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
void TLGDecoder::Reset() { void TLGDecoder::Reset() {
decoder_->InitDecoding(); decoder_->InitDecoding();
hypotheses_.clear();
likelihood_.clear();
olabels_.clear();
times_.clear();
num_frame_decoded_ = 0; num_frame_decoded_ = 0;
return; return;
} }
...@@ -68,14 +75,52 @@ std::string TLGDecoder::GetPartialResult() { ...@@ -68,14 +75,52 @@ std::string TLGDecoder::GetPartialResult() {
return words; return words;
} }
void TLGDecoder::FinalizeSearch() {
decoder_->FinalizeDecoding();
kaldi::CompactLattice clat;
decoder_->GetLattice(&clat, true);
kaldi::Lattice lat, nbest_lat;
fst::ConvertLattice(clat, &lat);
fst::ShortestPath(lat, &nbest_lat, opts_.nbest);
std::vector<kaldi::Lattice> nbest_lats;
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
hypotheses_.clear();
hypotheses_.reserve(nbest_lats.size());
likelihood_.clear();
likelihood_.reserve(nbest_lats.size());
times_.clear();
times_.reserve(nbest_lats.size());
for (auto lat : nbest_lats) {
kaldi::LatticeWeight weight;
std::vector<int> hypothese;
std::vector<int> time;
std::vector<int> alignment;
std::vector<int> words_id;
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
int idx = 0;
for (; idx < alignment.size() - 1; ++idx) {
if (alignment[idx] == 0) continue;
if (alignment[idx] != alignment[idx + 1]) {
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
}
}
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
hypotheses_.push_back(hypothese);
times_.push_back(time);
olabels_.push_back(words_id);
likelihood_.push_back(-(weight.Value2() + weight.Value1()));
}
}
std::string TLGDecoder::GetFinalBestPath() { std::string TLGDecoder::GetFinalBestPath() {
if (num_frame_decoded_ == 0) { if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.") // BestPathEnd if no frames were decoded.")
return std::string(""); return std::string("");
} }
decoder_->FinalizeDecoding();
kaldi::Lattice lat; kaldi::Lattice lat;
kaldi::LatticeWeight weight; kaldi::LatticeWeight weight;
std::vector<int> alignment; std::vector<int> alignment;
......
...@@ -18,13 +18,14 @@ ...@@ -18,13 +18,14 @@
#include "decoder/decoder_itf.h" #include "decoder/decoder_itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h" #include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h" #include "util/parse-options.h"
#include "utils/file_utils.h"
DECLARE_string(graph_path);
DECLARE_string(word_symbol_table); DECLARE_string(word_symbol_table);
DECLARE_string(graph_path);
DECLARE_int32(max_active); DECLARE_int32(max_active);
DECLARE_double(beam); DECLARE_double(beam);
DECLARE_double(lattice_beam); DECLARE_double(lattice_beam);
DECLARE_int32(nbest);
namespace ppspeech { namespace ppspeech {
...@@ -33,17 +34,27 @@ struct TLGDecoderOptions { ...@@ -33,17 +34,27 @@ struct TLGDecoderOptions {
// todo remove later, add into decode resource // todo remove later, add into decode resource
std::string word_symbol_table; std::string word_symbol_table;
std::string fst_path; std::string fst_path;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_ptr;
int nbest;
TLGDecoderOptions() : word_symbol_table(""), fst_path(""), fst_ptr(nullptr), nbest(10) {}
static TLGDecoderOptions InitFromFlags() { static TLGDecoderOptions InitFromFlags() {
TLGDecoderOptions decoder_opts; TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table; decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path; decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path; LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table; LOG(INFO) << "symbole table: " << decoder_opts.word_symbol_table;
if (!decoder_opts.fst_path.empty()) {
CHECK(FileExists(decoder_opts.fst_path));
decoder_opts.fst_ptr.reset(fst::Fst<fst::StdArc>::Read(FLAGS_graph_path));
}
decoder_opts.opts.max_active = FLAGS_max_active; decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam; decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
decoder_opts.nbest = FLAGS_nbest;
LOG(INFO) << "LatticeFasterDecoder max active: " LOG(INFO) << "LatticeFasterDecoder max active: "
<< decoder_opts.opts.max_active; << decoder_opts.opts.max_active;
LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam; LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam;
...@@ -59,20 +70,38 @@ class TLGDecoder : public DecoderBase { ...@@ -59,20 +70,38 @@ class TLGDecoder : public DecoderBase {
explicit TLGDecoder(TLGDecoderOptions opts); explicit TLGDecoder(TLGDecoderOptions opts);
~TLGDecoder() = default; ~TLGDecoder() = default;
void InitDecoder(); void InitDecoder() override;
void Reset(); void Reset() override;
void AdvanceDecode( void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable); const std::shared_ptr<kaldi::DecodableInterface>& decodable) override;
void Decode(); void Decode();
std::string GetFinalBestPath() override; std::string GetFinalBestPath() override;
std::string GetPartialResult() override; std::string GetPartialResult() override;
const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
return word_symbol_table_;
}
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
const std::vector<std::string>& nbest_words); const std::vector<std::string>& nbest_words);
void FinalizeSearch() override;
const std::vector<std::vector<int>>& Inputs() const override {
return hypotheses_;
}
const std::vector<std::vector<int>>& Outputs() const override {
return olabels_;
} // outputs_; }
const std::vector<float>& Likelihood() const override {
return likelihood_;
}
const std::vector<std::vector<int>>& Times() const override {
return times_;
}
protected: protected:
std::string GetBestPath() override { std::string GetBestPath() override {
CHECK(false); CHECK(false);
...@@ -90,10 +119,17 @@ class TLGDecoder : public DecoderBase { ...@@ -90,10 +119,17 @@ class TLGDecoder : public DecoderBase {
private: private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable); void AdvanceDecoding(kaldi::DecodableInterface* decodable);
int num_frame_decoded_;
std::vector<std::vector<int>> hypotheses_;
std::vector<std::vector<int>> olabels_;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_; std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_; std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_; std::shared_ptr<fst::SymbolTable> word_symbol_table_;
TLGDecoderOptions opts_;
}; };
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
...@@ -14,21 +14,24 @@ ...@@ -14,21 +14,24 @@
// todo refactor, repalce with gtest // todo refactor, repalce with gtest
#include "base/flags.h" #include "base/common.h"
#include "base/log.h" #include "decoder/ctc_tlg_decoder.h"
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/param.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
DEFINE_string(nnet_prob_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
// test decoder by feeding nnet posterior probability // test TLG decoder by feeding speech feature.
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:"); gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
...@@ -36,41 +39,51 @@ int main(int argc, char* argv[]) { ...@@ -36,41 +39,51 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler(); google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader likelihood_reader( kaldi::SequentialBaseFloatMatrixReader nnet_prob_reader(
FLAGS_nnet_prob_respecifier); FLAGS_nnet_prob_rspecifier);
std::string dict_file = FLAGS_dict_file; kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts; ppspeech::TLGDecoderOptions opts =
opts.dict_file = dict_file; ppspeech::TLGDecoderOptions::InitFromFlags();
opts.lm_path = lm_path; opts.opts.beam = 15.0;
ppspeech::CTCBeamSearch decoder(opts); opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nullptr, nullptr, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nullptr, nullptr)); new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale));
decoder.InitDecoder(); decoder.InitDecoder();
kaldi::Timer timer;
for (; !likelihood_reader.Done(); likelihood_reader.Next()) { for (; !nnet_prob_reader.Done(); nnet_prob_reader.Next()) {
string utt = likelihood_reader.Key(); string utt = nnet_prob_reader.Key();
const kaldi::Matrix<BaseFloat> likelihood = likelihood_reader.Value(); kaldi::Matrix<BaseFloat> prob = nnet_prob_reader.Value();
LOG(INFO) << "process utt: " << utt; decodable->Acceptlikelihood(prob);
LOG(INFO) << "rows: " << likelihood.NumRows();
LOG(INFO) << "cols: " << likelihood.NumCols();
decodable->Acceptlikelihood(likelihood);
decoder.AdvanceDecode(decodable); decoder.AdvanceDecode(decodable);
std::string result; std::string result;
result = decoder.GetFinalBestPath(); result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset(); decodable->Reset();
decoder.Reset(); decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done; ++num_done;
} }
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "Done " << num_done << " utterances, " << num_err KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors."; << " with errors.";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,6 +15,7 @@ ...@@ -16,6 +15,7 @@
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "fst/symbol-table.h"
#include "kaldi/decoder/decodable-itf.h" #include "kaldi/decoder/decodable-itf.h"
namespace ppspeech { namespace ppspeech {
...@@ -41,6 +41,14 @@ class DecoderInterface { ...@@ -41,6 +41,14 @@ class DecoderInterface {
virtual std::string GetPartialResult() = 0; virtual std::string GetPartialResult() = 0;
virtual const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const = 0;
virtual void FinalizeSearch() = 0;
virtual const std::vector<std::vector<int>>& Inputs() const = 0;
virtual const std::vector<std::vector<int>>& Outputs() const = 0;
virtual const std::vector<float>& Likelihood() const = 0;
virtual const std::vector<std::vector<int>>& Times() const = 0;
protected: protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0; // virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
// feature // feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
...@@ -37,36 +35,22 @@ DEFINE_int32(subsampling_rate, ...@@ -37,36 +35,22 @@ DEFINE_int32(subsampling_rate,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet // nnet
DEFINE_string(vocab_path, "", "nnet vocab path.");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); #ifdef USE_ONNX
DEFINE_string( DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path");
model_input_names, #endif
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
// decoder // decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "", "decoder graph");
DEFINE_string(graph_path, "TLG", "decoder graph"); DEFINE_string(word_symbol_table, "", "word symbol table");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_int32(max_active, 7500, "max active"); DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_double(blank_threshold, 0.98, "blank skip threshold");
// DecodeOptions flags // DecodeOptions flags
// DEFINE_int32(chunk_size, -1, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding"); DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight, DEFINE_double(ctc_weight,
0.5, 0.5,
......
set(srcs decodable.cc nnet_producer.cc)
list(APPEND srcs u2_nnet.cc)
if(WITH_ONNX)
list(APPEND srcs u2_onnx_nnet.cc)
endif()
add_library(nnet STATIC ${srcs})
target_link_libraries(nnet utils)
if(WITH_ONNX)
target_link_libraries(nnet ${FASTDEPLOY_LIBS})
endif()
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
# test bin
#set(bin_name u2_nnet_main)
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
\ No newline at end of file
...@@ -21,29 +21,25 @@ using kaldi::Matrix; ...@@ -21,29 +21,25 @@ using kaldi::Matrix;
using kaldi::Vector; using kaldi::Vector;
using std::vector; using std::vector;
Decodable::Decodable(const std::shared_ptr<NnetBase>& nnet, Decodable::Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale) kaldi::BaseFloat acoustic_scale)
: frontend_(frontend), : nnet_producer_(nnet_producer),
nnet_(nnet),
frame_offset_(0), frame_offset_(0),
frames_ready_(0), frames_ready_(0),
acoustic_scale_(acoustic_scale) {} acoustic_scale_(acoustic_scale) {}
// for debug // for debug
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_out_cache_ = likelihood; nnet_producer_->Acceptlikelihood(likelihood);
frames_ready_ += likelihood.NumRows();
} }
// return the size of frame have computed. // return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; } int32 Decodable::NumFramesReady() const { return frames_ready_; }
// frame idx is from 0 to frame_ready_ -1; // frame idx is from 0 to frame_ready_ -1;
bool Decodable::IsLastFrame(int32 frame) { bool Decodable::IsLastFrame(int32 frame) {
bool flag = EnsureFrameHaveComputed(frame); EnsureFrameHaveComputed(frame);
return frame >= frames_ready_; return frame >= frames_ready_;
} }
...@@ -64,32 +60,10 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) { ...@@ -64,32 +60,10 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) {
bool Decodable::AdvanceChunk() { bool Decodable::AdvanceChunk() {
kaldi::Timer timer; kaldi::Timer timer;
// read feats bool flag = nnet_producer_->Read(&framelikelihood_);
Vector<BaseFloat> features; if (flag == false) return false;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
VLOG(3) << "decodable exit;";
return false;
}
CHECK_GE(frontend_->Dim(), 0);
VLOG(1) << "AdvanceChunk feat cost: " << timer.Elapsed() << " sec.";
VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats.";
// forward feats
NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
Vector<BaseFloat>& logprobs = out.logprobs;
VLOG(2) << "Forward out " << logprobs.Dim() / vocab_dim
<< " decoder frames.";
// cache nnet outupts
nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim);
nnet_out_cache_.CopyRowsFromVec(logprobs);
// update state, decoding frame.
frame_offset_ = frames_ready_; frame_offset_ = frames_ready_;
frames_ready_ += nnet_out_cache_.NumRows(); frames_ready_ += 1;
VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed() VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed()
<< " sec."; << " sec.";
return true; return true;
...@@ -101,17 +75,17 @@ bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs, ...@@ -101,17 +75,17 @@ bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
return false; return false;
} }
int nrows = nnet_out_cache_.NumRows(); if (framelikelihood_.empty()) {
CHECK(nrows == (frames_ready_ - frame_offset_));
if (nrows <= 0) {
LOG(WARNING) << "No new nnet out in cache."; LOG(WARNING) << "No new nnet out in cache.";
return false; return false;
} }
logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols()); size_t dim = framelikelihood_.size();
logprobs->CopyRowsFromMat(nnet_out_cache_); logprobs->Resize(framelikelihood_.size());
std::memcpy(logprobs->Data(),
*vocab_dim = nnet_out_cache_.NumCols(); framelikelihood_.data(),
dim * sizeof(kaldi::BaseFloat));
*vocab_dim = framelikelihood_.size();
return true; return true;
} }
...@@ -122,19 +96,8 @@ bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) { ...@@ -122,19 +96,8 @@ bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
return false; return false;
} }
int nrows = nnet_out_cache_.NumRows(); CHECK_EQ(1, (frames_ready_ - frame_offset_));
CHECK(nrows == (frames_ready_ - frame_offset_)); *likelihood = framelikelihood_;
int vocab_size = nnet_out_cache_.NumCols();
likelihood->resize(vocab_size);
for (int32 idx = 0; idx < vocab_size; ++idx) {
(*likelihood)[idx] =
nnet_out_cache_(frame - frame_offset_, idx) * acoustic_scale_;
VLOG(4) << "nnet out: " << frame << " offset:" << frame_offset_ << " "
<< nnet_out_cache_.NumRows()
<< " logprob: " << nnet_out_cache_(frame - frame_offset_, idx);
}
return true; return true;
} }
...@@ -143,37 +106,31 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { ...@@ -143,37 +106,31 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return false; return false;
} }
CHECK_LE(index, nnet_out_cache_.NumCols()); CHECK_LE(index, framelikelihood_.size());
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
// the nnet output is prob ranther than log prob // the nnet output is prob ranther than log prob
// the index - 1, because the ilabel // the index - 1, because the ilabel
BaseFloat logprob = 0.0; BaseFloat logprob = 0.0;
int32 frame_idx = frame - frame_offset_; int32 frame_idx = frame - frame_offset_;
BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index)); CHECK_EQ(frame_idx, 0);
if (nnet_->IsLogProb()) { logprob = framelikelihood_[TokenId2NnetId(index)];
logprob = nnet_out;
} else {
logprob = std::log(nnet_out + std::numeric_limits<float>::epsilon());
}
CHECK(!std::isnan(logprob) && !std::isinf(logprob));
return acoustic_scale_ * logprob; return acoustic_scale_ * logprob;
} }
void Decodable::Reset() { void Decodable::Reset() {
if (frontend_ != nullptr) frontend_->Reset(); if (nnet_producer_ != nullptr) nnet_producer_->Reset();
if (nnet_ != nullptr) nnet_->Reset();
frame_offset_ = 0; frame_offset_ = 0;
frames_ready_ = 0; frames_ready_ = 0;
nnet_out_cache_.Resize(0, 0); framelikelihood_.clear();
} }
void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps, void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight, float reverse_weight,
std::vector<float>* rescoring_score) { std::vector<float>* rescoring_score) {
kaldi::Timer timer; kaldi::Timer timer;
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec."; VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec.";
} }
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once
#include "base/common.h" #include "base/common.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/decoder/decodable-itf.h" #include "kaldi/decoder/decodable-itf.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h" #include "nnet/nnet_itf.h"
#include "nnet/nnet_producer.h"
namespace ppspeech { namespace ppspeech {
...@@ -24,12 +26,9 @@ struct DecodableOpts; ...@@ -24,12 +26,9 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable(const std::shared_ptr<NnetBase>& nnet, explicit Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale = 1.0); kaldi::BaseFloat acoustic_scale = 1.0);
// void Init(DecodableOpts config);
// nnet logprob output, used by wfst // nnet logprob output, used by wfst
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
...@@ -57,23 +56,17 @@ class Decodable : public kaldi::DecodableInterface { ...@@ -57,23 +56,17 @@ class Decodable : public kaldi::DecodableInterface {
void Reset(); void Reset();
bool IsInputFinished() const { return frontend_->IsFinished(); } bool IsInputFinished() const { return nnet_producer_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame); bool EnsureFrameHaveComputed(int32 frame);
int32 TokenId2NnetId(int32 token_id); int32 TokenId2NnetId(int32 token_id);
std::shared_ptr<NnetBase> Nnet() { return nnet_; }
// for offline test // for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
private: private:
std::shared_ptr<FrontendInterface> frontend_; std::shared_ptr<NnetProducer> nnet_producer_;
std::shared_ptr<NnetBase> nnet_;
// nnet outputs' cache
kaldi::Matrix<kaldi::BaseFloat> nnet_out_cache_;
// the frame is nnet prob frame rather than audio feature frame // the frame is nnet prob frame rather than audio feature frame
// nnet frame subsample the feature frame // nnet frame subsample the feature frame
...@@ -85,6 +78,7 @@ class Decodable : public kaldi::DecodableInterface { ...@@ -85,6 +78,7 @@ class Decodable : public kaldi::DecodableInterface {
// so use subsampled_frame // so use subsampled_frame
int32 current_log_post_subsampled_offset_; int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_; int32 num_chunk_computed_;
std::vector<kaldi::BaseFloat> framelikelihood_;
kaldi::BaseFloat acoustic_scale_; kaldi::BaseFloat acoustic_scale_;
}; };
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "base/basic_types.h" #include "base/basic_types.h"
#include "kaldi/base/kaldi-types.h" #include "kaldi/base/kaldi-types.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h" #include "kaldi/util/options-itf.h"
DECLARE_int32(subsampling_rate); DECLARE_int32(subsampling_rate);
...@@ -25,26 +24,20 @@ DECLARE_string(model_input_names); ...@@ -25,26 +24,20 @@ DECLARE_string(model_input_names);
DECLARE_string(model_output_names); DECLARE_string(model_output_names);
DECLARE_string(model_cache_names); DECLARE_string(model_cache_names);
DECLARE_string(model_cache_shapes); DECLARE_string(model_cache_shapes);
#ifdef USE_ONNX
DECLARE_bool(with_onnx_model);
#endif
namespace ppspeech { namespace ppspeech {
struct ModelOptions { struct ModelOptions {
// common // common
int subsample_rate{1}; int subsample_rate{1};
int thread_num{1}; // predictor thread pool size for ds2;
bool use_gpu{false}; bool use_gpu{false};
std::string model_path; std::string model_path;
#ifdef USE_ONNX
std::string param_path; bool with_onnx_model{false};
#endif
// ds2 for inference
std::string input_names{};
std::string output_names{};
std::string cache_names{};
std::string cache_shape{};
bool switch_ir_optim{false};
bool enable_fc_padding{false};
bool enable_profile{false};
static ModelOptions InitFromFlags() { static ModelOptions InitFromFlags() {
ModelOptions opts; ModelOptions opts;
...@@ -52,26 +45,17 @@ struct ModelOptions { ...@@ -52,26 +45,17 @@ struct ModelOptions {
LOG(INFO) << "subsampling rate: " << opts.subsample_rate; LOG(INFO) << "subsampling rate: " << opts.subsample_rate;
opts.model_path = FLAGS_model_path; opts.model_path = FLAGS_model_path;
LOG(INFO) << "model path: " << opts.model_path; LOG(INFO) << "model path: " << opts.model_path;
#ifdef USE_ONNX
opts.param_path = FLAGS_param_path; opts.with_onnx_model = FLAGS_with_onnx_model;
LOG(INFO) << "param path: " << opts.param_path; LOG(INFO) << "with onnx model: " << opts.with_onnx_model;
#endif
LOG(INFO) << "DS2 param: ";
opts.cache_names = FLAGS_model_cache_names;
LOG(INFO) << " cache names: " << opts.cache_names;
opts.cache_shape = FLAGS_model_cache_shapes;
LOG(INFO) << " cache shape: " << opts.cache_shape;
opts.input_names = FLAGS_model_input_names;
LOG(INFO) << " input names: " << opts.input_names;
opts.output_names = FLAGS_model_output_names;
LOG(INFO) << " output names: " << opts.output_names;
return opts; return opts;
} }
}; };
struct NnetOut { struct NnetOut {
// nnet out. maybe logprob or prob. Almost time this is logprob. // nnet out. maybe logprob or prob. Almost time this is logprob.
kaldi::Vector<kaldi::BaseFloat> logprobs; std::vector<kaldi::BaseFloat> logprobs;
int32 vocab_dim; int32 vocab_dim;
// nnet state. Only using in Attention model. // nnet state. Only using in Attention model.
...@@ -89,7 +73,7 @@ class NnetInterface { ...@@ -89,7 +73,7 @@ class NnetInterface {
// nnet do not cache feats, feats cached by frontend. // nnet do not cache feats, feats cached by frontend.
// nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache, // nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache,
// frame_offset. // frame_offset.
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features, virtual void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim, const int32& feature_dim,
NnetOut* out) = 0; NnetOut* out) = 0;
...@@ -105,14 +89,14 @@ class NnetInterface { ...@@ -105,14 +89,14 @@ class NnetInterface {
// using to get encoder outs. e.g. seq2seq with Attention model. // using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts( virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0; std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const = 0;
}; };
class NnetBase : public NnetInterface { class NnetBase : public NnetInterface {
public: public:
int SubsamplingRate() const { return subsampling_rate_; } int SubsamplingRate() const { return subsampling_rate_; }
virtual std::shared_ptr<NnetBase> Clone() const = 0;
protected: protected:
int subsampling_rate_{1}; int subsampling_rate_{1};
}; };
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet_producer.h"
#include "matrix/kaldi-matrix.h"
namespace ppspeech {
using kaldi::BaseFloat;
using std::vector;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold)
: nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) {
Reset();
}
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
frontend_->Accept(inputs);
}
void NnetProducer::Acceptlikelihood(
const kaldi::Matrix<BaseFloat>& likelihood) {
std::vector<BaseFloat> prob;
prob.resize(likelihood.NumCols());
for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) {
for (size_t col = 0; col < likelihood.NumCols(); ++col) {
prob[col] = likelihood(idx, col);
}
cache_.push_back(prob);
}
}
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
bool flag = cache_.pop(nnet_prob);
return flag;
}
bool NnetProducer::Compute() {
vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
if (frontend_->IsFinished() == true) {
finished_ = true;
}
return false;
}
CHECK_GE(frontend_->Dim(), 0);
VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
size_t nframes = out.logprobs.size() / vocab_dim;
VLOG(1) << "Forward out " << nframes << " decoder frames.";
for (size_t idx = 0; idx < nframes; ++idx) {
std::vector<BaseFloat> logprob(
out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim);
// process blank prob
float blank_prob = std::exp(logprob[0]);
if (blank_prob > blank_threshold_) {
last_frame_logprob_ = logprob;
is_last_frame_skip_ = true;
continue;
} else {
int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin();
if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) {
cache_.push_back(last_frame_logprob_);
last_max_elem_ = cur_max;
}
last_max_elem_ = cur_max;
is_last_frame_skip_ = false;
cache_.push_back(logprob);
}
}
return true;
}
void NnetProducer::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
}
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "base/safe_queue.h"
#include "frontend/frontend_itf.h"
#include "nnet/nnet_itf.h"
namespace ppspeech {
class NnetProducer {
public:
explicit NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold);
// Feed feats or waves
void Accept(const std::vector<kaldi::BaseFloat>& inputs);
void Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood);
// nnet
bool Read(std::vector<kaldi::BaseFloat>* nnet_prob);
bool Empty() const { return cache_.empty(); }
void SetInputFinished() {
LOG(INFO) << "set finished";
frontend_->SetFinished();
}
// the compute thread exit
bool IsFinished() const {
return (frontend_->IsFinished() && finished_);
}
~NnetProducer() {}
void Reset() {
if (frontend_ != NULL) frontend_->Reset();
if (nnet_ != NULL) nnet_->Reset();
cache_.clear();
finished_ = false;
}
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score);
bool Compute();
private:
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::vector<BaseFloat> last_frame_logprob_;
bool is_last_frame_skip_ = false;
int last_max_elem_ = -1;
float blank_threshold_ = 0.0;
bool finished_;
DISALLOW_COPY_AND_ASSIGN(NnetProducer);
};
} // namespace ppspeech
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.cc // https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.cc
#include "nnet/u2_nnet.h" #include "nnet/u2_nnet.h"
#include <type_traits>
#ifdef USE_PROFILING #ifdef WITH_PROFILING
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
using paddle::platform::RecordEvent; using paddle::platform::RecordEvent;
using paddle::platform::TracerEventType; using paddle::platform::TracerEventType;
#endif // end USE_PROFILING #endif // end WITH_PROFILING
namespace ppspeech { namespace ppspeech {
...@@ -30,7 +31,7 @@ namespace ppspeech { ...@@ -30,7 +31,7 @@ namespace ppspeech {
void U2Nnet::LoadModel(const std::string& model_path_w_prefix) { void U2Nnet::LoadModel(const std::string& model_path_w_prefix) {
paddle::jit::utils::InitKernelSignatureMap(); paddle::jit::utils::InitKernelSignatureMap();
#ifdef USE_GPU #ifdef WITH_GPU
dev_ = phi::GPUPlace(); dev_ = phi::GPUPlace();
#else #else
dev_ = phi::CPUPlace(); dev_ = phi::CPUPlace();
...@@ -62,12 +63,12 @@ void U2Nnet::LoadModel(const std::string& model_path_w_prefix) { ...@@ -62,12 +63,12 @@ void U2Nnet::LoadModel(const std::string& model_path_w_prefix) {
} }
void U2Nnet::Warmup() { void U2Nnet::Warmup() {
#ifdef USE_PROFILING #ifdef WITH_PROFILING
RecordEvent event("warmup", TracerEventType::UserDefined, 1); RecordEvent event("warmup", TracerEventType::UserDefined, 1);
#endif #endif
{ {
#ifdef USE_PROFILING #ifdef WITH_PROFILING
RecordEvent event( RecordEvent event(
"warmup-encoder-ctc", TracerEventType::UserDefined, 1); "warmup-encoder-ctc", TracerEventType::UserDefined, 1);
#endif #endif
...@@ -91,7 +92,7 @@ void U2Nnet::Warmup() { ...@@ -91,7 +92,7 @@ void U2Nnet::Warmup() {
} }
{ {
#ifdef USE_PROFILING #ifdef WITH_PROFILING
RecordEvent event("warmup-decoder", TracerEventType::UserDefined, 1); RecordEvent event("warmup-decoder", TracerEventType::UserDefined, 1);
#endif #endif
auto hyps = auto hyps =
...@@ -101,10 +102,10 @@ void U2Nnet::Warmup() { ...@@ -101,10 +102,10 @@ void U2Nnet::Warmup() {
auto encoder_out = paddle::ones( auto encoder_out = paddle::ones(
{1, 20, 512}, paddle::DataType::FLOAT32, phi::CPUPlace()); {1, 20, 512}, paddle::DataType::FLOAT32, phi::CPUPlace());
std::vector<paddle::experimental::Tensor> inputs{ std::vector<paddle::Tensor> inputs{
hyps, hyps_lens, encoder_out}; hyps, hyps_lens, encoder_out};
std::vector<paddle::experimental::Tensor> outputs = std::vector<paddle::Tensor> outputs =
forward_attention_decoder_(inputs); forward_attention_decoder_(inputs);
} }
...@@ -118,27 +119,46 @@ U2Nnet::U2Nnet(const ModelOptions& opts) : opts_(opts) { ...@@ -118,27 +119,46 @@ U2Nnet::U2Nnet(const ModelOptions& opts) : opts_(opts) {
// shallow copy // shallow copy
U2Nnet::U2Nnet(const U2Nnet& other) { U2Nnet::U2Nnet(const U2Nnet& other) {
// copy meta // copy meta
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidecoder_ = other.is_bidecoder_;
chunk_size_ = other.chunk_size_; chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_; num_left_chunks_ = other.num_left_chunks_;
forward_encoder_chunk_ = other.forward_encoder_chunk_;
forward_attention_decoder_ = other.forward_attention_decoder_;
ctc_activation_ = other.ctc_activation_;
offset_ = other.offset_; offset_ = other.offset_;
// copy model ptr // copy model ptr
model_ = other.model_; // model_ = other.model_->Clone();
// hack, fix later
#ifdef WITH_GPU
dev_ = phi::GPUPlace();
#else
dev_ = phi::CPUPlace();
#endif
paddle::jit::Layer model = paddle::jit::Load(other.opts_.model_path, dev_);
model_ = std::make_shared<paddle::jit::Layer>(std::move(model));
ctc_activation_ = model_->Function("ctc_activation");
subsampling_rate_ = model_->Attribute<int>("subsampling_rate");
right_context_ = model_->Attribute<int>("right_context");
sos_ = model_->Attribute<int>("sos_symbol");
eos_ = model_->Attribute<int>("eos_symbol");
is_bidecoder_ = model_->Attribute<int>("is_bidirectional_decoder");
forward_encoder_chunk_ = model_->Function("forward_encoder_chunk");
forward_attention_decoder_ = model_->Function("forward_attention_decoder");
ctc_activation_ = model_->Function("ctc_activation");
CHECK(forward_encoder_chunk_.IsValid());
CHECK(forward_attention_decoder_.IsValid());
CHECK(ctc_activation_.IsValid());
LOG(INFO) << "Paddle Model Info: ";
LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
LOG(INFO) << "\tright context " << right_context_;
LOG(INFO) << "\tsos " << sos_;
LOG(INFO) << "\teos " << eos_;
LOG(INFO) << "\tis bidecoder " << is_bidecoder_ << std::endl;
// ignore inner states // ignore inner states
} }
std::shared_ptr<NnetBase> U2Nnet::Copy() const { std::shared_ptr<NnetBase> U2Nnet::Clone() const {
auto asr_model = std::make_shared<U2Nnet>(*this); auto asr_model = std::make_shared<U2Nnet>(*this);
// reset inner state for new decoding // reset inner state for new decoding
asr_model->Reset(); asr_model->Reset();
...@@ -154,6 +174,7 @@ void U2Nnet::Reset() { ...@@ -154,6 +174,7 @@ void U2Nnet::Reset() {
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32)); std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
encoder_outs_.clear(); encoder_outs_.clear();
VLOG(1) << "FeedForward cost: " << cost_time_ << " sec. ";
VLOG(3) << "u2nnet reset"; VLOG(3) << "u2nnet reset";
} }
...@@ -165,23 +186,18 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) { ...@@ -165,23 +186,18 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) {
} }
void U2Nnet::FeedForward(const kaldi::Vector<BaseFloat>& features, void U2Nnet::FeedForward(const std::vector<BaseFloat>& features,
const int32& feature_dim, const int32& feature_dim,
NnetOut* out) { NnetOut* out) {
kaldi::Timer timer; kaldi::Timer timer;
std::vector<kaldi::BaseFloat> chunk_feats(features.Data(),
features.Data() + features.Dim());
std::vector<kaldi::BaseFloat> ctc_probs; std::vector<kaldi::BaseFloat> ctc_probs;
ForwardEncoderChunkImpl( ForwardEncoderChunkImpl(
chunk_feats, feature_dim, &ctc_probs, &out->vocab_dim); features, feature_dim, &out->logprobs, &out->vocab_dim);
float forward_chunk_time = timer.Elapsed();
out->logprobs.Resize(ctc_probs.size(), kaldi::kSetZero); VLOG(1) << "FeedForward cost: " << forward_chunk_time << " sec. "
std::memcpy(out->logprobs.Data(), << features.size() / feature_dim << " frames.";
ctc_probs.data(), cost_time_ += forward_chunk_time;
ctc_probs.size() * sizeof(kaldi::BaseFloat));
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
<< chunk_feats.size() / feature_dim << " frames.";
} }
...@@ -190,7 +206,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -190,7 +206,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
const int32& feat_dim, const int32& feat_dim,
std::vector<kaldi::BaseFloat>* out_prob, std::vector<kaldi::BaseFloat>* out_prob,
int32* vocab_dim) { int32* vocab_dim) {
#ifdef USE_PROFILING #ifdef WITH_PROFILING
RecordEvent event( RecordEvent event(
"ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1); "ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1);
#endif #endif
...@@ -210,7 +226,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -210,7 +226,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
// not cache feature in nnet // not cache feature in nnet
CHECK_EQ(cached_feats_.size(), 0); CHECK_EQ(cached_feats_.size(), 0);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true); CHECK_EQ((std::is_same<float, kaldi::BaseFloat>::value), true);
std::memcpy(feats_ptr, std::memcpy(feats_ptr,
chunk_feats.data(), chunk_feats.data(),
chunk_feats.size() * sizeof(kaldi::BaseFloat)); chunk_feats.size() * sizeof(kaldi::BaseFloat));
...@@ -218,7 +234,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -218,7 +234,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
VLOG(3) << "feats shape: " << feats.shape()[0] << ", " << feats.shape()[1] VLOG(3) << "feats shape: " << feats.shape()[0] << ", " << feats.shape()[1]
<< ", " << feats.shape()[2]; << ", " << feats.shape()[2];
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("feat", std::ios_base::app | std::ios_base::out); std::stringstream path("feat", std::ios_base::app | std::ios_base::out);
path << offset_; path << offset_;
...@@ -237,7 +253,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -237,7 +253,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
#endif #endif
// Endocer chunk forward // Endocer chunk forward
#ifdef USE_GPU #ifdef WITH_GPU
feats = feats.copy_to(paddle::GPUPlace(), /*blocking*/ false); feats = feats.copy_to(paddle::GPUPlace(), /*blocking*/ false);
att_cache_ = att_cache_.copy_to(paddle::GPUPlace()), /*blocking*/ false; att_cache_ = att_cache_.copy_to(paddle::GPUPlace()), /*blocking*/ false;
cnn_cache_ = cnn_cache_.copy_to(Paddle::GPUPlace(), /*blocking*/ false); cnn_cache_ = cnn_cache_.copy_to(Paddle::GPUPlace(), /*blocking*/ false);
...@@ -254,7 +270,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -254,7 +270,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
std::vector<paddle::Tensor> outputs = forward_encoder_chunk_(inputs); std::vector<paddle::Tensor> outputs = forward_encoder_chunk_(inputs);
CHECK_EQ(outputs.size(), 3); CHECK_EQ(outputs.size(), 3);
#ifdef USE_GPU #ifdef WITH_GPU
paddle::Tensor chunk_out = outputs[0].copy_to(paddle::CPUPlace()); paddle::Tensor chunk_out = outputs[0].copy_to(paddle::CPUPlace());
att_cache_ = outputs[1].copy_to(paddle::CPUPlace()); att_cache_ = outputs[1].copy_to(paddle::CPUPlace());
cnn_cache_ = outputs[2].copy_to(paddle::CPUPlace()); cnn_cache_ = outputs[2].copy_to(paddle::CPUPlace());
...@@ -264,7 +280,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -264,7 +280,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
cnn_cache_ = outputs[2]; cnn_cache_ = outputs[2];
#endif #endif
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("encoder_logits", std::stringstream path("encoder_logits",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -294,7 +310,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -294,7 +310,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
encoder_outs_.push_back(chunk_out); encoder_outs_.push_back(chunk_out);
VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size(); VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size();
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("encoder_logits_list", std::stringstream path("encoder_logits_list",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -313,7 +329,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -313,7 +329,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
} }
#endif // end TEST_DEBUG #endif // end TEST_DEBUG
#ifdef USE_GPU #ifdef WITH_GPU
#error "Not implementation." #error "Not implementation."
...@@ -327,7 +343,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -327,7 +343,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
CHECK_EQ(outputs.size(), 1); CHECK_EQ(outputs.size(), 1);
paddle::Tensor ctc_log_probs = outputs[0]; paddle::Tensor ctc_log_probs = outputs[0];
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("encoder_logprob", std::stringstream path("encoder_logprob",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -349,7 +365,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -349,7 +365,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
} }
#endif // end TEST_DEBUG #endif // end TEST_DEBUG
#endif // end USE_GPU #endif // end WITH_GPU
// Copy to output, (B=1,T,D) // Copy to output, (B=1,T,D)
std::vector<int64_t> ctc_log_probs_shape = ctc_log_probs.shape(); std::vector<int64_t> ctc_log_probs_shape = ctc_log_probs.shape();
...@@ -366,7 +382,7 @@ void U2Nnet::ForwardEncoderChunkImpl( ...@@ -366,7 +382,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
std::memcpy( std::memcpy(
out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat)); out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat));
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("encoder_logits_list_ctc", std::stringstream path("encoder_logits_list_ctc",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -415,7 +431,7 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob, ...@@ -415,7 +431,7 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob,
void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight, float reverse_weight,
std::vector<float>* rescoring_score) { std::vector<float>* rescoring_score) {
#ifdef USE_PROFILING #ifdef WITH_PROFILING
RecordEvent event("AttentionRescoring", TracerEventType::UserDefined, 1); RecordEvent event("AttentionRescoring", TracerEventType::UserDefined, 1);
#endif #endif
CHECK(rescoring_score != nullptr); CHECK(rescoring_score != nullptr);
...@@ -457,7 +473,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -457,7 +473,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} }
} }
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("encoder_logits_concat", std::stringstream path("encoder_logits_concat",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -481,7 +497,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -481,7 +497,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
paddle::Tensor encoder_out = paddle::concat(encoder_outs_, 1); paddle::Tensor encoder_out = paddle::concat(encoder_outs_, 1);
VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size(); VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size();
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("encoder_out0", std::stringstream path("encoder_out0",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -500,7 +516,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -500,7 +516,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} }
#endif // end TEST_DEBUG #endif // end TEST_DEBUG
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("encoder_out", std::stringstream path("encoder_out",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -519,7 +535,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -519,7 +535,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} }
#endif // end TEST_DEBUG #endif // end TEST_DEBUG
std::vector<paddle::experimental::Tensor> inputs{ std::vector<paddle::Tensor> inputs{
hyps_tensor, hyps_lens, encoder_out}; hyps_tensor, hyps_lens, encoder_out};
std::vector<paddle::Tensor> outputs = forward_attention_decoder_(inputs); std::vector<paddle::Tensor> outputs = forward_attention_decoder_(inputs);
CHECK_EQ(outputs.size(), 2); CHECK_EQ(outputs.size(), 2);
...@@ -531,7 +547,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -531,7 +547,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
CHECK_EQ(probs_shape[0], num_hyps); CHECK_EQ(probs_shape[0], num_hyps);
CHECK_EQ(probs_shape[1], max_hyps_len); CHECK_EQ(probs_shape[1], max_hyps_len);
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("decoder_logprob", std::stringstream path("decoder_logprob",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -549,7 +565,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -549,7 +565,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} }
#endif // end TEST_DEBUG #endif // end TEST_DEBUG
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("hyps_lens", std::stringstream path("hyps_lens",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -565,7 +581,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -565,7 +581,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} }
#endif // end TEST_DEBUG #endif // end TEST_DEBUG
#ifdef TEST_DEBUG #ifdef PPS_DEBUG
{ {
std::stringstream path("hyps_tensor", std::stringstream path("hyps_tensor",
std::ios_base::app | std::ios_base::out); std::ios_base::app | std::ios_base::out);
...@@ -590,7 +606,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -590,7 +606,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} else { } else {
// dump r_probs // dump r_probs
CHECK_EQ(r_probs_shape.size(), 1); CHECK_EQ(r_probs_shape.size(), 1);
CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0]; //CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
} }
// compute rescoring score // compute rescoring score
...@@ -600,15 +616,15 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -600,15 +616,15 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
VLOG(2) << "split prob: " << probs_v.size() << " " VLOG(2) << "split prob: " << probs_v.size() << " "
<< probs_v[0].shape().size() << " 0: " << probs_v[0].shape()[0] << probs_v[0].shape().size() << " 0: " << probs_v[0].shape()[0]
<< ", " << probs_v[0].shape()[1] << ", " << probs_v[0].shape()[2]; << ", " << probs_v[0].shape()[1] << ", " << probs_v[0].shape()[2];
CHECK(static_cast<int>(probs_v.size()) == num_hyps) //CHECK(static_cast<int>(probs_v.size()) == num_hyps)
<< ": is " << probs_v.size() << " expect: " << num_hyps; // << ": is " << probs_v.size() << " expect: " << num_hyps;
std::vector<paddle::Tensor> r_probs_v; std::vector<paddle::Tensor> r_probs_v;
if (is_bidecoder_ && reverse_weight > 0) { if (is_bidecoder_ && reverse_weight > 0) {
r_probs_v = paddle::experimental::split_with_num(r_probs, num_hyps, 0); r_probs_v = paddle::experimental::split_with_num(r_probs, num_hyps, 0);
CHECK(static_cast<int>(r_probs_v.size()) == num_hyps) //CHECK(static_cast<int>(r_probs_v.size()) == num_hyps)
<< "r_probs_v size: is " << r_probs_v.size() // << "r_probs_v size: is " << r_probs_v.size()
<< " expect: " << num_hyps; // << " expect: " << num_hyps;
} }
for (int i = 0; i < num_hyps; ++i) { for (int i = 0; i < num_hyps; ++i) {
...@@ -638,7 +654,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps, ...@@ -638,7 +654,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
void U2Nnet::EncoderOuts( void U2Nnet::EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const { std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const {
// list of (B=1,T,D) // list of (B=1,T,D)
int size = encoder_outs_.size(); int size = encoder_outs_.size();
VLOG(3) << "encoder_outs_ size: " << size; VLOG(3) << "encoder_outs_ size: " << size;
...@@ -650,18 +666,18 @@ void U2Nnet::EncoderOuts( ...@@ -650,18 +666,18 @@ void U2Nnet::EncoderOuts(
const int& B = shape[0]; const int& B = shape[0];
const int& T = shape[1]; const int& T = shape[1];
const int& D = shape[2]; const int& D = shape[2];
CHECK(B == 1) << "Only support batch one."; //CHECK(B == 1) << "Only support batch one.";
VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << "," VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << ","
<< D << ")"; << D << ")";
const float* this_tensor_ptr = item.data<float>(); const float* this_tensor_ptr = item.data<float>();
for (int j = 0; j < T; j++) { for (int j = 0; j < T; j++) {
const float* cur = this_tensor_ptr + j * D; const float* cur = this_tensor_ptr + j * D;
kaldi::Vector<kaldi::BaseFloat> out(D); std::vector<kaldi::BaseFloat> out(D);
std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat)); std::memcpy(out.data(), cur, D * sizeof(kaldi::BaseFloat));
encoder_out->emplace_back(out); encoder_out->emplace_back(out);
} }
} }
} }
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h" #include "nnet/nnet_itf.h"
#include "paddle/extension.h" #include "paddle/extension.h"
#include "paddle/jit/all.h" #include "paddle/jit/all.h"
...@@ -42,7 +42,7 @@ class U2NnetBase : public NnetBase { ...@@ -42,7 +42,7 @@ class U2NnetBase : public NnetBase {
num_left_chunks_ = num_left_chunks; num_left_chunks_ = num_left_chunks;
} }
virtual std::shared_ptr<NnetBase> Copy() const = 0; virtual std::shared_ptr<NnetBase> Clone() const = 0;
protected: protected:
virtual void ForwardEncoderChunkImpl( virtual void ForwardEncoderChunkImpl(
...@@ -76,7 +76,7 @@ class U2Nnet : public U2NnetBase { ...@@ -76,7 +76,7 @@ class U2Nnet : public U2NnetBase {
explicit U2Nnet(const ModelOptions& opts); explicit U2Nnet(const ModelOptions& opts);
U2Nnet(const U2Nnet& other); U2Nnet(const U2Nnet& other);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features, void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim, const int32& feature_dim,
NnetOut* out) override; NnetOut* out) override;
...@@ -91,7 +91,7 @@ class U2Nnet : public U2NnetBase { ...@@ -91,7 +91,7 @@ class U2Nnet : public U2NnetBase {
std::shared_ptr<paddle::jit::Layer> model() const { return model_; } std::shared_ptr<paddle::jit::Layer> model() const { return model_; }
std::shared_ptr<NnetBase> Copy() const override; std::shared_ptr<NnetBase> Clone() const override;
void ForwardEncoderChunkImpl( void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats, const std::vector<kaldi::BaseFloat>& chunk_feats,
...@@ -111,10 +111,10 @@ class U2Nnet : public U2NnetBase { ...@@ -111,10 +111,10 @@ class U2Nnet : public U2NnetBase {
void FeedEncoderOuts(const paddle::Tensor& encoder_out); void FeedEncoderOuts(const paddle::Tensor& encoder_out);
void EncoderOuts( void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const; std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
ModelOptions opts_; // hack, fix later
private: private:
ModelOptions opts_;
phi::Place dev_; phi::Place dev_;
std::shared_ptr<paddle::jit::Layer> model_{nullptr}; std::shared_ptr<paddle::jit::Layer> model_{nullptr};
...@@ -127,6 +127,7 @@ class U2Nnet : public U2NnetBase { ...@@ -127,6 +127,7 @@ class U2Nnet : public U2NnetBase {
paddle::jit::Function forward_encoder_chunk_; paddle::jit::Function forward_encoder_chunk_;
paddle::jit::Function forward_attention_decoder_; paddle::jit::Function forward_attention_decoder_;
paddle::jit::Function ctc_activation_; paddle::jit::Function ctc_activation_;
float cost_time_ = 0.0;
}; };
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include "base/common.h" #include "base/common.h"
#include "decoder/param.h" #include "decoder/param.h"
#include "frontend/audio/assembler.h" #include "frontend/assembler.h"
#include "frontend/audio/data_cache.h" #include "frontend/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/u2_nnet.h" #include "nnet/u2_nnet.h"
......
...@@ -12,16 +12,28 @@ ...@@ -12,16 +12,28 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef USE_ONNX
#include "nnet/u2_nnet.h"
#else
#include "nnet/u2_onnx_nnet.h"
#endif
#include "base/common.h"
#include "decoder/param.h" #include "decoder/param.h"
#include "kaldi/feat/wave-reader.h" #include "frontend/feature_pipeline.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "recognizer/recognizer.h" #include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
#include "nnet/u2_nnet.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(wav_rspecifier, "", "test wav rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate"); DEFINE_int32(sample_rate, 16000, "sample rate");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:"); gflags::SetUsageMessage("Usage:");
...@@ -30,76 +42,104 @@ int main(int argc, char* argv[]) { ...@@ -30,76 +42,104 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler(); google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
ppspeech::RecognizerResource resource = int32 num_done = 0, num_err = 0;
ppspeech::RecognizerResource::InitFromFlags();
ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int sample_rate = FLAGS_sample_rate; int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk; float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate; int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
int32 num_done = 0, num_err = 0; CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
double tot_wav_duration = 0.0; CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0);
CHECK_GT(FLAGS_model_path.size(), 0);
LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier;
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
ppspeech::FeaturePipelineOptions feature_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
feature_opts.assembler_opts.fill_zero = false;
#ifndef USE_ONNX
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
#else
std::shared_ptr<ppspeech::U2OnnxNnet> nnet(new ppspeech::U2OnnxNnet(model_opts));
#endif
std::shared_ptr<ppspeech::FeaturePipeline> feature_pipeline(
new ppspeech::FeaturePipeline(feature_opts));
std::shared_ptr<ppspeech::NnetProducer> nnet_producer(
new ppspeech::NnetProducer(nnet, feature_pipeline));
kaldi::Timer timer; kaldi::Timer timer;
float tot_wav_duration = 0;
for (; !wav_reader.Done(); wav_reader.Next()) { for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key(); std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value(); const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0; int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel); this_channel);
int tot_samples = waveform.Dim(); int tot_samples = waveform.Dim();
tot_wav_duration += tot_samples * 1.0 / sample_rate;
LOG(INFO) << "wav len (sample): " << tot_samples; LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0; int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats; kaldi::Timer timer;
int feature_rows = 0;
while (sample_offset < tot_samples) { while (sample_offset < tot_samples) {
int cur_chunk_size = int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset); std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size); std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) { for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i); wav_chunk[i] = waveform(sample_offset + i);
} }
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer.Accept(wav_chunk); nnet_producer->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) { if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished(); nnet_producer->SetInputFinished();
} }
recognizer.Decode();
// no overlap // no overlap
sample_offset += cur_chunk_size; sample_offset += cur_chunk_size;
} }
CHECK(sample_offset == tot_samples);
std::string result;
result = recognizer.GetFinalResult(); std::vector<std::vector<kaldi::BaseFloat>> prob_vec;
recognizer.Reset(); while (1) {
if (result.empty()) { std::vector<kaldi::BaseFloat> logprobs;
// the TokenWriter can not write empty string. bool isok = nnet_producer->Read(&logprobs);
++num_err; if (nnet_producer->IsFinished()) break;
KALDI_LOG << " the result of " << utt << " is empty"; if (isok == false) continue;
continue; prob_vec.push_back(logprobs);
} }
KALDI_LOG << " the result of " << utt << " is " << result; {
result_writer.Write(utt, result); // writer nnet output
++num_done; kaldi::MatrixIndexT nrow = prob_vec.size();
kaldi::MatrixIndexT ncol = prob_vec[0].size();
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> nnet_out(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx];
}
}
nnet_out_writer.Write(utt, nnet_out);
}
nnet_producer->Reset();
} }
nnet_producer->Wait();
double elapsed = timer.Elapsed(); double elapsed = timer.Elapsed();
KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "Program cost:" << elapsed << " sec";
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s"; LOG(INFO) << "Done " << num_done << " utterances, " << num_err
KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration; << " with errors.";
return (num_done != 0 ? 0 : 1);
} }
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.cc
#include "nnet/u2_onnx_nnet.h"
#include "common/base/config.h"
namespace ppspeech {
void U2OnnxNnet::LoadModel(const std::string& model_dir) {
std::string encoder_onnx_path = model_dir + "/encoder.onnx";
std::string rescore_onnx_path = model_dir + "/decoder.onnx";
std::string ctc_onnx_path = model_dir + "/ctc.onnx";
std::string param_path = model_dir + "/param.onnx";
// 1. Load sessions
try {
encoder_ = std::make_shared<fastdeploy::Runtime>();
ctc_ = std::make_shared<fastdeploy::Runtime>();
rescore_ = std::make_shared<fastdeploy::Runtime>();
fastdeploy::RuntimeOption runtime_option;
runtime_option.UseOrtBackend();
runtime_option.UseCpu();
runtime_option.SetCpuThreadNum(1);
runtime_option.SetModelPath(encoder_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(encoder_->Init(runtime_option));
runtime_option.SetModelPath(rescore_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(rescore_->Init(runtime_option));
runtime_option.SetModelPath(ctc_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(ctc_->Init(runtime_option));
} catch (std::exception const& e) {
LOG(ERROR) << "error when load onnx model: " << e.what();
exit(0);
}
Config conf(param_path);
encoder_output_size_ = conf.Read("output_size", encoder_output_size_);
num_blocks_ = conf.Read("num_blocks", num_blocks_);
head_ = conf.Read("head", head_);
cnn_module_kernel_ = conf.Read("cnn_module_kernel", cnn_module_kernel_);
subsampling_rate_ = conf.Read("subsampling_rate", subsampling_rate_);
right_context_ = conf.Read("right_context", right_context_);
sos_= conf.Read("sos_symbol", sos_);
eos_= conf.Read("eos_symbol", eos_);
is_bidecoder_= conf.Read("is_bidirectional_decoder", is_bidecoder_);
chunk_size_= conf.Read("chunk_size", chunk_size_);
num_left_chunks_ = conf.Read("left_chunks", num_left_chunks_);
LOG(INFO) << "Onnx Model Info:";
LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
LOG(INFO) << "\tnum_blocks " << num_blocks_;
LOG(INFO) << "\thead " << head_;
LOG(INFO) << "\tcnn_module_kernel " << cnn_module_kernel_;
LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
LOG(INFO) << "\tright_context " << right_context_;
LOG(INFO) << "\tsos " << sos_;
LOG(INFO) << "\teos " << eos_;
LOG(INFO) << "\tis bidirectional decoder " << is_bidecoder_;
LOG(INFO) << "\tchunk_size " << chunk_size_;
LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_;
// 3. Read model nodes
LOG(INFO) << "Onnx Encoder:";
GetInputOutputInfo(encoder_, &encoder_in_names_, &encoder_out_names_);
LOG(INFO) << "Onnx CTC:";
GetInputOutputInfo(ctc_, &ctc_in_names_, &ctc_out_names_);
LOG(INFO) << "Onnx Rescore:";
GetInputOutputInfo(rescore_, &rescore_in_names_, &rescore_out_names_);
}
U2OnnxNnet::U2OnnxNnet(const ModelOptions& opts) : opts_(opts) {
LoadModel(opts_.model_path);
}
// shallow copy
U2OnnxNnet::U2OnnxNnet(const U2OnnxNnet& other) {
// metadatas
encoder_output_size_ = other.encoder_output_size_;
num_blocks_ = other.num_blocks_;
head_ = other.head_;
cnn_module_kernel_ = other.cnn_module_kernel_;
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidecoder_ = other.is_bidecoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
offset_ = other.offset_;
// session
encoder_ = other.encoder_;
ctc_ = other.ctc_;
rescore_ = other.rescore_;
// node names
encoder_in_names_ = other.encoder_in_names_;
encoder_out_names_ = other.encoder_out_names_;
ctc_in_names_ = other.ctc_in_names_;
ctc_out_names_ = other.ctc_out_names_;
rescore_in_names_ = other.rescore_in_names_;
rescore_out_names_ = other.rescore_out_names_;
}
void U2OnnxNnet::GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
std::vector<std::string>* in_names, std::vector<std::string>* out_names) {
std::vector<fastdeploy::TensorInfo> inputs_info = runtime->GetInputInfos();
(*in_names).resize(inputs_info.size());
for (int i = 0; i < inputs_info.size(); ++i){
fastdeploy::TensorInfo info = inputs_info[i];
std::stringstream shape;
for(int j = 0; j < info.shape.size(); ++j){
shape << info.shape[j];
shape << " ";
}
LOG(INFO) << "\tInput " << i << " : name=" << info.name << " type=" << info.dtype
<< " dims=" << shape.str();
(*in_names)[i] = info.name;
}
std::vector<fastdeploy::TensorInfo> outputs_info = runtime->GetOutputInfos();
(*out_names).resize(outputs_info.size());
for (int i = 0; i < outputs_info.size(); ++i){
fastdeploy::TensorInfo info = outputs_info[i];
std::stringstream shape;
for(int j = 0; j < info.shape.size(); ++j){
shape << info.shape[j];
shape << " ";
}
LOG(INFO) << "\tOutput " << i << " : name=" << info.name << " type=" << info.dtype
<< " dims=" << shape.str();
(*out_names)[i] = info.name;
}
}
std::shared_ptr<NnetBase> U2OnnxNnet::Clone() const {
auto asr_model = std::make_shared<U2OnnxNnet>(*this);
// reset inner state for new decoding
asr_model->Reset();
return asr_model;
}
void U2OnnxNnet::Reset() {
offset_ = 0;
encoder_outs_.clear();
cached_feats_.clear();
// Reset att_cache
if (num_left_chunks_ > 0) {
int required_cache_size = chunk_size_ * num_left_chunks_;
offset_ = required_cache_size;
att_cache_.resize(num_blocks_ * head_ * required_cache_size *
encoder_output_size_ / head_ * 2,
0.0);
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, required_cache_size,
encoder_output_size_ / head_ * 2};
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
} else {
att_cache_.resize(0, 0.0);
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, 0,
encoder_output_size_ / head_ * 2};
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
}
// Reset cnn_cache
cnn_cache_.resize(
num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0);
const std::vector<int64_t> cnn_cache_shape = {num_blocks_, 1, encoder_output_size_,
cnn_module_kernel_ - 1};
cnn_cache_ort_.SetExternalData(cnn_cache_shape, fastdeploy::FDDataType::FP32, cnn_cache_.data());
}
void U2OnnxNnet::FeedForward(const std::vector<BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) {
kaldi::Timer timer;
std::vector<kaldi::BaseFloat> ctc_probs;
ForwardEncoderChunkImpl(
features, feature_dim, &out->logprobs, &out->vocab_dim);
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
<< features.size() / feature_dim << " frames.";
}
void U2OnnxNnet::ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* out_prob,
int32* vocab_dim) {
// 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
// chunk
int num_frames = chunk_feats.size() / feat_dim;
VLOG(3) << "num_frames: " << num_frames;
VLOG(3) << "feat_dim: " << feat_dim;
const int feature_dim = feat_dim;
std::vector<float> feats;
feats.insert(feats.end(), chunk_feats.begin(), chunk_feats.end());
fastdeploy::FDTensor feats_ort;
const std::vector<int64_t> feats_shape = {1, num_frames, feature_dim};
feats_ort.SetExternalData(feats_shape, fastdeploy::FDDataType::FP32, feats.data());
// offset
int64_t offset_int64 = static_cast<int64_t>(offset_);
fastdeploy::FDTensor offset_ort;
offset_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &offset_int64);
// required_cache_size
int64_t required_cache_size = chunk_size_ * num_left_chunks_;
fastdeploy::FDTensor required_cache_size_ort("");
required_cache_size_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &required_cache_size);
// att_mask
fastdeploy::FDTensor att_mask_ort;
std::vector<uint8_t> att_mask(required_cache_size + chunk_size_, 1);
if (num_left_chunks_ > 0) {
int chunk_idx = offset_ / chunk_size_ - num_left_chunks_;
if (chunk_idx < num_left_chunks_) {
for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) {
att_mask[i] = 0;
}
}
const std::vector<int64_t> att_mask_shape = {1, 1, required_cache_size + chunk_size_};
att_mask_ort.SetExternalData(att_mask_shape, fastdeploy::FDDataType::BOOL, reinterpret_cast<bool*>(att_mask.data()));
}
// 2. Encoder chunk forward
std::vector<fastdeploy::FDTensor> inputs(encoder_in_names_.size());
for (int i = 0; i < encoder_in_names_.size(); ++i) {
std::string name = encoder_in_names_[i];
if (!strcmp(name.data(), "chunk")) {
inputs[i] = std::move(feats_ort);
inputs[i].name = "chunk";
} else if (!strcmp(name.data(), "offset")) {
inputs[i] = std::move(offset_ort);
inputs[i].name = "offset";
} else if (!strcmp(name.data(), "required_cache_size")) {
inputs[i] = std::move(required_cache_size_ort);
inputs[i].name = "required_cache_size";
} else if (!strcmp(name.data(), "att_cache")) {
inputs[i] = std::move(att_cache_ort_);
inputs[i].name = "att_cache";
} else if (!strcmp(name.data(), "cnn_cache")) {
inputs[i] = std::move(cnn_cache_ort_);
inputs[i].name = "cnn_cache";
} else if (!strcmp(name.data(), "att_mask")) {
inputs[i] = std::move(att_mask_ort);
inputs[i].name = "att_mask";
}
}
std::vector<fastdeploy::FDTensor> ort_outputs;
assert(encoder_->Infer(inputs, &ort_outputs));
offset_ += static_cast<int>(ort_outputs[0].shape[1]);
att_cache_ort_ = std::move(ort_outputs[1]);
cnn_cache_ort_ = std::move(ort_outputs[2]);
std::vector<fastdeploy::FDTensor> ctc_inputs;
ctc_inputs.emplace_back(std::move(ort_outputs[0]));
// ctc_inputs[0] = std::move(ort_outputs[0]);
ctc_inputs[0].name = ctc_in_names_[0];
std::vector<fastdeploy::FDTensor> ctc_ort_outputs;
assert(ctc_->Infer(ctc_inputs, &ctc_ort_outputs));
encoder_outs_.emplace_back(std::move(ctc_inputs[0])); // *****
float* logp_data = reinterpret_cast<float*>(ctc_ort_outputs[0].Data());
// Copy to output, (B=1,T,D)
std::vector<int64_t> ctc_log_probs_shape = ctc_ort_outputs[0].shape;
CHECK_EQ(ctc_log_probs_shape.size(), 3);
int B = ctc_log_probs_shape[0];
CHECK_EQ(B, 1);
int T = ctc_log_probs_shape[1];
int D = ctc_log_probs_shape[2];
*vocab_dim = D;
out_prob->resize(T * D);
std::memcpy(
out_prob->data(), logp_data, T * D * sizeof(kaldi::BaseFloat));
return;
}
float U2OnnxNnet::ComputeAttentionScore(const float* prob,
const std::vector<int>& hyp, int eos,
int decode_out_len) {
float score = 0.0f;
for (size_t j = 0; j < hyp.size(); ++j) {
score += *(prob + j * decode_out_len + hyp[j]);
}
score += *(prob + hyp.size() * decode_out_len + eos);
return score;
}
void U2OnnxNnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
rescoring_score->resize(num_hyps, 0.0f);
if (num_hyps == 0) {
return;
}
// No encoder output
if (encoder_outs_.size() == 0) {
return;
}
std::vector<int64_t> hyps_lens;
int max_hyps_len = 0;
for (size_t i = 0; i < num_hyps; ++i) {
int length = hyps[i].size() + 1;
max_hyps_len = std::max(length, max_hyps_len);
hyps_lens.emplace_back(static_cast<int64_t>(length));
}
std::vector<float> rescore_input;
int encoder_len = 0;
for (int i = 0; i < encoder_outs_.size(); i++) {
float* encoder_outs_data = reinterpret_cast<float*>(encoder_outs_[i].Data());
for (int j = 0; j < encoder_outs_[i].Numel(); j++) {
rescore_input.emplace_back(encoder_outs_data[j]);
}
encoder_len += encoder_outs_[i].shape[1];
}
std::vector<int64_t> hyps_pad;
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
hyps_pad.emplace_back(sos_);
size_t j = 0;
for (; j < hyp.size(); ++j) {
hyps_pad.emplace_back(hyp[j]);
}
if (j == max_hyps_len - 1) {
continue;
}
for (; j < max_hyps_len - 1; ++j) {
hyps_pad.emplace_back(0);
}
}
const std::vector<int64_t> hyps_pad_shape = {num_hyps, max_hyps_len};
const std::vector<int64_t> hyps_lens_shape = {num_hyps};
const std::vector<int64_t> decode_input_shape = {1, encoder_len, encoder_output_size_};
fastdeploy::FDTensor hyps_pad_tensor_;
hyps_pad_tensor_.SetExternalData(hyps_pad_shape, fastdeploy::FDDataType::INT64, hyps_pad.data());
fastdeploy::FDTensor hyps_lens_tensor_;
hyps_lens_tensor_.SetExternalData(hyps_lens_shape, fastdeploy::FDDataType::INT64, hyps_lens.data());
fastdeploy::FDTensor decode_input_tensor_;
decode_input_tensor_.SetExternalData(decode_input_shape, fastdeploy::FDDataType::FP32, rescore_input.data());
std::vector<fastdeploy::FDTensor> rescore_inputs(3);
rescore_inputs[0] = std::move(hyps_pad_tensor_);
rescore_inputs[0].name = rescore_in_names_[0];
rescore_inputs[1] = std::move(hyps_lens_tensor_);
rescore_inputs[1].name = rescore_in_names_[1];
rescore_inputs[2] = std::move(decode_input_tensor_);
rescore_inputs[2].name = rescore_in_names_[2];
std::vector<fastdeploy::FDTensor> rescore_outputs;
assert(rescore_->Infer(rescore_inputs, &rescore_outputs));
float* decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[0].Data());
float* r_decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[1].Data());
int decode_out_len = rescore_outputs[0].shape[2];
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
float score = 0.0f;
// left to right decoder score
score = ComputeAttentionScore(
decoder_outs_data + max_hyps_len * decode_out_len * i, hyp, eos_,
decode_out_len);
// Optional: Used for right to left score
float r_score = 0.0f;
if (is_bidecoder_ && reverse_weight > 0) {
std::vector<int> r_hyp(hyp.size());
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
// right to left decoder score
r_score = ComputeAttentionScore(
r_decoder_outs_data + max_hyps_len * decode_out_len * i, r_hyp, eos_,
decode_out_len);
}
// combined left-to-right and right-to-left score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
}
}
void U2OnnxNnet::EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const {
}
} //namepace ppspeech
\ No newline at end of file
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -11,87 +12,86 @@ ...@@ -11,87 +12,86 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h
#pragma once #pragma once
#include <numeric>
#include "base/common.h" #include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h" #include "nnet/nnet_itf.h"
#include "paddle_inference_api.h" #include "nnet/u2_nnet.h"
#include "fastdeploy/runtime.h"
namespace ppspeech { namespace ppspeech {
class U2OnnxNnet : public U2NnetBase {
template <typename T>
class Tensor {
public: public:
Tensor() {} explicit U2OnnxNnet(const ModelOptions& opts);
explicit Tensor(const std::vector<int>& shape) : _shape(shape) { U2OnnxNnet(const U2OnnxNnet& other);
int neml = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
LOG(INFO) << "Tensor neml: " << neml;
_data.resize(neml, 0);
}
void reshape(const std::vector<int>& shape) {
_shape = shape;
int neml = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
_data.resize(neml, 0);
}
const std::vector<int>& get_shape() const { return _shape; }
std::vector<T>& get_data() { return _data; }
private:
std::vector<int> _shape;
std::vector<T> _data;
};
class PaddleNnet : public NnetBase { void FeedForward(const std::vector<kaldi::BaseFloat>& features,
public:
explicit PaddleNnet(const ModelOptions& opts);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
const int32& feature_dim, const int32& feature_dim,
NnetOut* out) override; NnetOut* out) override;
void AttentionRescoring(const std::vector<std::vector<int>>& hyps, void Reset() override;
float reverse_weight,
std::vector<float>* rescoring_score) override { bool IsLogProb() override { return true; }
VLOG(2) << "deepspeech2 not has AttentionRescoring.";
}
void Dim(); void Dim();
void Reset() override; void LoadModel(const std::string& model_dir);
bool IsLogProb() override { return false; } std::shared_ptr<NnetBase> Clone() const override;
void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) override;
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder( float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
const std::string& name); int eos, int decode_out_len);
void InitCacheEncouts(const ModelOptions& opts); void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
void EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) void EncoderOuts(
const override {} std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
void GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
std::vector<std::string>* in_names,
std::vector<std::string>* out_names);
private: private:
paddle_infer::Predictor* GetPredictor(); ModelOptions opts_;
int ReleasePredictor(paddle_infer::Predictor* predictor);
std::unique_ptr<paddle_infer::services::PredictorPool> pool; int encoder_output_size_ = 0;
std::vector<bool> pool_usages; int num_blocks_ = 0;
std::mutex pool_mutex; int cnn_module_kernel_ = 0;
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id; int head_ = 0;
std::map<std::string, int> cache_names_idx_;
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
ModelOptions opts_; // sessions
std::shared_ptr<fastdeploy::Runtime> encoder_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> rescore_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> ctc_ = nullptr;
public:
DISALLOW_COPY_AND_ASSIGN(PaddleNnet); // node names
std::vector<std::string> encoder_in_names_, encoder_out_names_;
std::vector<std::string> ctc_in_names_, ctc_out_names_;
std::vector<std::string> rescore_in_names_, rescore_out_names_;
// caches
fastdeploy::FDTensor att_cache_ort_;
fastdeploy::FDTensor cnn_cache_ort_;
std::vector<fastdeploy::FDTensor> encoder_outs_;
std::vector<float> att_cache_;
std::vector<float> cnn_cache_;
}; };
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
set(srcs)
list(APPEND srcs
recognizer_controller.cc
recognizer_controller_impl.cc
recognizer_instance.cc
recognizer.cc
)
add_library(recognizer STATIC ${srcs})
target_link_libraries(recognizer PUBLIC decoder)
set(TEST_BINS
recognizer_batch_main
recognizer_batch_main2
recognizer_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
endforeach()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -13,58 +13,34 @@ ...@@ -13,58 +13,34 @@
// limitations under the License. // limitations under the License.
#include "recognizer/recognizer.h" #include "recognizer/recognizer.h"
#include "recognizer/recognizer_instance.h"
namespace ppspeech { bool InitRecognizer(const std::string& model_file,
const std::string& word_symbol_table_file,
using kaldi::BaseFloat; const std::string& fst_file,
using kaldi::SubVector; int num_instance) {
using kaldi::Vector; return ppspeech::RecognizerInstance::GetInstance().Init(model_file,
using kaldi::VectorBase; word_symbol_table_file,
using std::unique_ptr; fst_file,
using std::vector; num_instance);
Recognizer::Recognizer(const RecognizerResource& resource) {
// resource_ = resource;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
BaseFloat ac_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
decoder_.reset(new TLGDecoder(resource.tlg_opts));
input_finished_ = false;
}
void Recognizer::Accept(const Vector<BaseFloat>& waves) {
feature_pipeline_->Accept(waves);
} }
void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); } int GetRecognizerInstanceId() {
return ppspeech::RecognizerInstance::GetInstance().GetRecognizerInstanceId();
std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
} }
std::string Recognizer::GetPartialResult() { void InitDecoder(int instance_id) {
return decoder_->GetPartialResult(); return ppspeech::RecognizerInstance::GetInstance().InitDecoder(instance_id);
} }
void Recognizer::SetFinished() { void AcceptData(const std::vector<float>& waves, int instance_id) {
feature_pipeline_->SetFinished(); return ppspeech::RecognizerInstance::GetInstance().Accept(waves, instance_id);
input_finished_ = true;
} }
bool Recognizer::IsFinished() { return input_finished_; } void SetInputFinished(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().SetInputFinished(instance_id);
void Recognizer::Reset() {
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
} }
} // namespace ppspeech std::string GetFinalResult(int instance_id) {
\ No newline at end of file return ppspeech::RecognizerInstance::GetInstance().GetResult(instance_id);
}
\ No newline at end of file
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
bool InitRecognizer(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance);
int GetRecognizerInstanceId();
void InitDecoder(int instance_id);
void AcceptData(const std::vector<float>& waves, int instance_id);
void SetInputFinished(int instance_id);
std::string GetFinalResult(int instance_id);
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer_controller.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(njob, 3, "njob");
using std::string;
using std::vector;
void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists,
int njob) {
vector<string> wavlist;
wavlists->resize(njob);
uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
string utt_str = wavlist[idx];
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]);
}
}
void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
std::vector<string> wavlist,
std::vector<string> uttlist,
std::vector<string>* results) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return;
results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
std::string utt = uttlist[idx];
std::string wav_file = wavlist[idx];
std::ifstream infile;
infile.open(wav_file, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 recog_id = -1;
while (recog_id == -1) {
recog_id = recognizer_controller->GetRecognizerInstanceId();
}
recognizer_controller->InitDecoder(recog_id);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
recognizer_controller->Accept(wav_chunk, recog_id);
// no overlap
sample_offset += cur_chunk_size;
}
recognizer_controller->SetInputFinished(recog_id);
CHECK(sample_offset == tot_samples);
std::string result = recognizer_controller->GetFinalResult(recog_id);
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
result = " ";
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
results->push_back(result);
++num_done;
}
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int njob = FLAGS_njob;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::RecognizerResource resource =
ppspeech::RecognizerResource::InitFromFlags();
ppspeech::RecognizerController recognizer_controller(njob, resource);
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;
vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist;
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func,
&recognizer_controller,
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f));
}
for (size_t i = 0; i < njob; ++i) {
futurelist[i].get();
}
for (size_t idx = 0; idx < njob; ++idx) {
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
string utt = uttlist[idx][utt_idx];
string result = resultlist[idx][utt_idx];
result_writer.Write(utt, result);
}
}
return 0;
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(njob, 3, "njob");
using std::string;
using std::vector;
void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists,
int njob) {
vector<string> wavlist;
wavlists->resize(njob);
uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
string utt_str = wavlist[idx];
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]);
}
}
void recognizer_func(std::vector<string> wavlist,
std::vector<string> uttlist,
std::vector<string>* results) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return;
results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
std::string utt = uttlist[idx];
std::string wav_file = wavlist[idx];
std::ifstream infile;
infile.open(wav_file, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 recog_id = -1;
while (recog_id == -1) {
recog_id = GetRecognizerInstanceId();
}
InitDecoder(recog_id);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
AcceptData(wav_chunk, recog_id);
// no overlap
sample_offset += cur_chunk_size;
}
SetInputFinished(recog_id);
CHECK(sample_offset == tot_samples);
std::string result = GetFinalResult(recog_id);
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
result = " ";
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
results->push_back(result);
++num_done;
}
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int njob = FLAGS_njob;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob);
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;
vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist;
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func,
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f));
}
for (size_t i = 0; i < njob; ++i) {
futurelist[i].get();
}
for (size_t idx = 0; idx < njob; ++idx) {
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
string utt = uttlist[idx][utt_idx];
string result = resultlist[idx][utt_idx];
result_writer.Write(utt, result);
}
}
return 0;
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer_controller.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) {
recognizer_workers.resize(num_worker);
for (size_t i = 0; i < num_worker; ++i) {
recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource));
waiting_workers.push(i);
}
}
int RecognizerController::GetRecognizerInstanceId() {
if (waiting_workers.empty()) {
return -1;
}
int idx = -1;
{
std::unique_lock<std::mutex> lock(mutex_);
idx = waiting_workers.front();
waiting_workers.pop();
}
return idx;
}
RecognizerController::~RecognizerController() {
for (size_t i = 0; i < recognizer_workers.size(); ++i) {
recognizer_workers[i]->WaitFinished();
}
}
void RecognizerController::InitDecoder(int idx) {
recognizer_workers[idx]->InitDecoder();
}
std::string RecognizerController::GetFinalResult(int idx) {
recognizer_workers[idx]->WaitDecoderFinished();
recognizer_workers[idx]->AttentionRescoring();
std::string result = recognizer_workers[idx]->GetFinalResult();
{
std::unique_lock<std::mutex> lock(mutex_);
waiting_workers.push(idx);
}
return result;
}
void RecognizerController::Accept(std::vector<float> data, int idx) {
recognizer_workers[idx]->Accept(data);
}
void RecognizerController::SetInputFinished(int idx) {
recognizer_workers[idx]->SetInputFinished();
}
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include <memory>
#include "recognizer/recognizer_controller_impl.h"
namespace ppspeech {
class RecognizerController {
public:
explicit RecognizerController(int num_worker, RecognizerResource resource);
~RecognizerController();
int GetRecognizerInstanceId();
void InitDecoder(int idx);
void Accept(std::vector<float> data, int idx);
void SetInputFinished(int idx);
std::string GetFinalResult(int idx);
private:
std::queue<int> waiting_workers;
std::mutex mutex_;
std::vector<std::unique_ptr<ppspeech::RecognizerControllerImpl>> recognizer_workers;
DISALLOW_COPY_AND_ASSIGN(RecognizerController);
};
}
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,86 +12,180 @@ ...@@ -12,86 +12,180 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "recognizer/u2_recognizer.h" #include "recognizer/recognizer_controller_impl.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "nnet/u2_nnet.h" #include "common/utils/strings.h"
namespace ppspeech { namespace ppspeech {
using kaldi::BaseFloat; RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource)
using kaldi::SubVector; : opts_(resource) {
using kaldi::Vector; BaseFloat am_scale = resource.acoustic_scale;
using kaldi::VectorBase; BaseFloat blank_threshold = resource.blank_threshold;
using std::unique_ptr;
using std::vector;
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
: opts_(resource) {
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts)); std::shared_ptr<FeaturePipeline> feature_pipeline(
new FeaturePipeline(feature_opts));
std::shared_ptr<NnetBase> nnet;
#ifndef USE_ONNX
nnet = resource.nnet->Clone();
#else
if (resource.model_opts.with_onnx_model){
nnet.reset(new U2OnnxNnet(resource.model_opts));
} else {
nnet = resource.nnet->Clone();
}
#endif
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline, blank_threshold));
nnet_thread_ = std::thread(RunNnetEvaluation, this);
decodable_.reset(new Decodable(nnet_producer_, am_scale));
if (resource.decoder_opts.tlg_decoder_opts.fst_path.empty()) {
LOG(INFO) << "Init PrefixBeamSearch Decoder";
decoder_ = std::make_unique<CTCPrefixBeamSearch>(
resource.decoder_opts.ctc_prefix_search_opts);
} else {
LOG(INFO) << "Init TLGDecoder";
decoder_ = std::make_unique<TLGDecoder>(
resource.decoder_opts.tlg_decoder_opts);
}
std::shared_ptr<NnetBase> nnet(new U2Nnet(resource.model_opts)); symbol_table_ = decoder_->WordSymbolTable();
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
result_.clear();
}
BaseFloat am_scale = resource.acoustic_scale; RecognizerControllerImpl::~RecognizerControllerImpl() {
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); WaitFinished();
}
CHECK_NE(resource.vocab_path, ""); void RecognizerControllerImpl::Reset() {
decoder_.reset(new CTCPrefixBeamSearch( nnet_producer_->Reset();
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); }
unit_table_ = decoder_->VocabTable(); void RecognizerControllerImpl::RunDecoder(RecognizerControllerImpl* me) {
symbol_table_ = unit_table_; me->RunDecoderInternal();
}
input_finished_ = false; void RecognizerControllerImpl::RunDecoderInternal() {
LOG(INFO) << "DecoderInternal begin";
while (!nnet_producer_->IsFinished()) {
nnet_condition_.notify_one();
decoder_->AdvanceDecode(decodable_);
}
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
LOG(INFO) << "DecoderInternal exit";
}
Reset(); void RecognizerControllerImpl::WaitDecoderFinished() {
if (decoder_thread_.joinable()) decoder_thread_.join();
} }
void U2Recognizer::Reset() { void RecognizerControllerImpl::RunNnetEvaluation(RecognizerControllerImpl* me) {
global_frame_offset_ = 0; me->RunNnetEvaluationInternal();
num_frames_ = 0; }
result_.clear();
decodable_->Reset(); void RecognizerControllerImpl::SetInputFinished() {
decoder_->Reset(); nnet_producer_->SetInputFinished();
nnet_condition_.notify_one();
LOG(INFO) << "Set Input Finished";
} }
void U2Recognizer::ResetContinuousDecoding() { void RecognizerControllerImpl::WaitFinished() {
global_frame_offset_ = num_frames_; abort_ = true;
LOG(INFO) << "nnet wait finished";
nnet_condition_.notify_one();
if (nnet_thread_.joinable()) {
nnet_thread_.join();
}
}
void RecognizerControllerImpl::RunNnetEvaluationInternal() {
bool result = false;
LOG(INFO) << "NnetEvaluationInteral begin";
while (!abort_) {
std::unique_lock<std::mutex> lock(nnet_mutex_);
nnet_condition_.wait(lock);
do {
result = nnet_producer_->Compute();
decoder_condition_.notify_one();
} while (result);
}
LOG(INFO) << "NnetEvaluationInteral exit";
}
void RecognizerControllerImpl::Accept(std::vector<float> data) {
nnet_producer_->Accept(data);
nnet_condition_.notify_one();
}
void RecognizerControllerImpl::InitDecoder() {
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0; num_frames_ = 0;
result_.clear(); result_.clear();
decodable_->Reset(); decodable_->Reset();
decoder_->Reset(); decoder_->Reset();
decoder_thread_ = std::thread(RunDecoder, this);
} }
void RecognizerControllerImpl::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(false);
void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) { // No need to do rescoring
kaldi::Timer timer; if (0.0 == opts_.decoder_opts.rescoring_weight) {
feature_pipeline_->Accept(waves); LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim() return;
<< " samples."; }
} LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
void U2Recognizer::Decode() { std::vector<float> rescoring_score;
decoder_->AdvanceDecode(decodable_); decodable_->AttentionRescoring(
UpdateResult(false); hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
}
void U2Recognizer::Rescoring() { // combine ctc score and rescoring score
// Do attention Rescoring for (size_t i = 0; i < num_hyps; i++) {
AttentionRescoring(); VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
} }
void U2Recognizer::UpdateResult(bool finish) { std::string RecognizerControllerImpl::GetFinalResult() { return result_[0].sentence; }
std::string RecognizerControllerImpl::GetPartialResult() { return result_[0].sentence; }
void RecognizerControllerImpl::UpdateResult(bool finish) {
const auto& hypotheses = decoder_->Outputs(); const auto& hypotheses = decoder_->Outputs();
const auto& inputs = decoder_->Inputs(); const auto& inputs = decoder_->Inputs();
const auto& likelihood = decoder_->Likelihood(); const auto& likelihood = decoder_->Likelihood();
const auto& times = decoder_->Times(); const auto& times = decoder_->Times();
result_.clear(); result_.clear();
CHECK_EQ(hypotheses.size(), likelihood.size()); CHECK_EQ(inputs.size(), likelihood.size());
for (size_t i = 0; i < hypotheses.size(); i++) { for (size_t i = 0; i < hypotheses.size(); i++) {
const std::vector<int>& hypothesis = hypotheses[i]; const std::vector<int>& hypothesis = hypotheses[i];
...@@ -99,21 +193,16 @@ void U2Recognizer::UpdateResult(bool finish) { ...@@ -99,21 +193,16 @@ void U2Recognizer::UpdateResult(bool finish) {
path.score = likelihood[i]; path.score = likelihood[i];
for (size_t j = 0; j < hypothesis.size(); j++) { for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]); std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in path.sentence += (" " + word);
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (decoder_->Type() == kWfstBeamSearch) {
path.sentence += (" " + word);
} else {
path.sentence += (word);
}
} }
path.sentence = DelBlank(path.sentence);
// TimeStamp is only supported in final result // TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we // various FST operations when building the decoding graph. So here we
// use time stamp of the input(e2e model unit), which is more accurate, // use time stamp of the input(e2e model unit), which is more accurate,
// and it requires the symbol table of the e2e model used in training. // and it requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) { if (symbol_table_ != nullptr && finish) {
int offset = global_frame_offset_ * FrameShiftInMs(); int offset = global_frame_offset_ * FrameShiftInMs();
const std::vector<int>& input = inputs[i]; const std::vector<int>& input = inputs[i];
...@@ -121,7 +210,7 @@ void U2Recognizer::UpdateResult(bool finish) { ...@@ -121,7 +210,7 @@ void U2Recognizer::UpdateResult(bool finish) {
CHECK_EQ(input.size(), time_stamp.size()); CHECK_EQ(input.size(), time_stamp.size());
for (size_t j = 0; j < input.size(); j++) { for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]); std::string word = symbol_table_->Find(input[j]);
int start = int start =
time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0 time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
...@@ -163,56 +252,4 @@ void U2Recognizer::UpdateResult(bool finish) { ...@@ -163,56 +252,4 @@ void U2Recognizer::UpdateResult(bool finish) {
} }
} }
void U2Recognizer::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(true);
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; }
void U2Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/common.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "nnet/u2_nnet.h"
#include "nnet/nnet_producer.h"
#ifdef USE_ONNX
#include "nnet/u2_onnx_nnet.h"
#endif
#include "nnet/decodable.h"
#include "recognizer/recognizer_resource.h"
#include <memory>
namespace ppspeech {
class RecognizerControllerImpl {
public:
explicit RecognizerControllerImpl(const RecognizerResource& resource);
~RecognizerControllerImpl();
void Accept(std::vector<float> data);
void InitDecoder();
void SetInputFinished();
std::string GetFinalResult();
std::string GetPartialResult();
void Rescoring();
void Reset();
void WaitDecoderFinished();
void WaitFinished();
void AttentionRescoring();
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
}
int FrameShiftInMs() const {
return 1; //todo
}
private:
static void RunNnetEvaluation(RecognizerControllerImpl* me);
void RunNnetEvaluationInternal();
static void RunDecoder(RecognizerControllerImpl* me);
void RunDecoderInternal();
void UpdateResult(bool finish = false);
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<DecoderBase> decoder_;
std::shared_ptr<NnetProducer> nnet_producer_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
std::vector<DecodeResult> result_;
RecognizerResource opts_;
bool abort_ = false;
// global decoded frame offset
int global_frame_offset_;
// cur decoded frame num
int num_frames_;
// timestamp gap between words in a sentence
const int time_stamp_gap_ = 100;
bool input_finished_;
std::mutex nnet_mutex_;
std::mutex decoder_mutex_;
std::condition_variable nnet_condition_;
std::condition_variable decoder_condition_;
std::thread nnet_thread_;
std::thread decoder_thread_;
DISALLOW_COPY_AND_ASSIGN(RecognizerControllerImpl);
};
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer_instance.h"
namespace ppspeech {
RecognizerInstance& RecognizerInstance::GetInstance() {
static RecognizerInstance instance;
return instance;
}
bool RecognizerInstance::Init(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance) {
RecognizerResource resource = RecognizerResource::InitFromFlags();
resource.model_opts.model_path = model_file;
//resource.vocab_path = word_symbol_table_file;
if (!fst_file.empty()) {
resource.decoder_opts.tlg_decoder_opts.fst_path = fst_file;
resource.decoder_opts.tlg_decoder_opts.fst_path = word_symbol_table_file;
} else {
resource.decoder_opts.ctc_prefix_search_opts.word_symbol_table =
word_symbol_table_file;
}
recognizer_controller_ = std::make_unique<RecognizerController>(num_instance, resource);
return true;
}
void RecognizerInstance::InitDecoder(int idx) {
recognizer_controller_->InitDecoder(idx);
return;
}
int RecognizerInstance::GetRecognizerInstanceId() {
return recognizer_controller_->GetRecognizerInstanceId();
}
void RecognizerInstance::Accept(const std::vector<float>& waves, int idx) const {
recognizer_controller_->Accept(waves, idx);
return;
}
void RecognizerInstance::SetInputFinished(int idx) const {
recognizer_controller_->SetInputFinished(idx);
return;
}
std::string RecognizerInstance::GetResult(int idx) const {
return recognizer_controller_->GetFinalResult(idx);
}
}
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -15,36 +15,28 @@ ...@@ -15,36 +15,28 @@
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "frontend/audio/feature_common.h" #include "recognizer/recognizer_controller.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/feat/feature-fbank.h"
#include "kaldi/feat/feature-mfcc.h"
#include "kaldi/matrix/kaldi-vector.h"
namespace ppspeech { namespace ppspeech {
class FbankComputer { class RecognizerInstance {
public: public:
typedef kaldi::FbankOptions Options; static RecognizerInstance& GetInstance();
explicit FbankComputer(const Options& opts); RecognizerInstance() {}
~RecognizerInstance() {}
kaldi::FrameExtractionOptions& GetFrameOptions() { bool Init(const std::string& model_file,
return opts_.frame_opts; const std::string& word_symbol_table_file,
} const std::string& fst_file,
int num_instance);
bool Compute(kaldi::Vector<kaldi::BaseFloat>* window, int GetRecognizerInstanceId();
kaldi::Vector<kaldi::BaseFloat>* feat); void InitDecoder(int idx);
int32 Dim() const; void Accept(const std::vector<float>& waves, int idx) const;
void SetInputFinished(int idx) const;
bool NeedRawLogEnergy(); std::string GetResult(int idx) const;
private: private:
Options opts_; std::unique_ptr<RecognizerController> recognizer_controller_;
kaldi::FbankComputer computer_;
DISALLOW_COPY_AND_ASSIGN(FbankComputer);
}; };
typedef StreamingFeatureTpl<FbankComputer> Fbank;
} // namespace ppspeech } // namespace ppspeech
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#include "decoder/param.h" #include "decoder/param.h"
#include "kaldi/feat/wave-reader.h" #include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "recognizer/u2_recognizer.h" #include "recognizer/recognizer_controller.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
...@@ -31,6 +31,7 @@ int main(int argc, char* argv[]) { ...@@ -31,6 +31,7 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0; double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0; double tot_decode_time = 0.0;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
...@@ -44,11 +45,13 @@ int main(int argc, char* argv[]) { ...@@ -44,11 +45,13 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size; LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::U2RecognizerResource resource = ppspeech::RecognizerResource resource =
ppspeech::U2RecognizerResource::InitFromFlags(); ppspeech::RecognizerResource::InitFromFlags();
ppspeech::U2Recognizer recognizer(resource); std::shared_ptr<ppspeech::RecognizerControllerImpl> recognizer_ptr(
new ppspeech::RecognizerControllerImpl(resource));
for (; !wav_reader.Done(); wav_reader.Next()) { for (; !wav_reader.Done(); wav_reader.Next()) {
recognizer_ptr->InitDecoder();
std::string utt = wav_reader.Key(); std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value(); const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt; LOG(INFO) << "utt: " << utt;
...@@ -63,45 +66,32 @@ int main(int argc, char* argv[]) { ...@@ -63,45 +66,32 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "wav len (sample): " << tot_samples; LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0; int sample_offset = 0;
int cnt = 0;
kaldi::Timer timer;
kaldi::Timer local_timer; kaldi::Timer local_timer;
while (sample_offset < tot_samples) { while (sample_offset < tot_samples) {
int cur_chunk_size = int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset); std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size); std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) { for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i); wav_chunk[i] = waveform(sample_offset + i);
} }
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer.Accept(wav_chunk); recognizer_ptr->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
}
recognizer.Decode();
if (recognizer.DecodedSomething()) {
LOG(INFO) << "Pratial result: " << cnt << " "
<< recognizer.GetPartialResult();
}
// no overlap // no overlap
sample_offset += cur_chunk_size; sample_offset += cur_chunk_size;
cnt++;
} }
CHECK(sample_offset == tot_samples); CHECK(sample_offset == tot_samples);
recognizer_ptr->SetInputFinished();
recognizer_ptr->WaitDecoderFinished();
// second pass decoding kaldi::Timer timer;
recognizer.Rescoring(); recognizer_ptr->AttentionRescoring();
float rescore_time = timer.Elapsed();
tot_decode_time += timer.Elapsed(); tot_attention_rescore_time += rescore_time;
std::string result = recognizer.GetFinalResult();
recognizer.Reset();
std::string result = recognizer_ptr->GetFinalResult();
if (result.empty()) { if (result.empty()) {
// the TokenWriter can not write empty string. // the TokenWriter can not write empty string.
++num_err; ++num_err;
...@@ -109,17 +99,20 @@ int main(int argc, char* argv[]) { ...@@ -109,17 +99,20 @@ int main(int argc, char* argv[]) {
continue; continue;
} }
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result; LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed(); << " cost: " << local_timer.Elapsed() << " rescore:" << rescore_time;
result_writer.Write(utt, result); result_writer.Write(utt, result);
++num_done; ++num_done;
} }
recognizer_ptr->WaitFinished();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
} }
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "decoder/common.h"
#include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_decoder.h" #include "decoder/ctc_tlg_decoder.h"
#include "decoder/decoder_itf.h" #include "frontend/feature_pipeline.h"
#include "frontend/audio/feature_pipeline.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "nnet/decodable.h"
DECLARE_int32(nnet_decoder_chunk); DECLARE_int32(nnet_decoder_chunk);
DECLARE_int32(num_left_chunks); DECLARE_int32(num_left_chunks);
...@@ -30,9 +11,9 @@ DECLARE_double(rescoring_weight); ...@@ -30,9 +11,9 @@ DECLARE_double(rescoring_weight);
DECLARE_double(reverse_weight); DECLARE_double(reverse_weight);
DECLARE_int32(nbest); DECLARE_int32(nbest);
DECLARE_int32(blank); DECLARE_int32(blank);
DECLARE_double(acoustic_scale); DECLARE_double(acoustic_scale);
DECLARE_string(vocab_path); DECLARE_double(blank_threshold);
DECLARE_string(word_symbol_table);
namespace ppspeech { namespace ppspeech {
...@@ -59,6 +40,7 @@ struct DecodeOptions { ...@@ -59,6 +40,7 @@ struct DecodeOptions {
// CtcEndpointConfig ctc_endpoint_opts; // CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions ctc_prefix_search_opts{}; CTCBeamSearchOptions ctc_prefix_search_opts{};
TLGDecoderOptions tlg_decoder_opts{};
static DecodeOptions InitFromFlags() { static DecodeOptions InitFromFlags() {
DecodeOptions decoder_opts; DecodeOptions decoder_opts;
...@@ -70,6 +52,11 @@ struct DecodeOptions { ...@@ -70,6 +52,11 @@ struct DecodeOptions {
decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank; decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank;
decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest; decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest; decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
decoder_opts.ctc_prefix_search_opts.word_symbol_table =
FLAGS_word_symbol_table;
decoder_opts.tlg_decoder_opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
LOG(INFO) << "chunk_size: " << decoder_opts.chunk_size; LOG(INFO) << "chunk_size: " << decoder_opts.chunk_size;
LOG(INFO) << "num_left_chunks: " << decoder_opts.num_left_chunks; LOG(INFO) << "num_left_chunks: " << decoder_opts.num_left_chunks;
LOG(INFO) << "ctc_weight: " << decoder_opts.ctc_weight; LOG(INFO) << "ctc_weight: " << decoder_opts.ctc_weight;
...@@ -82,19 +69,20 @@ struct DecodeOptions { ...@@ -82,19 +69,20 @@ struct DecodeOptions {
} }
}; };
struct U2RecognizerResource { struct RecognizerResource {
// decodable opt
kaldi::BaseFloat acoustic_scale{1.0}; kaldi::BaseFloat acoustic_scale{1.0};
std::string vocab_path{}; kaldi::BaseFloat blank_threshold{0.98};
FeaturePipelineOptions feature_pipeline_opts{}; FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{}; ModelOptions model_opts{};
DecodeOptions decoder_opts{}; DecodeOptions decoder_opts{};
std::shared_ptr<NnetBase> nnet;
static U2RecognizerResource InitFromFlags() { static RecognizerResource InitFromFlags() {
U2RecognizerResource resource; RecognizerResource resource;
resource.vocab_path = FLAGS_vocab_path;
resource.acoustic_scale = FLAGS_acoustic_scale; resource.acoustic_scale = FLAGS_acoustic_scale;
LOG(INFO) << "vocab path: " << resource.vocab_path; resource.blank_threshold = FLAGS_blank_threshold;
LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale; LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale;
resource.feature_pipeline_opts = resource.feature_pipeline_opts =
...@@ -104,69 +92,17 @@ struct U2RecognizerResource { ...@@ -104,69 +92,17 @@ struct U2RecognizerResource {
<< resource.feature_pipeline_opts.assembler_opts.fill_zero; << resource.feature_pipeline_opts.assembler_opts.fill_zero;
resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags(); resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags();
#ifndef USE_ONNX
resource.nnet.reset(new U2Nnet(resource.model_opts));
#else
if (resource.model_opts.with_onnx_model){
resource.nnet.reset(new U2OnnxNnet(resource.model_opts));
} else {
resource.nnet.reset(new U2Nnet(resource.model_opts));
}
#endif
return resource; return resource;
} }
}; };
} //namespace ppspeech
class U2Recognizer { \ No newline at end of file
public:
explicit U2Recognizer(const U2RecognizerResource& resouce);
void Reset();
void ResetContinuousDecoding();
void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
void Decode();
void Rescoring();
std::string GetFinalResult();
std::string GetPartialResult();
void SetFinished();
bool IsFinished() { return input_finished_; }
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
}
int FrameShiftInMs() const {
// one decoder frame length in ms
return decodable_->Nnet()->SubsamplingRate() *
feature_pipeline_->FrameShift();
}
const std::vector<DecodeResult>& Result() const { return result_; }
private:
void AttentionRescoring();
void UpdateResult(bool finish = false);
private:
U2RecognizerResource opts_;
// std::shared_ptr<U2RecognizerResource> resource_;
// U2RecognizerResource resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<CTCPrefixBeamSearch> decoder_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
std::vector<DecodeResult> result_;
// global decoded frame offset
int global_frame_offset_;
// cur decoded frame num
int num_frames_;
// timestamp gap between words in a sentence
const int time_stamp_gap_ = 100;
bool input_finished_;
};
} // namespace ppspeech
\ No newline at end of file
...@@ -10,4 +10,4 @@ target_link_libraries(websocket_server_main PUBLIC fst websocket ${DEPS}) ...@@ -10,4 +10,4 @@ target_link_libraries(websocket_server_main PUBLIC fst websocket ${DEPS})
add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc)
target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS}) target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS})
\ No newline at end of file
# add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND")
add_definitions("-DUSE_ORT_BACKEND")
add_subdirectory(nnet)
\ No newline at end of file
set(srcs
panns_nnet.cc
panns_interface.cc
)
add_library(cls SHARED ${srcs})
target_link_libraries(cls PRIVATE ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils )
set(bin_name panns_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_link_libraries(${bin_name} gflags glog cls)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "audio_classification/nnet/panns_interface.h"
#include "audio_classification/nnet/panns_nnet.h"
#include "common/base/config.h"
namespace ppspeech {
void* ClsCreateInstance(const char* conf_path) {
Config conf(conf_path);
// cls init
ppspeech::ClsNnetConf cls_nnet_conf;
cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true);
cls_nnet_conf.wav_normal_type_ =
conf.Read("wav_normal_type", std::string("linear"));
cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0);
cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string(""));
cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string(""));
cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string(""));
cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12);
cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000);
cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32);
cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10);
cls_nnet_conf.num_bins = conf.Read("num_bins", 64);
cls_nnet_conf.low_freq = conf.Read("low_freq", 50);
cls_nnet_conf.high_freq = conf.Read("high_freq", 14000);
cls_nnet_conf.dither = conf.Read("dither", 0.0);
ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet();
int ret = cls_model->Init(cls_nnet_conf);
return static_cast<void*>(cls_model);
}
int ClsDestroyInstance(void* instance) {
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
if (cls_model != NULL) {
delete cls_model;
cls_model = NULL;
}
return 0;
}
int ClsFeedForward(void* instance,
const char* wav_path,
int topk,
char* result,
int result_max_len) {
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
if (cls_model == NULL) {
printf("instance is null\n");
return -1;
}
int ret = cls_model->Forward(wav_path, topk, result, result_max_len);
return 0;
}
int ClsReset(void* instance) {
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
if (cls_model == NULL) {
printf("instance is null\n");
return -1;
}
cls_model->Reset();
return 0;
}
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace ppspeech {
void* ClsCreateInstance(const char* conf_path);
int ClsDestroyInstance(void* instance);
int ClsFeedForward(void* instance,
const char* wav_path,
int topk,
char* result,
int result_max_len);
int ClsReset(void* instance);
} // namespace ppspeech
\ No newline at end of file
此差异已折叠。
...@@ -12,59 +12,63 @@ ...@@ -12,59 +12,63 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// todo refactor later (SGoat)
#pragma once #pragma once
#include "decoder/ctc_beam_search_decoder.h" #include "common/frontend/data_cache.h"
#include "decoder/ctc_tlg_decoder.h" #include "common/frontend/fbank.h"
#include "frontend/audio/feature_pipeline.h" #include "common/frontend/feature-fbank.h"
#include "nnet/decodable.h" #include "common/frontend/frontend_itf.h"
#include "nnet/ds2_nnet.h" #include "common/frontend/wave-reader.h"
#include "common/utils/audio_process.h"
DECLARE_double(acoustic_scale); #include "common/utils/file_utils.h"
#include "fastdeploy/runtime.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
namespace ppspeech { namespace ppspeech {
struct ClsNnetConf {
struct RecognizerResource { // wav
kaldi::BaseFloat acoustic_scale{1.0}; bool wav_normal_;
FeaturePipelineOptions feature_pipeline_opts{}; std::string wav_normal_type_;
ModelOptions model_opts{}; float wav_norm_mul_factor_;
TLGDecoderOptions tlg_opts{}; // model
// CTCBeamSearchOptions beam_search_opts; std::string model_file_path_;
std::string param_file_path_;
static RecognizerResource InitFromFlags() { std::string dict_file_path_;
RecognizerResource resource; int num_cpu_thread_;
resource.acoustic_scale = FLAGS_acoustic_scale; // fbank
resource.feature_pipeline_opts = float samp_freq;
FeaturePipelineOptions::InitFromFlags(); float frame_length_ms;
resource.feature_pipeline_opts.assembler_opts.fill_zero = true; float frame_shift_ms;
LOG(INFO) << "ds2 need fill zero be true: " int num_bins;
<< resource.feature_pipeline_opts.assembler_opts.fill_zero; float low_freq;
resource.model_opts = ModelOptions::InitFromFlags(); float high_freq;
resource.tlg_opts = TLGDecoderOptions::InitFromFlags(); float dither;
return resource;
}
}; };
class Recognizer { class ClsNnet {
public: public:
explicit Recognizer(const RecognizerResource& resouce); ClsNnet();
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves); int Init(const ClsNnetConf& conf);
void Decode(); int Forward(const char* wav_path,
std::string GetFinalResult(); int topk,
std::string GetPartialResult(); char* result,
void SetFinished(); int result_max_len);
bool IsFinished();
void Reset(); void Reset();
private: private:
// std::shared_ptr<RecognizerResource> resource_; int ModelForward(float* features,
// RecognizerResource resource_; const int num_frames,
std::shared_ptr<FeaturePipeline> feature_pipeline_; const int feat_dim,
std::shared_ptr<Decodable> decodable_; std::vector<float>* model_out);
std::unique_ptr<TLGDecoder> decoder_; int ModelForwardStream(std::vector<float>* feats);
bool input_finished_; int GetTopkResult(int k, const std::vector<float>& model_out);
ClsNnetConf conf_;
knf::FbankOptions fbank_opts_;
std::unique_ptr<fastdeploy::Runtime> runtime_;
std::vector<std::string> dict_;
std::stringstream ss_;
}; };
} // namespace ppspeech } // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <string>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "audio_classification/nnet/panns_interface.h"
DEFINE_string(conf_path, "", "config path");
DEFINE_string(scp_path, "", "wav scp path");
DEFINE_string(topk, "", "print topk results");
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK_GT(FLAGS_conf_path.size(), 0);
CHECK_GT(FLAGS_scp_path.size(), 0);
CHECK_GT(FLAGS_topk.size(), 0);
void* instance = ppspeech::ClsCreateInstance(FLAGS_conf_path.c_str());
int ret = 0;
// read wav
std::ifstream ifs(FLAGS_scp_path);
std::string line = "";
int topk = std::atoi(FLAGS_topk.c_str());
while (getline(ifs, line)) {
// read wav
char result[1024] = {0};
ret = ppspeech::ClsFeedForward(
instance, line.c_str(), topk, result, 1024);
printf("%s %s\n", line.c_str(), result);
ret = ppspeech::ClsReset(instance);
}
ret = ppspeech::ClsDestroyInstance(instance);
return 0;
}
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(glog) if(ANDROID)
add_subdirectory(nnet) else() #Unix
add_subdirectory(glog)
endif()
\ No newline at end of file
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc) add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc)
target_link_libraries(glog_main glog) target_link_libraries(glog_main extern_glog)
add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc) add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc)
target_link_libraries(glog_logtostderr_main glog) target_link_libraries(glog_logtostderr_main extern_glog)
此差异已折叠。
此差异已折叠。
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include <cstring> #include <cstring>
#include <deque> #include <deque>
#include <fstream> #include <fstream>
#include <functional>
#include <future>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <istream> #include <istream>
...@@ -48,4 +50,5 @@ ...@@ -48,4 +50,5 @@
#include "base/log.h" #include "base/log.h"
#include "base/macros.h" #include "base/macros.h"
#include "utils/file_utils.h" #include "utils/file_utils.h"
#include "utils/math.h" #include "utils/math.h"
\ No newline at end of file #include "utils/timer.h"
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册