未验证 提交 0d766f4c 编写于 作者: J Jackwaterveg 提交者: GitHub

Merge pull request #1496 from PaddlePaddle/speechx

[speechx] high performance inference for ds2
......@@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:
......
......@@ -35,3 +35,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks
* [librosa](https://github.com/librosa/librosa/blob/main/LICENSE.md)
- ISC License
- Audio feature
* [ThreadPool](https://github.com/progschj/ThreadPool/blob/master/COPYING)
- zlib License
- ThreadPool
......@@ -2,18 +2,32 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(paddlespeech VERSION 0.1)
set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0048.cmake")
set(CMAKE_VERBOSE_MAKEFILE on)
# set std-14
set(CMAKE_CXX_STANDARD 14)
# include file
# cmake dir
set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake)
# Modules
list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}/external)
list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir})
include(FetchContent)
include(ExternalProject)
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
# compiler option
# Keep the same with openfst, -fPIC or -fpic
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")
###############################################################################
# Option Configurations
......@@ -25,91 +39,92 @@ option(TEST_DEBUG "option for debug" OFF)
###############################################################################
# Include third party
###############################################################################
# #example for include third party
# FetchContent_Declare()
# # FetchContent_MakeAvailable was not added until CMake 3.14
# example for include third party
# FetchContent_MakeAvailable was not added until CMake 3.14
# FetchContent_MakeAvailable()
# include_directories()
# gflags
include(gflags)
# glog
include(glog)
# gtest
include(gtest)
# ABSEIL-CPP
include(FetchContent)
FetchContent_Declare(
absl
GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
GIT_TAG "20210324.1"
)
FetchContent_MakeAvailable(absl)
include(absl)
# libsndfile
include(FetchContent)
FetchContent_Declare(
libsndfile
GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
GIT_TAG "1.0.31"
)
FetchContent_MakeAvailable(libsndfile)
include(libsndfile)
# gflags
FetchContent_Declare(
gflags
URL https://github.com/gflags/gflags/archive/v2.2.1.zip
URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
)
FetchContent_MakeAvailable(gflags)
include_directories(${gflags_BINARY_DIR}/include)
# boost
# include(boost) # not work
set(boost_SOURCE_DIR ${fc_patch}/boost-src)
set(BOOST_ROOT ${boost_SOURCE_DIR})
# #find_package(boost REQUIRED PATHS ${BOOST_ROOT})
# glog
FetchContent_Declare(
glog
URL https://github.com/google/glog/archive/v0.4.0.zip
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
)
FetchContent_MakeAvailable(glog)
include_directories(${glog_BINARY_DIR})
# Eigen
include(eigen)
find_package(Eigen3 REQUIRED)
# gtest
FetchContent_Declare(googletest
URL https://github.com/google/googletest/archive/release-1.10.0.zip
URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
)
FetchContent_MakeAvailable(googletest)
# Kenlm
include(kenlm)
add_dependencies(kenlm eigen boost)
#openblas
include(openblas)
# openfst
set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
set(openfst_PREFIX_DIR ${fc_patch}/openfst-subbuild/openfst-populate-prefix)
ExternalProject_Add(openfst
URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip
URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
SOURCE_DIR ${openfst_SOURCE_DIR}
BINARY_DIR ${openfst_BINARY_DIR}
CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
"LIBS=-lgflags_nothreads -lglog -lpthread"
BUILD_COMMAND make -j 4
)
include(openfst)
add_dependencies(openfst gflags glog)
link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include)
add_subdirectory(speechx)
#openblas
#set(OpenBLAS_INSTALL_PREFIX ${fc_patch}/OpenBLAS)
#set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src)
#ExternalProject_Add(
# OpenBLAS
# GIT_REPOSITORY https://github.com/xianyi/OpenBLAS
# GIT_TAG v0.3.13
# GIT_SHALLOW TRUE
# GIT_PROGRESS TRUE
# CONFIGURE_COMMAND ""
# BUILD_IN_SOURCE TRUE
# BUILD_COMMAND make USE_LOCKING=1 USE_THREAD=0
# INSTALL_COMMAND make PREFIX=${OpenBLAS_INSTALL_PREFIX} install
# UPDATE_DISCONNECTED TRUE
#)
# paddle lib
set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib)
set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
ExternalProject_Add(paddle
URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
PREFIX ${paddle_PREFIX_DIR}
SOURCE_DIR ${paddle_SOURCE_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)
set(PADDLE_LIB ${fc_patch}/paddle-lib)
include_directories("${PADDLE_LIB}/paddle/include")
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
link_directories("${PADDLE_LIB}/paddle/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib")
##paddle with mkl
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
include_directories("${MATH_LIB_PATH}/include")
set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf xxhash cryptopp
${EXTERNAL_LIB})
###############################################################################
# Add local library
......@@ -121,4 +136,9 @@ add_subdirectory(speechx)
# if dir do not have CmakeLists.txt
#add_library(lib_name STATIC file.cc)
#target_link_libraries(lib_name item0 item1)
#add_dependencies(lib_name depend-target)
\ No newline at end of file
#add_dependencies(lib_name depend-target)
set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
add_subdirectory(speechx)
add_subdirectory(examples)
\ No newline at end of file
# SpeechX -- Speech Inference All in One
> Test under `Ubuntu 16.04.7 LTS`.
## Build
```
./build.sh
```l
## Valgrind
> If using docker please check `--privileged` is set when `docker run`.
1. Fatal error at startup: a function redirection which is mandatory for this platform-tool combination cannot be set up
```
apt-get install libc6-dbg
```
```
pushd tools
./setup_valgrind.sh
popd
```
# TODO
* DecibelNormalizer: there is a little bit difference between offline and online db norm. The computation of online db norm read feature chunk by chunk, which causes the feature size is different with offline db norm. In normalizer.cc:73, the samples.size() is different, which causes the difference of result.
#!/usr/bin/env bash
# the build script had verified in the paddlepaddle docker image.
# please follow the instruction below to install PaddlePaddle image.
# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
boost_SOURCE_DIR=$PWD/fc_patch/boost-src
if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
tar xzfv boost_1_75_0.tar.gz
mkdir -p $PWD/fc_patch
mv boost_1_75_0 ${boost_SOURCE_DIR}
cd ${boost_SOURCE_DIR}
bash ./bootstrap.sh
./b2
cd -
echo -e "\n"
fi
#rm -rf build
mkdir -p build
cd build
cmake .. -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
#cmake ..
make -j1
cd -
cmake_policy(SET CMP0048 NEW)
\ No newline at end of file
include(FetchContent)
set(BUILD_SHARED_LIBS OFF) # up to you
set(BUILD_TESTING OFF) # to disable abseil test, or gtest will fail.
set(ABSL_ENABLE_INSTALL ON) # now you can enable install rules even in subproject...
FetchContent_Declare(
absl
GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
GIT_TAG "20210324.1"
)
FetchContent_MakeAvailable(absl)
set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR})
include_directories(${absl_SOURCE_DIR})
\ No newline at end of file
include(FetchContent)
set(Boost_DEBUG ON)
set(Boost_PREFIX_DIR ${fc_patch}/boost)
set(Boost_SOURCE_DIR ${fc_patch}/boost-src)
FetchContent_Declare(
Boost
URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
PREFIX ${Boost_PREFIX_DIR}
SOURCE_DIR ${Boost_SOURCE_DIR}
)
execute_process(COMMAND bootstrap.sh WORKING_DIRECTORY ${Boost_SOURCE_DIR})
execute_process(COMMAND b2 WORKING_DIRECTORY ${Boost_SOURCE_DIR})
FetchContent_MakeAvailable(Boost)
message(STATUS "boost src dir: ${Boost_SOURCE_DIR}")
message(STATUS "boost inc dir: ${Boost_INCLUDE_DIR}")
message(STATUS "boost bin dir: ${Boost_BINARY_DIR}")
set(BOOST_ROOT ${Boost_SOURCE_DIR})
message(STATUS "boost root dir: ${BOOST_ROOT}")
include_directories(${Boost_SOURCE_DIR})
\ No newline at end of file
include(FetchContent)
# update eigen to the commit id f612df27 on 03/16/2021
set(EIGEN_PREFIX_DIR ${fc_patch}/eigen3)
FetchContent_Declare(
Eigen3
GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git
GIT_TAG master
PREFIX ${EIGEN_PREFIX_DIR}
GIT_SHALLOW TRUE
GIT_PROGRESS TRUE)
set(EIGEN_BUILD_DOC OFF)
# note: To disable eigen tests,
# you should put this code in a add_subdirectory to avoid to change
# BUILD_TESTING for your own project too since variables are directory
# scoped
set(BUILD_TESTING OFF)
set(EIGEN_BUILD_PKGCONFIG OFF)
set( OFF)
FetchContent_MakeAvailable(Eigen3)
message(STATUS "eigen src dir: ${Eigen3_SOURCE_DIR}")
message(STATUS "eigen bin dir: ${Eigen3_BINARY_DIR}")
#include_directories(${Eigen3_SOURCE_DIR})
#link_directories(${Eigen3_BINARY_DIR})
\ No newline at end of file
include(FetchContent)
FetchContent_Declare(
gflags
URL https://github.com/gflags/gflags/archive/v2.2.1.zip
URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
)
FetchContent_MakeAvailable(gflags)
# openfst need
include_directories(${gflags_BINARY_DIR}/include)
\ No newline at end of file
include(FetchContent)
FetchContent_Declare(
glog
URL https://github.com/google/glog/archive/v0.4.0.zip
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
)
FetchContent_MakeAvailable(glog)
include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
include(FetchContent)
FetchContent_Declare(
gtest
URL https://github.com/google/googletest/archive/release-1.10.0.zip
URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
)
FetchContent_MakeAvailable(gtest)
include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src)
\ No newline at end of file
include(FetchContent)
FetchContent_Declare(
kenlm
GIT_REPOSITORY "https://github.com/kpu/kenlm.git"
GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a"
)
# https://github.com/kpu/kenlm/blob/master/cmake/modules/FindEigen3.cmake
set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR})
FetchContent_MakeAvailable(kenlm)
include_directories(${kenlm_SOURCE_DIR})
\ No newline at end of file
include(FetchContent)
# https://github.com/pongasoft/vst-sam-spl-64/blob/master/libsndfile.cmake
# https://github.com/popojan/goban/blob/master/CMakeLists.txt#L38
# https://github.com/ddiakopoulos/libnyquist/blob/master/CMakeLists.txt
if(LIBSNDFILE_ROOT_DIR)
# instructs FetchContent to not download or update but use the location instead
set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE ${LIBSNDFILE_ROOT_DIR})
else()
set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE "")
endif()
set(LIBSNDFILE_GIT_REPO "https://github.com/libsndfile/libsndfile.git" CACHE STRING "libsndfile git repository url" FORCE)
set(LIBSNDFILE_GIT_TAG 1.0.31 CACHE STRING "libsndfile git tag" FORCE)
FetchContent_Declare(libsndfile
GIT_REPOSITORY ${LIBSNDFILE_GIT_REPO}
GIT_TAG ${LIBSNDFILE_GIT_TAG}
GIT_CONFIG advice.detachedHead=false
# GIT_SHALLOW true
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
FetchContent_GetProperties(libsndfile)
if(NOT libsndfile_POPULATED)
if(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE)
message(STATUS "Using libsndfile from local ${FETCHCONTENT_SOURCE_DIR_LIBSNDFILE}")
else()
message(STATUS "Fetching libsndfile ${LIBSNDFILE_GIT_REPO}/tree/${LIBSNDFILE_GIT_TAG}")
endif()
FetchContent_Populate(libsndfile)
endif()
set(LIBSNDFILE_ROOT_DIR ${libsndfile_SOURCE_DIR})
set(LIBSNDFILE_INCLUDE_DIR "${libsndfile_BINARY_DIR}/src")
function(libsndfile_build)
option(BUILD_PROGRAMS "Build programs" OFF)
option(BUILD_EXAMPLES "Build examples" OFF)
option(BUILD_TESTING "Build examples" OFF)
option(ENABLE_CPACK "Enable CPack support" OFF)
option(ENABLE_PACKAGE_CONFIG "Generate and install package config file" OFF)
option(BUILD_REGTEST "Build regtest" OFF)
# finally we include libsndfile itself
add_subdirectory(${libsndfile_SOURCE_DIR} ${libsndfile_BINARY_DIR} EXCLUDE_FROM_ALL)
# copying .hh for c++ support
#file(COPY "${libsndfile_SOURCE_DIR}/src/sndfile.hh" DESTINATION ${LIBSNDFILE_INCLUDE_DIR})
endfunction()
libsndfile_build()
include_directories(${LIBSNDFILE_INCLUDE_DIR})
\ No newline at end of file
include(FetchContent)
set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src)
set(OpenBLAS_PREFIX ${fc_patch}/OpenBLAS-prefix)
# ######################################################################################################################
# OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575
# ######################################################################################################################
enable_language(Fortran)
#TODO: switch to CPM
include(GNUInstallDirs)
ExternalProject_Add(
OPENBLAS
GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git
GIT_TAG v0.3.10
GIT_SHALLOW YES
PREFIX ${OpenBLAS_PREFIX}
SOURCE_DIR ${OpenBLAS_SOURCE_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
CMAKE_GENERATOR "Unix Makefiles")
# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})
add_library(openblas STATIC IMPORTED)
add_dependencies(openblas OPENBLAS)
set_target_properties(openblas PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES Fortran)
# ${CMAKE_INSTALL_LIBDIR} lib
set_target_properties(openblas PROPERTIES IMPORTED_LOCATION ${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libopenblas.a)
# https://cmake.org/cmake/help/latest/command/install.html?highlight=cmake_install_libdir#installing-targets
# ${CMAKE_INSTALL_LIBDIR} lib
# ${CMAKE_INSTALL_INCLUDEDIR} include
link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR})
\ No newline at end of file
include(FetchContent)
set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
ExternalProject_Add(openfst
URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip
URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
# #PREFIX ${openfst_PREFIX_DIR}
# SOURCE_DIR ${openfst_SOURCE_DIR}
# BINARY_DIR ${openfst_BINARY_DIR}
CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
"LIBS=-lgflags_nothreads -lglog -lpthread"
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
BUILD_COMMAND make -j 4
)
link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include)
*.ark
paddle_asr_model/
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)
# Examples
* decoder - offline decoder
* feat - mfcc, linear
* nnet - ds2 nn
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/raw_audio.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test feature rspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::RawDataCache> raw_data(
new ppspeech::RawDataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
row_idx++;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
# 3. run feat
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn
# 4. run decoder
offline_decoder_main \
--feature_respecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams \
--dict_file=$model_dir/vocab.txt \
--lm_path=$model_dir/avg_1.jit.klm
\ No newline at end of file
#!/bin/bash
# this script is for memory check, so please run ./run.sh first.
set +x
set -e
. ./path.sh
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
echo "please install valgrind in the speechx tools dir.\n"
exit 1
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
offline_decoder_main \
--feature_respecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams \
--dict_file=$model_dir/vocab.txt \
--lm_path=$model_dir/avg_1.jit.klm
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
\ No newline at end of file
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "frontend/linear_spectrogram.h"
#include "base/flags.h"
#include "base/log.h"
#include "frontend/feature_cache.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/normalizer.h"
#include "frontend/raw_audio.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
std::vector<float> mean_{
-13730251.531853663, -12982852.199316509, -13673844.299583456,
-13089406.559646806, -12673095.524938712, -12823859.223276224,
-13590267.158903603, -14257618.467152044, -14374605.116185192,
-14490009.21822485, -14849827.158924166, -15354435.470563512,
-15834149.206532761, -16172971.985514281, -16348740.496746974,
-16423536.699409386, -16556246.263649225, -16744088.772748645,
-16916184.08510357, -17054034.840031497, -17165612.509455364,
-17255955.470915023, -17322572.527648456, -17408943.862033736,
-17521554.799865916, -17620623.254924215, -17699792.395918526,
-17723364.411134344, -17741483.4433254, -17747426.888704527,
-17733315.928209435, -17748780.160905756, -17808336.883775543,
-17895918.671983004, -18009812.59173023, -18098188.66548325,
-18195798.958462656, -18293617.62980999, -18397432.92077201,
-18505834.787318766, -18585451.8100908, -18652438.235649142,
-18700960.306275308, -18734944.58792185, -18737426.313365128,
-18735347.165987637, -18738813.444170244, -18737086.848890636,
-18731576.2474336, -18717405.44095871, -18703089.25545657,
-18691014.546456724, -18692460.568905357, -18702119.628629155,
-18727710.621126678, -18761582.72034647, -18806745.835547544,
-18850674.8692112, -18884431.510951452, -18919999.992506847,
-18939303.799078144, -18952946.273760635, -18980289.22996379,
-19011610.17803294, -19040948.61805145, -19061021.429847397,
-19112055.53768819, -19149667.414264943, -19201127.05091321,
-19270250.82564605, -19334606.883057203, -19390513.336589377,
-19444176.259208687, -19502755.000038862, -19544333.014549147,
-19612668.183176614, -19681902.19006569, -19771969.951249883,
-19873329.723376893, -19996752.59235844, -20110031.131400537,
-20231658.612529557, -20319378.894054495, -20378534.45718066,
-20413332.089584175, -20438147.844177883, -20443710.248040095,
-20465457.02238927, -20488610.969337028, -20516295.16424432,
-20541423.795738827, -20553192.874953747, -20573605.50701977,
-20577871.61936797, -20571807.008916274, -20556242.38912231,
-20542199.30819195, -20521239.063551214, -20519150.80004532,
-20527204.80248933, -20536933.769257784, -20543470.522332076,
-20549700.089992985, -20551525.24958494, -20554873.406493705,
-20564277.65794227, -20572211.740052115, -20574305.69550465,
-20575494.450104576, -20567092.577932164, -20549302.929608088,
-20545445.11878376, -20546625.326603737, -20549190.03499401,
-20554824.947828256, -20568341.378989458, -20577582.331383612,
-20577980.519402675, -20566603.03458152, -20560131.592262644,
-20552166.469060015, -20549063.06763577, -20544490.562339947,
-20539817.82346569, -20528747.715731595, -20518026.24576161,
-20510977.844974525, -20506874.36087992, -20506731.11977665,
-20510482.133420516, -20507760.92101862, -20494644.834457114,
-20480107.89304893, -20461312.091867123, -20442941.75080173,
-20426123.02834838, -20424607.675283, -20426810.369107097,
-20434024.50097819, -20437404.75544205, -20447688.63916367,
-20460893.335563846, -20482922.735127095, -20503610.119434915,
-20527062.76448319, -20557830.035128627, -20593274.72068722,
-20632528.452965066, -20673637.471334763, -20733106.97143075,
-20842921.0447562, -21054357.83621519, -21416569.534189366,
-21978460.272811692, -22753170.052172784, -23671344.10563395,
-24613499.293358143, -25406477.12230188, -25884377.82156489,
-26049040.62791664, -26996879.104431007};
std::vector<float> variance_{
213747175.10846674, 188395815.34302503, 212706429.10966414,
199109025.81461075, 189235901.23864496, 194901336.53253657,
217481594.29306737, 238689869.12327808, 243977501.24115244,
248479623.6431067, 259766741.47116545, 275516766.7790273,
291271202.3691234, 302693239.8220509, 308627358.3997694,
311143911.38788426, 315446105.07731867, 321705430.9341829,
327458907.4659941, 332245072.43223983, 336251717.5935284,
339694069.7639722, 342188204.4322228, 345587110.31313115,
349903086.2875232, 353660214.20643026, 356700344.5270885,
357665362.3529641, 358493352.05658793, 358857951.620328,
358375239.52774596, 358899733.6342954, 361051818.3511561,
364361716.05025816, 368750322.3771452, 372047800.6462831,
375655861.1349018, 379358519.1980013, 383327605.3935181,
387458599.282341, 390434692.3406868, 392994486.35057056,
394874418.04603153, 396230525.79763395, 396365592.0414835,
396334819.8242737, 396488353.19250053, 396438877.00744957,
396197980.4459586, 395590921.6672991, 395001107.62072515,
394528291.7318225, 394593110.424006, 395018405.59353715,
396110577.5415993, 397506704.0371068, 399400197.4657644,
401243568.2468382, 402687134.7805103, 404136047.2872507,
404883170.001883, 405522253.219517, 406660365.3626476,
407919346.0991902, 409045348.5384909, 409759588.7889818,
411974821.8564483, 413489718.78201455, 415535392.56684107,
418466481.97674364, 421104678.35678065, 423405392.5200779,
425550570.40798235, 427929423.9579701, 429585274.253478,
432368493.55181056, 435193587.13513297, 438886855.20476013,
443058876.8633751, 448181232.5093362, 452883835.6332396,
458056721.77926534, 461816531.22735566, 464363620.1970998,
465886343.5057493, 466928872.0651, 467180536.42647296,
468111848.70714295, 469138695.3071312, 470378429.6930793,
471517958.7132626, 472109050.4262365, 473087417.0177867,
473381322.04648733, 473220195.85483915, 472666071.8998819,
472124669.87879956, 471298571.411737, 471251033.2902761,
471672676.43128747, 472177147.2193172, 472572361.7711908,
472968783.7751127, 473156295.4164052, 473398034.82676554,
473897703.5203811, 474328271.33112127, 474452670.98002136,
474549003.99284613, 474252887.13567275, 473557462.909069,
473483385.85193115, 473609738.04855174, 473746944.82085115,
474016729.91696435, 474617321.94138587, 475045097.237122,
475125402.586558, 474664112.9824912, 474426247.5800283,
474104075.42796475, 473978219.7273978, 473773171.7798875,
473578534.69508696, 473102924.16904145, 472651240.5232615,
472374383.1810912, 472209479.6956096, 472202298.8921673,
472370090.76781124, 472220933.99374026, 471625467.37106377,
470994646.51883453, 470182428.9637543, 469348211.5939578,
468570387.4467277, 468540442.7225135, 468672018.90414184,
468994346.9533251, 469138757.58201426, 469553915.95710236,
470134523.38582784, 471082421.62055486, 471962316.51804745,
472939745.1708408, 474250621.5944825, 475773933.43199486,
477465399.71087736, 479218782.61382693, 481752299.7930922,
486608947.8984568, 496119403.2067917, 512730085.5704984,
539048915.2641417, 576285298.3548826, 621610270.2240586,
669308196.4436442, 710656993.5957186, 736344437.3725077,
745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
//std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new
//ppspeech::RawDataCache());
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(
new ppspeech::RawAudioCache());
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
std::unique_ptr<ppspeech::FeatureExtractorInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt,
std::move(base_feature_extractor)));
std::unique_ptr<ppspeech::FeatureExtractorInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_write_path,
std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Vector<BaseFloat> features;
feature_cache.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
feature_cache.SetFinished();
}
feature_cache.Read(&features);
if (features.Dim() == 0) break;
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.Dim() / feature_cache.Dim();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feature_cache.Dim());
for (auto feat : feats) {
int num_rows = feat.Dim() / feature_cache.Dim();
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
++col_idx) {
features(cur_idx, col_idx) =
feat(row_idx * feature_cache.Dim() + col_idx);
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
# 3. run feat
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn
#!/bin/bash
# this script is for memory check, so please run ./run.sh first.
set +x
set -e
. ./path.sh
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
echo "please install valgrind in the speechx tools dir.\n"
exit 1
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})
\ No newline at end of file
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <algorithm>
#include <fstream>
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <thread>
#include "paddle_inference_api.h"
using std::cout;
using std::endl;
DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams");
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame
int col_size = 161; // feat dim
cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;
;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl;
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
cout << "SwitchIrOptim: " << false << endl;
config.DisableFCPadding();
cout << "DisableFCPadding: " << endl;
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size()
<< ",col=" << feats[0].size() << std::endl;
std::vector<float> pp_input_mat;
for (const auto& item : feats) {
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
}
std::cout << "4. fead the data to model" << std::endl;
int row = feats.size();
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
for (auto name : input_names) {
cout << "model input names: " << name << endl;
}
for (auto name : output_names) {
cout << "model output names: " << name << endl;
}
// input
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(pp_input_mat.data());
// input length
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
// state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(),
chunk_state_h_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
// state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(),
chunk_state_c_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
// run
bool success = predictor->Run();
// state_h out
std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
// stage_c out
std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
// output tensor
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
col = output_shape[2];
// probs
std::vector<std::vector<float>> probs;
probs.reserve(row);
for (int i = 0; i < row; i++) {
probs.push_back(std::vector<float>());
probs.back().reserve(col);
for (int j = 0; j < col; j++) {
probs.back().push_back(output_probs[i * col + j]);
}
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size()
<< " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
model_forward_test();
return 0;
}
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
# 4. run decoder
pp-model-test \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams
#!/bin/bash
# this script is for memory check, so please run ./run.sh first.
set +x
set -e
. ./path.sh
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
echo "please install valgrind in the speechx tools dir.\n"
exit 1
fi
model_dir=../paddle_asr_model
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
pp-model-test \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams
\ No newline at end of file
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Google-style flag handling declarations and inline definitions.
#ifndef FST_LIB_FLAGS_H_
#define FST_LIB_FLAGS_H_
#include <cstdlib>
#include <iostream>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <fst/types.h>
#include <fst/lock.h>
#include "gflags/gflags.h"
#include "glog/logging.h"
using std::string;
// FLAGS USAGE:
//
// Definition example:
//
// DEFINE_int32(length, 0, "length");
//
// This defines variable FLAGS_length, initialized to 0.
//
// Declaration example:
//
// DECLARE_int32(length);
//
// SET_FLAGS() can be used to set flags from the command line
// using, for example, '--length=2'.
//
// ShowUsage() can be used to print out command and flag usage.
// #define DECLARE_bool(name) extern bool FLAGS_ ## name
// #define DECLARE_string(name) extern string FLAGS_ ## name
// #define DECLARE_int32(name) extern int32 FLAGS_ ## name
// #define DECLARE_int64(name) extern int64 FLAGS_ ## name
// #define DECLARE_double(name) extern double FLAGS_ ## name
template <typename T>
struct FlagDescription {
FlagDescription(T *addr, const char *doc, const char *type,
const char *file, const T val)
: address(addr),
doc_string(doc),
type_name(type),
file_name(file),
default_value(val) {}
T *address;
const char *doc_string;
const char *type_name;
const char *file_name;
const T default_value;
};
template <typename T>
class FlagRegister {
public:
static FlagRegister<T> *GetRegister() {
static auto reg = new FlagRegister<T>;
return reg;
}
const FlagDescription<T> &GetFlagDescription(const string &name) const {
fst::MutexLock l(&flag_lock_);
auto it = flag_table_.find(name);
return it != flag_table_.end() ? it->second : 0;
}
void SetDescription(const string &name,
const FlagDescription<T> &desc) {
fst::MutexLock l(&flag_lock_);
flag_table_.insert(make_pair(name, desc));
}
bool SetFlag(const string &val, bool *address) const {
if (val == "true" || val == "1" || val.empty()) {
*address = true;
return true;
} else if (val == "false" || val == "0") {
*address = false;
return true;
}
else {
return false;
}
}
bool SetFlag(const string &val, string *address) const {
*address = val;
return true;
}
bool SetFlag(const string &val, int32 *address) const {
char *p = 0;
*address = strtol(val.c_str(), &p, 0);
return !val.empty() && *p == '\0';
}
bool SetFlag(const string &val, int64 *address) const {
char *p = 0;
*address = strtoll(val.c_str(), &p, 0);
return !val.empty() && *p == '\0';
}
bool SetFlag(const string &val, double *address) const {
char *p = 0;
*address = strtod(val.c_str(), &p);
return !val.empty() && *p == '\0';
}
bool SetFlag(const string &arg, const string &val) const {
for (typename std::map< string, FlagDescription<T> >::const_iterator it =
flag_table_.begin();
it != flag_table_.end();
++it) {
const string &name = it->first;
const FlagDescription<T> &desc = it->second;
if (arg == name)
return SetFlag(val, desc.address);
}
return false;
}
void GetUsage(std::set<std::pair<string, string>> *usage_set) const {
for (auto it = flag_table_.begin(); it != flag_table_.end(); ++it) {
const string &name = it->first;
const FlagDescription<T> &desc = it->second;
string usage = " --" + name;
usage += ": type = ";
usage += desc.type_name;
usage += ", default = ";
usage += GetDefault(desc.default_value) + "\n ";
usage += desc.doc_string;
usage_set->insert(make_pair(desc.file_name, usage));
}
}
private:
string GetDefault(bool default_value) const {
return default_value ? "true" : "false";
}
string GetDefault(const string &default_value) const {
return "\"" + default_value + "\"";
}
template <class V>
string GetDefault(const V &default_value) const {
std::ostringstream strm;
strm << default_value;
return strm.str();
}
mutable fst::Mutex flag_lock_; // Multithreading lock.
std::map<string, FlagDescription<T>> flag_table_;
};
template <typename T>
class FlagRegisterer {
public:
FlagRegisterer(const string &name, const FlagDescription<T> &desc) {
auto registr = FlagRegister<T>::GetRegister();
registr->SetDescription(name, desc);
}
private:
FlagRegisterer(const FlagRegisterer &) = delete;
FlagRegisterer &operator=(const FlagRegisterer &) = delete;
};
#define DEFINE_VAR(type, name, value, doc) \
type FLAGS_ ## name = value; \
static FlagRegisterer<type> \
name ## _flags_registerer(#name, FlagDescription<type>(&FLAGS_ ## name, \
doc, \
#type, \
__FILE__, \
value))
// #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc)
// #define DEFINE_string(name, value, doc) \
// DEFINE_VAR(string, name, value, doc)
// #define DEFINE_int32(name, value, doc) DEFINE_VAR(int32, name, value, doc)
// #define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc)
// #define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc)
// Temporary directory.
DECLARE_string(tmpdir);
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags,
const char *src = "");
#define SET_FLAGS(usage, argc, argv, rmflags) \
gflags::ParseCommandLineFlags(argc, argv, true)
// SetFlags(usage, argc, argv, rmflags, __FILE__)
// Deprecated; for backward compatibility.
inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) {
return SetFlags(usage, argc, argv, rmflags);
}
void ShowUsage(bool long_usage = true);
#endif // FST_LIB_FLAGS_H_
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Google-style logging declarations and inline definitions.
#ifndef FST_LIB_LOG_H_
#define FST_LIB_LOG_H_
#include <cassert>
#include <iostream>
#include <string>
#include <fst/types.h>
#include <fst/flags.h>
using std::string;
DECLARE_int32(v);
class LogMessage {
public:
LogMessage(const string &type) : fatal_(type == "FATAL") {
std::cerr << type << ": ";
}
~LogMessage() {
std::cerr << std::endl;
if(fatal_)
exit(1);
}
std::ostream &stream() { return std::cerr; }
private:
bool fatal_;
};
// #define LOG(type) LogMessage(#type).stream()
// #define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO)
// Checks
inline void FstCheck(bool x, const char* expr,
const char *file, int line) {
if (!x) {
LOG(FATAL) << "Check failed: \"" << expr
<< "\" file: " << file
<< " line: " << line;
}
}
// #define CHECK(x) FstCheck(static_cast<bool>(x), #x, __FILE__, __LINE__)
// #define CHECK_EQ(x, y) CHECK((x) == (y))
// #define CHECK_LT(x, y) CHECK((x) < (y))
// #define CHECK_GT(x, y) CHECK((x) > (y))
// #define CHECK_LE(x, y) CHECK((x) <= (y))
// #define CHECK_GE(x, y) CHECK((x) >= (y))
// #define CHECK_NE(x, y) CHECK((x) != (y))
// Debug checks
// #define DCHECK(x) assert(x)
// #define DCHECK_EQ(x, y) DCHECK((x) == (y))
// #define DCHECK_LT(x, y) DCHECK((x) < (y))
// #define DCHECK_GT(x, y) DCHECK((x) > (y))
// #define DCHECK_LE(x, y) DCHECK((x) <= (y))
// #define DCHECK_GE(x, y) DCHECK((x) >= (y))
// #define DCHECK_NE(x, y) DCHECK((x) != (y))
// Ports
#define ATTRIBUTE_DEPRECATED __attribute__((deprecated))
#endif // FST_LIB_LOG_H_
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Google-style flag handling definitions.
#include <cstring>
#if _MSC_VER
#include <io.h>
#include <fcntl.h>
#endif
#include <fst/compat.h>
#include <fst/flags.h>
static const char *private_tmpdir = getenv("TMPDIR");
// DEFINE_int32(v, 0, "verbosity level");
// DEFINE_bool(help, false, "show usage information");
// DEFINE_bool(helpshort, false, "show brief usage information");
#ifndef _MSC_VER
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : "/tmp",
"temporary directory");
#else
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : getenv("TEMP"),
"temporary directory");
#endif // !_MSC_VER
using namespace std;
static string flag_usage;
static string prog_src;
// Sets prog_src to src.
static void SetProgSrc(const char *src) {
prog_src = src;
#if _MSC_VER
// This common code is invoked by all FST binaries, and only by them. Switch
// stdin and stdout into "binary" mode, so that 0x0A won't be translated into
// a 0x0D 0x0A byte pair in a pipe or a shell redirect. Other streams are
// already using ios::binary where binary files are read or written.
// Kudos to @daanzu for the suggested fix.
// https://github.com/kkm000/openfst/issues/20
// https://github.com/kkm000/openfst/pull/23
// https://github.com/kkm000/openfst/pull/32
_setmode(_fileno(stdin), O_BINARY);
_setmode(_fileno(stdout), O_BINARY);
#endif
// Remove "-main" in src filename. Flags are defined in fstx.cc but SetFlags()
// is called in fstx-main.cc, which results in a filename mismatch in
// ShowUsageRestrict() below.
static constexpr char kMainSuffix[] = "-main.cc";
const int prefix_length = prog_src.size() - strlen(kMainSuffix);
if (prefix_length > 0 && prog_src.substr(prefix_length) == kMainSuffix) {
prog_src.erase(prefix_length, strlen("-main"));
}
}
void SetFlags(const char *usage, int *argc, char ***argv,
bool remove_flags, const char *src) {
flag_usage = usage;
SetProgSrc(src);
int index = 1;
for (; index < *argc; ++index) {
string argval = (*argv)[index];
if (argval[0] != '-' || argval == "-") break;
while (argval[0] == '-') argval = argval.substr(1); // Removes initial '-'.
string arg = argval;
string val = "";
// Splits argval (arg=val) into arg and val.
auto pos = argval.find("=");
if (pos != string::npos) {
arg = argval.substr(0, pos);
val = argval.substr(pos + 1);
}
auto bool_register = FlagRegister<bool>::GetRegister();
if (bool_register->SetFlag(arg, val))
continue;
auto string_register = FlagRegister<string>::GetRegister();
if (string_register->SetFlag(arg, val))
continue;
auto int32_register = FlagRegister<int32>::GetRegister();
if (int32_register->SetFlag(arg, val))
continue;
auto int64_register = FlagRegister<int64>::GetRegister();
if (int64_register->SetFlag(arg, val))
continue;
auto double_register = FlagRegister<double>::GetRegister();
if (double_register->SetFlag(arg, val))
continue;
LOG(FATAL) << "SetFlags: Bad option: " << (*argv)[index];
}
if (remove_flags) {
for (auto i = 0; i < *argc - index; ++i) {
(*argv)[i + 1] = (*argv)[i + index];
}
*argc -= index - 1;
}
// if (FLAGS_help) {
// ShowUsage(true);
// exit(1);
// }
// if (FLAGS_helpshort) {
// ShowUsage(false);
// exit(1);
// }
}
// If flag is defined in file 'src' and 'in_src' true or is not
// defined in file 'src' and 'in_src' is false, then print usage.
static void
ShowUsageRestrict(const std::set<pair<string, string>> &usage_set,
const string &src, bool in_src, bool show_file) {
string old_file;
bool file_out = false;
bool usage_out = false;
for (const auto &pair : usage_set) {
const auto &file = pair.first;
const auto &usage = pair.second;
bool match = file == src;
if ((match && !in_src) || (!match && in_src)) continue;
if (file != old_file) {
if (show_file) {
if (file_out) cout << "\n";
cout << "Flags from: " << file << "\n";
file_out = true;
}
old_file = file;
}
cout << usage << "\n";
usage_out = true;
}
if (usage_out) cout << "\n";
}
void ShowUsage(bool long_usage) {
std::set<pair<string, string>> usage_set;
cout << flag_usage << "\n";
auto bool_register = FlagRegister<bool>::GetRegister();
bool_register->GetUsage(&usage_set);
auto string_register = FlagRegister<string>::GetRegister();
string_register->GetUsage(&usage_set);
auto int32_register = FlagRegister<int32>::GetRegister();
int32_register->GetUsage(&usage_set);
auto int64_register = FlagRegister<int64>::GetRegister();
int64_register->GetUsage(&usage_set);
auto double_register = FlagRegister<double>::GetRegister();
double_register->GetUsage(&usage_set);
if (!prog_src.empty()) {
cout << "PROGRAM FLAGS:\n\n";
ShowUsageRestrict(usage_set, prog_src, true, false);
}
if (!long_usage) return;
if (!prog_src.empty()) cout << "LIBRARY FLAGS:\n\n";
ShowUsageRestrict(usage_set, prog_src, false, true);
}
......@@ -2,13 +2,32 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(speechx LANGUAGES CXX)
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/kaldi
)
add_subdirectory(kaldi)
add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc)
target_link_libraries(mfcc-test kaldi-mfcc)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/utils
)
add_subdirectory(utils)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/frontend
)
add_subdirectory(frontend)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/nnet
)
add_subdirectory(nnet)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
\ No newline at end of file
......@@ -16,45 +16,45 @@
#include "kaldi/base/kaldi-types.h"
#include <limits.h>
#include <limits>
typedef float BaseFloat;
typedef double double64;
typedef float BaseFloat;
typedef double double64;
typedef signed char int8;
typedef short int16;
typedef int int32;
typedef signed char int8;
typedef short int16;
typedef int int32;
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef long int64;
typedef long int64;
#else
typedef long long int64;
typedef long long int64;
#endif
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef unsigned long uint64;
#else
typedef unsigned long long uint64;
#endif
typedef signed int char32;
const uint8 kuint8max = (( uint8) 0xFF);
const uint16 kuint16max = ((uint16) 0xFFFF);
const uint32 kuint32max = ((uint32) 0xFFFFFFFF);
const uint64 kuint64max = ((uint64) (0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = (( int8) 0x80);
const int8 kint8max = (( int8) 0x7F);
const int16 kint16min = (( int16) 0x8000);
const int16 kint16max = (( int16) 0x7FFF);
const int32 kint32min = (( int32) 0x80000000);
const int32 kint32max = (( int32) 0x7FFFFFFF);
const int64 kint64min = (( int64) (0x8000000000000000LL));
const int64 kint64max = (( int64) (0x7FFFFFFFFFFFFFFFLL));
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();
typedef signed int char32;
const uint8 kuint8max = ((uint8)0xFF);
const uint16 kuint16max = ((uint16)0xFFFF);
const uint32 kuint32max = ((uint32)0xFFFFFFFF);
const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = ((int8)0x80);
const int8 kint8max = ((int8)0x7F);
const int16 kint16min = ((int16)0x8000);
const int16 kint16max = ((int16)0x7FFF);
const int32 kint32min = ((int32)0x80000000);
const int32 kint32max = ((int32)0x7FFFFFFF);
const int64 kint64min = ((int64)(0x8000000000000000LL));
const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL));
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <deque>
#include <fstream>
#include <iostream>
#include <istream>
#include <map>
#include <memory>
#include <mutex>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "base/basic_types.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/macros.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fst/flags.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fst/log.h"
......@@ -16,8 +16,10 @@
namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
#endif
} // namespace pp_speech
\ No newline at end of file
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
// This software is provided 'as-is', without any express or implied
// warranty. In no event will the authors be held liable for any damages
// arising from the use of this software.
// Permission is granted to anyone to use this software for any purpose,
// including commercial applications, and to alter it and redistribute it
// freely, subject to the following restrictions:
// 1. The origin of this software must not be misrepresented; you must not
// claim that you wrote the original software. If you use this software
// in a product, an acknowledgment in the product documentation would be
// appreciated but is not required.
// 2. Altered source versions must be plainly marked as such, and must not be
// misrepresented as being the original software.
// 3. This notice may not be removed or altered from any source
// distribution.
// this code is from https://github.com/progschj/ThreadPool
#ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H
#include <condition_variable>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <thread>
#include <vector>
class ThreadPool {
public:
ThreadPool(size_t);
template <class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector<std::thread> workers;
// the task queue
std::queue<std::function<void()>> tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
for (size_t i = 0; i < threads; ++i)
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock, [this] {
return this->stop || !this->tasks.empty();
});
if (this->stop && this->tasks.empty()) return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
});
}
// add new work item to the pool
template <class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers) worker.join();
}
#endif
# codelab
This directory is here for testing some funcitons temporaril.
aux_source_directory(. DIR_LIB_SRCS)
add_library(decoder STATIC ${DIR_LIB_SRCS})
project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder STATIC
ctc_beam_search_decoder.cc
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
)
target_link_libraries(decoder PUBLIC kenlm utils fst)
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/basic_types.h"
struct DecoderResult {
BaseFloat acoustic_score;
std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp;
};
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
#include "decoder/ctc_decoders/decoder_utils.h"
#include "utils/file_utils.h"
namespace ppspeech {
using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts),
init_ext_scorer_(nullptr),
blank_id_(-1),
space_id_(-1),
num_frame_decoded_(0),
root_(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: "
<< vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
blank_id_ = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id_ = it - vocabulary_.begin();
// if no space in vocabulary
if ((size_t)space_id_ >= vocabulary_.size()) {
space_id_ = -2;
}
}
void CTCBeamSearch::Reset() {
//num_frame_decoded_ = 0;
//ResetPrefixes();
InitDecoder();
}
void CTCBeamSearch::InitDecoder() {
num_frame_decoded_ = 0;
//ResetPrefixes();
prefixes_.clear();
root_ = std::make_shared<PathTrie>();
root_->score = root_->log_prob_b_prev = 0.0;
prefixes_.push_back(root_.get());
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root_->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root_->set_matcher(matcher);
}
}
void CTCBeamSearch::Decode(
std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
}
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; }
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (1) {
vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob;
bool flag =
decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
}
}
void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes_.size(); i++) {
if (prefixes_[i] != nullptr) {
delete prefixes_[i];
prefixes_[i] = nullptr;
}
}
prefixes_.clear();
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) {
kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(probs);
LOG(INFO) << "ctc decoding elapsed time(s) "
<< static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
}
string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
return result[0].second;
}
string CTCBeamSearch::GetFinalBestPath() {
CalculateApproxScore();
LMRescore();
return GetBestPath();
}
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(),
vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
probs_seq[i][j] = static_cast<double>(probs[i][j]);
}
}
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
std::sort(prefixes_.begin(),
prefixes_.begin() + num_prefixes_,
prefix_compare);
if (num_prefixes_ == 0) {
continue;
}
min_cutoff = prefixes_[num_prefixes_ - 1]->score +
std::log(prob[blank_id_]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes_ == beam_size);
}
vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes_.clear();
// update log probs
root_->iterate_to_vec(prefixes_);
// only preserve top beam_size prefixes_
if (prefixes_.size() >= beam_size) {
std::nth_element(prefixes_.begin(),
prefixes_.begin() + beam_size,
prefixes_.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes_.size(); ++i) {
prefixes_[i]->remove();
}
} // if
num_frame_decoded_++;
} // for probs_seq
}
int32 CTCBeamSearch::SearchOneChar(
const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes_.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes_[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
if (c == blank_id_) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id_ || init_ext_scorer_->is_character_based())) {
PathTrie* prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
return 0;
}
void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size;
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
std::sort(
prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
double approx_ctc = prefixes_[i]->score;
if (init_ext_scorer_ != nullptr) {
vector<int> output;
prefixes_[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = init_ext_scorer_->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight:
approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
init_ext_scorer_->alpha;
}
prefixes_[i]->approx_ctc = approx_ctc;
}
}
void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
auto prefix = prefixes_[i];
if (!prefix->is_empty() && prefix->character != space_id_) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
score += init_ext_scorer_->beta;
prefix->score += score;
}
}
}
}
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h"
#include "util/parse-options.h"
#pragma once
namespace ppspeech {
struct CTCBeamSearchOptions {
std::string dict_file;
std::string lm_path;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions()
: dict_file("vocab.txt"),
lm_path("lm.klm"),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(0) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch {
public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {}
void InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id_;
int space_id_;
std::shared_ptr<PathTrie> root_;
std::vector<PathTrie*> prefixes_;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
};
} // namespace basr
\ No newline at end of file
../../../third_party/ctc_decoders
\ No newline at end of file
project(frontend)
add_library(frontend STATIC
normalizer.cc
linear_spectrogram.cc
raw_audio.cc
feature_cache.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// wrap the fbank feat of kaldi, todo (SmileGoat)
#include "kaldi/feat/feature-mfcc.h"
#incldue "kaldi/matrix/kaldi-vector.h"
namespace ppspeech {
class FbankExtractor : FeatureExtractorInterface {
public:
explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor);
virtual void AcceptWaveform(
const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0;
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
kaldi::Vector<kaldi::BaseFloat>* feat) const;
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/feature_cache.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
FeatureCache::FeatureCache(
int max_size, unique_ptr<FeatureExtractorInterface> base_extractor) {
max_size_ = max_size;
base_extractor_ = std::move(base_extractor);
}
void FeatureCache::Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
base_extractor_->Accept(inputs);
// feed current data
bool result = false;
do {
result = Compute();
} while (result);
}
// pop feature chunk
bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.empty() && base_extractor_->IsFinished() == false) {
ready_read_condition_.wait(lock);
BaseFloat elapsed = timer.Elapsed() * 1000;
// todo replace 1.0 with timeout_
if (elapsed > 1.0) {
return false;
}
usleep(1000); // sleep 1 ms
}
if (cache_.empty()) return false;
feats->Resize(cache_.front().Dim());
feats->CopyFromVec(cache_.front());
cache_.pop();
ready_feed_condition_.notify_one();
return true;
}
// read all data from base_feature_extractor_ into cache_
bool FeatureCache::Compute() {
// compute and feed
Vector<BaseFloat> feature_chunk;
bool result = base_extractor_->Read(&feature_chunk);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
if (feature_chunk.Dim() != 0) {
cache_.push(feature_chunk);
}
ready_read_condition_.notify_one();
return result;
}
void Reset() {
// std::lock_guard<std::mutex> lock(mutex_);
return;
}
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
namespace ppspeech {
class FeatureCache : public FeatureExtractorInterface {
public:
explicit FeatureCache(
int32 max_size = kint16max,
std::unique_ptr<FeatureExtractorInterface> base_extractor = NULL);
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// feats dim = num_frames * feature_dim
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// feature cache only cache feature which from base extractor
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual void SetFinished() {
base_extractor_->SetFinished();
// read the last chunk data
Compute();
}
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
while (!cache_.empty()) {
cache_.pop();
}
}
private:
bool Compute();
std::mutex mutex_;
size_t max_size_;
std::queue<kaldi::Vector<BaseFloat>> cache_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_;
// DISALLOW_COPY_AND_ASSGIN(FeatureCache);
};
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/basic_types.h"
#include "kaldi/matrix/kaldi-vector.h"
namespace ppspeech {
class FeatureExtractorInterface {
public:
// accept input data, accept feature or raw waves which decided
// by the base_extractor
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs) = 0;
// get the processed result
// the length of output = feature_row * feature_dim,
// the Matrix is squashed into Vector
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* outputs) = 0;
// the Dim is the feature dim
virtual size_t Dim() const = 0;
virtual void SetFinished() = 0;
virtual bool IsFinished() const = 0;
virtual void Reset() = 0;
};
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/linear_spectrogram.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/matrix/matrix-functions.h"
namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat;
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector;
LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
opts_ = opts;
base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift();
fft_points_ = window_size;
chunk_sample_size_ =
static_cast<int32>(opts.streaming_chunk * opts.frame_opts.samp_freq);
hanning_window_.resize(window_size);
double a = M_2PI / (window_size - 1);
hanning_window_energy_ = 0;
for (int i = 0; i < window_size; ++i) {
hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
}
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
}
void LinearSpectrogram::Accept(const VectorBase<BaseFloat>& inputs) {
base_extractor_->Accept(inputs);
}
bool LinearSpectrogram::Read(Vector<BaseFloat>* feats) {
Vector<BaseFloat> input_feats(chunk_sample_size_);
bool flag = base_extractor_->Read(&input_feats);
if (flag == false || input_feats.Dim() == 0) return false;
vector<BaseFloat> input_feats_vec(input_feats.Dim());
std::memcpy(input_feats_vec.data(), input_feats.Data(),
input_feats.Dim()*sizeof(BaseFloat));
vector<vector<BaseFloat>> result;
Compute(input_feats_vec, result);
int32 feat_size = 0;
if (result.size() != 0) {
feat_size = result.size() * result[0].size();
}
feats->Resize(feat_size);
// todo refactor (SimleGoat)
for (size_t idx = 0; idx < feat_size; ++idx) {
(*feats)(idx) = result[idx / dim_][idx % dim_];
}
return true;
}
void LinearSpectrogram::Hanning(vector<float>* data) const {
CHECK_GE(data->size(), hanning_window_.size());
for (size_t i = 0; i < hanning_window_.size(); ++i) {
data->at(i) *= hanning_window_[i];
}
}
bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
vector<BaseFloat>* real,
vector<BaseFloat>* img) const {
Vector<BaseFloat> v_tmp;
v_tmp.Resize(v->size());
std::memcpy(v_tmp.Data(), v->data(), sizeof(BaseFloat)*(v->size()));
RealFft(&v_tmp, true);
v->resize(v_tmp.Dim());
std::memcpy(v->data(), v_tmp.Data(), sizeof(BaseFloat)*(v->size()));
real->push_back(v->at(0));
img->push_back(0);
for (int i = 1; i < v->size() / 2; i++) {
real->push_back(v->at(2 * i));
img->push_back(v->at(2 * i + 1));
}
real->push_back(v->at(1));
img->push_back(0);
return true;
}
// Compute spectrogram feat
// todo: refactor later (SmileGoat)
bool LinearSpectrogram::Compute(const vector<float>& waves,
vector<vector<float>>& feats) {
int num_samples = waves.size();
const int& frame_length = opts_.frame_opts.WindowSize();
const int& sample_rate = opts_.frame_opts.samp_freq;
const int& frame_shift = opts_.frame_opts.WindowShift();
const int& fft_points = fft_points_;
const float scale = hanning_window_energy_ * sample_rate;
if (num_samples < frame_length) {
return true;
}
int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
feats.resize(num_frames);
vector<float> fft_real((fft_points_ / 2 + 1), 0);
vector<float> fft_img((fft_points_ / 2 + 1), 0);
vector<float> v(frame_length, 0);
vector<float> power((fft_points / 2 + 1));
for (int i = 0; i < num_frames; ++i) {
vector<float> data(waves.data() + i * frame_shift,
waves.data() + i * frame_shift + frame_length);
Hanning(&data);
fft_img.clear();
fft_real.clear();
v.assign(data.begin(), data.end());
NumpyFft(&v, &fft_real, &fft_img);
feats[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
for (int j = 0; j < (fft_points / 2 + 1); ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
feats[i][j] = power[j];
if (j == 0 || j == feats[0].size() - 1) {
feats[i][j] /= scale;
} else {
feats[i][j] *= (2.0 / scale);
}
// log added eps=1e-14
feats[i][j] = std::log(feats[i][j] + 1e-14);
}
}
return true;
}
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h"
namespace ppspeech {
struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts;
kaldi::BaseFloat streaming_chunk;
LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {}
void Register(kaldi::OptionsItf* opts) {
opts->Register(
"streaming-chunk", &streaming_chunk, "streaming chunk size");
frame_opts.Register(opts);
}
};
class LinearSpectrogram : public FeatureExtractorInterface {
public:
explicit LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the dim of single frame feature
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
}
private:
void Hanning(std::vector<kaldi::BaseFloat>* data) const;
bool Compute(const std::vector<kaldi::BaseFloat>& waves,
std::vector<std::vector<kaldi::BaseFloat>>& feats);
bool NumpyFft(std::vector<kaldi::BaseFloat>* v,
std::vector<kaldi::BaseFloat>* real,
std::vector<kaldi::BaseFloat>* img) const;
kaldi::int32 fft_points_;
size_t dim_;
std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// wrap the mfcc feat of kaldi, todo (SmileGoat)
#include "kaldi/feat/feature-mfcc.h"
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/normalizer.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
base_extractor_ = std::move(base_extractor);
opts_ = opts;
dim_ = 1;
}
void DecibelNormalizer::Accept(
const kaldi::VectorBase<BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
bool DecibelNormalizer::Read(kaldi::Vector<BaseFloat>* waves) {
if (base_extractor_->Read(waves) == false ||
waves->Dim() == 0) {
return false;
}
Compute(waves);
return true;
}
bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
// calculate db rms
BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples;
samples.resize(waves->Dim());
for (size_t i = 0; i < samples.size(); ++i) {
samples[i] = (*waves)(i);
}
// square
for (auto& d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
}
mean_square += d * d;
}
// mean
mean_square /= samples.size();
rms_db = 10 * std::log10(mean_square);
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR)
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
// Note that this is an in-place transformation.
for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
std::memcpy(waves->Data(), samples.data(), sizeof(BaseFloat)*samples.size());
return true;
}
CMVN::CMVN(std::string cmvn_file,
unique_ptr<FeatureExtractorInterface> base_extractor)
: var_norm_(true) {
base_extractor_ = std::move(base_extractor);
bool binary;
kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary);
dim_ = stats_.NumCols() - 1;
}
void CMVN::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
base_extractor_->Accept(inputs);
return;
}
bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
if (base_extractor_->Read(feats) == false) {
return false;
}
Compute(feats);
return true;
}
// feats contain num_frames feature.
void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
feats->Dim() % dim != 0) {
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x'
<< stats_.NumCols() << ", feats " << feats->Dim() << 'x';
}
if (stats_.NumRows() == 1 && var_norm_) {
KALDI_ERR
<< "You requested variance normalization but no variance stats_ "
<< "are supplied.";
}
double count = stats_(0, dim);
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0)
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance "
"normalization: "
<< "count = " << count;
if (!var_norm_) {
Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim());
// fill the datat of mean_stats in mean_stats_appy whose dim is equal
// with the dim of feature.
// the dim of feats = dim * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx,
dim);
stats_tmp.CopyFromVec(mean_stats);
}
offset.AddVec(-1.0 / count, mean_stats_apply);
feats->AddVec(1.0, offset);
return;
}
// norm(0, d) = mean offset;
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) {
double mean, offset, scale;
mean = stats_(0, d) / count;
double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
if (var < floor) {
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
<< floor;
var = floor;
}
scale = 1.0 / sqrt(var);
if (scale != scale || 1 / scale == 0.0)
KALDI_ERR
<< "NaN or infinity in cepstral mean/variance computation";
offset = -(mean * scale);
for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset;
norm(1, d_skip) = scale;
d_skip = d_skip + dim;
}
}
// Apply the normalization.
feats->MulElements(norm.Row(1));
feats->AddVec(1.0, norm.Row(0));
}
void CMVN::ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm_, feats);
}
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech {
struct DecibelNormalizerOptions {
float target_db;
float max_gain_db;
bool convert_int_float;
DecibelNormalizerOptions()
: target_db(-20), max_gain_db(300.0), convert_int_float(false) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register(
"target-db", &target_db, "target db for db normalization");
opts->Register(
"max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float",
&convert_int_float,
"if convert int samples to float");
}
};
class DecibelNormalizer : public FeatureExtractorInterface {
public:
explicit DecibelNormalizer(
const DecibelNormalizerOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// noramlize audio, the dim is 1.
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
}
private:
bool Compute(kaldi::VectorBase<kaldi::BaseFloat>* waves) const;
DecibelNormalizerOptions opts_;
size_t dim_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> waveform_;
};
class CMVN : public FeatureExtractorInterface {
public:
explicit CMVN(std::string cmvn_file,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
// the length of feats = feature_row * feature_dim,
// the Matrix is squashed into Vector
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the feautre dim.
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
}
private:
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;
void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats);
kaldi::Matrix<double> stats_;
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
size_t dim_;
bool var_norm_;
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/raw_audio.h"
#include "kaldi/base/timer.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::VectorBase;
using kaldi::Vector;
RawAudioCache::RawAudioCache(int buffer_size)
: finished_(false), data_length_(0), start_(0), timeout_(1) {
ring_buffer_.resize(buffer_size);
}
void RawAudioCache::Accept(const VectorBase<BaseFloat>& waves) {
std::unique_lock<std::mutex> lock(mutex_);
while (data_length_ + waves.Dim() > ring_buffer_.size()) {
ready_feed_condition_.wait(lock);
}
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + start_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx);
}
data_length_ += waves.Dim();
}
bool RawAudioCache::Read(Vector<BaseFloat>* waves) {
size_t chunk_size = waves->Dim();
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (chunk_size > data_length_) {
// when audio is empty and no more data feed
// ready_read_condition will block in dead lock. so replace with timeout_
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
if (elapsed > timeout_) {
if (finished_ == true) { // read last chunk data
break;
}
if (chunk_size > data_length_) {
return false;
}
}
usleep(100); // sleep 0.1 ms
}
// read last chunk data
if (chunk_size > data_length_) {
chunk_size = data_length_;
waves->Resize(chunk_size);
}
for (size_t idx = 0; idx < chunk_size; ++idx) {
int buff_idx = (start_ + idx) % ring_buffer_.size();
waves->Data()[idx] = ring_buffer_[buff_idx];
}
data_length_ -= chunk_size;
start_ = (start_ + chunk_size) % ring_buffer_.size();
ready_feed_condition_.notify_one();
return true;
}
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#pragma once
namespace ppspeech {
class RawAudioCache : public FeatureExtractorInterface {
public:
explicit RawAudioCache(int buffer_size = kint16max);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// the audio dim is 1
virtual size_t Dim() const { return 1; }
virtual void SetFinished() {
std::lock_guard<std::mutex> lock(mutex_);
finished_ = true;
}
virtual bool IsFinished() const { return finished_; }
virtual void Reset() {
start_ = 0;
data_length_ = 0;
finished_ = false;
}
private:
std::vector<kaldi::BaseFloat> ring_buffer_;
size_t start_;
size_t data_length_;
bool finished_;
mutable std::mutex mutex_;
std::condition_variable ready_feed_condition_;
kaldi::int32 timeout_;
DISALLOW_COPY_AND_ASSIGN(RawAudioCache);
};
// it is a datasource for testing different frontend module.
// it accepts waves or feats.
class RawDataCache : public FeatureExtractorInterface {
public:
explicit RawDataCache() { finished_ = false; }
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
data_ = inputs;
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
if (data_.Dim() == 0) {
return false;
}
(*feats) = data_;
data_.Resize(0);
return true;
}
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { finished_ = true; }
virtual bool IsFinished() const { return finished_; }
void SetDim(int32 dim) { dim_ = dim; }
virtual void Reset() {
finished_ = true;
}
private:
kaldi::Vector<kaldi::BaseFloat> data_;
bool finished_;
int32 dim_;
DISALLOW_COPY_AND_ASSIGN(RawDataCache);
};
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// extract the window of kaldi feat.
此差异已折叠。
此差异已折叠。
......@@ -15,5 +15,6 @@ add_library(kaldi-feat-common
feature-window.cc
resample.cc
mel-computations.cc
cmvn.cc
)
target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util)
\ No newline at end of file
target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util)
此差异已折叠。
// transform/cmvn.h
// Copyright 2009-2013 Microsoft Corporation
// Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_TRANSFORM_CMVN_H_
#define KALDI_TRANSFORM_CMVN_H_
#include "base/kaldi-common.h"
#include "matrix/matrix-lib.h"
namespace kaldi {
/// This function initializes the matrix to dimension 2 by (dim+1);
/// 1st "dim" elements of 1st row are mean stats, 1st "dim" elements
/// of 2nd row are var stats, last element of 1st row is count,
/// last element of 2nd row is zero.
void InitCmvnStats(int32 dim, Matrix<double> *stats);
/// Accumulation from a single frame (weighted).
void AccCmvnStats(const VectorBase<BaseFloat> &feat,
BaseFloat weight,
MatrixBase<double> *stats);
/// Accumulation from a feature file (possibly weighted-- useful in excluding silence).
void AccCmvnStats(const MatrixBase<BaseFloat> &feats,
const VectorBase<BaseFloat> *weights, // or NULL
MatrixBase<double> *stats);
/// Apply cepstral mean and variance normalization to a matrix of features.
/// If norm_vars == true, expects stats to be of dimension 2 by (dim+1), but
/// if norm_vars == false, will accept stats of dimension 1 by (dim+1); these
/// are produced by the balanced-cmvn code when it computes an offset and
/// represents it as "fake stats".
void ApplyCmvn(const MatrixBase<double> &stats,
bool norm_vars,
MatrixBase<BaseFloat> *feats);
/// This is as ApplyCmvn, but does so in the reverse sense, i.e. applies a transform
/// that would take zero-mean, unit-variance input and turn it into output with the
/// stats of "stats". This can be useful if you trained without CMVN but later want
/// to correct a mismatch, so you would first apply CMVN and then do the "reverse"
/// CMVN with the summed stats of your training data.
void ApplyCmvnReverse(const MatrixBase<double> &stats,
bool norm_vars,
MatrixBase<BaseFloat> *feats);
/// Modify the stats so that for some dimensions (specified in "dims"), we
/// replace them with "fake" stats that have zero mean and unit variance; this
/// is done to disable CMVN for those dimensions.
void FakeStatsForSomeDims(const std::vector<int32> &dims,
MatrixBase<double> *stats);
} // namespace kaldi
#endif // KALDI_TRANSFORM_CMVN_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册