提交 290c23b9 编写于 作者: H Hui Zhang

add u2 nnet, u2 nnet main, codelab, and can compile

上级 e1fc57de
...@@ -42,6 +42,7 @@ for type in attention_rescoring; do ...@@ -42,6 +42,7 @@ for type in attention_rescoring; do
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
mkdir -p ${output_dir} mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test_wav.py \ python3 -u ${BIN_DIR}/test_wav.py \
--debug True \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--decode_cfg ${decode_config_path} \ --decode_cfg ${decode_config_path} \
......
...@@ -16,6 +16,8 @@ import os ...@@ -16,6 +16,8 @@ import os
import sys import sys
from pathlib import Path from pathlib import Path
import distutils
import numpy as np
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
...@@ -74,6 +76,8 @@ class U2Infer(): ...@@ -74,6 +76,8 @@ class U2Infer():
# fbank # fbank
feat = self.preprocessing(audio, **self.preprocess_args) feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}") logger.info(f"feat shape: {feat.shape}")
if self.args.debug:
np.savetxt("feat.transform.txt", feat)
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
...@@ -125,6 +129,11 @@ if __name__ == "__main__": ...@@ -125,6 +129,11 @@ if __name__ == "__main__":
"--result_file", type=str, help="path of save the asr result") "--result_file", type=str, help="path of save the asr result")
parser.add_argument( parser.add_argument(
"--audio_file", type=str, help="path of the input audio file") "--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="for debug.")
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = CfgNode(new_allowed=True)
......
# This file is used by clang-format to autoformat paddle source code
#
# The clang-format is part of llvm toolchain.
# It need to install llvm and clang to format source code style.
#
# The basic usage is,
# clang-format -i -style=file PATH/TO/SOURCE/CODE
#
# The -style=file implicit use ".clang-format" file located in one of
# parent directory.
# The -i means inplace change.
#
# The document of clang-format is
# http://clang.llvm.org/docs/ClangFormat.html
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
---
Language: Cpp
BasedOnStyle: Google
IndentWidth: 4
TabWidth: 4
ContinuationIndentWidth: 4
MaxEmptyLinesToKeep: 2
AccessModifierOffset: -2 # The private/protected/public has no indent in class
Standard: Cpp11
AllowAllParametersOfDeclarationOnNextLine: true
BinPackParameters: false
BinPackArguments: false
...
...@@ -31,9 +31,13 @@ SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall ...@@ -31,9 +31,13 @@ SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall
############################################################################### ###############################################################################
# Option Configurations # Option Configurations
############################################################################### ###############################################################################
# option configurations
option(TEST_DEBUG "option for debug" OFF) option(TEST_DEBUG "option for debug" OFF)
option(USE_PROFILING "enable c++ profling" OFF)
option(USING_U2 "compile u2 model." ON)
option(USING_DS2 "compile with ds2 model." ON)
option(USING_GPU "u2 compute on GPU." OFF)
############################################################################### ###############################################################################
# Include third party # Include third party
...@@ -85,6 +89,41 @@ add_dependencies(openfst gflags glog) ...@@ -85,6 +89,41 @@ add_dependencies(openfst gflags glog)
include(paddleinference) include(paddleinference)
# paddle core.so
find_package(Threads REQUIRED)
find_package(PythonLibs REQUIRED)
find_package(Python3 REQUIRED)
find_package(pybind11 CONFIG)
message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}")
message(STATUS "Pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}, pybind11_LIBRARIES=${pybind11_LIBRARIES}, pybind11_DEFINITIONS=${pybind11_DEFINITIONS}")
# paddle include and link option
execute_process(
COMMAND python -c "import paddle ; print(' '.join(paddle.sysconfig.get_link_flags()), end='')"
OUTPUT_VARIABLE PADDLE_LINK_FLAGS
RESULT_VARIABLE SUCESS)
message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS})
string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS)
# paddle compile option
execute_process(
COMMAND python -c "import paddle ; print(' '.join(paddle.sysconfig.get_compile_flags()), end='')"
OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS)
message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS})
string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
# for LD_LIBRARY_PATH
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
execute_process(
COMMAND python -c "import paddle ; print(':'.join(paddle.sysconfig.get_lib()), end='')"
OUTPUT_VARIABLE PADDLE_LIB_DIRS)
message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS})
############################################################################### ###############################################################################
# Add local library # Add local library
############################################################################### ###############################################################################
......
...@@ -3,11 +3,14 @@ ...@@ -3,11 +3,14 @@
## Environment ## Environment
We develop under: We develop under:
* python - 3.7
* docker - `registry.baidubce.com/paddlepaddle/paddle:2.2.2-gpu-cuda10.2-cudnn7` * docker - `registry.baidubce.com/paddlepaddle/paddle:2.2.2-gpu-cuda10.2-cudnn7`
* os - Ubuntu 16.04.7 LTS * os - Ubuntu 16.04.7 LTS
* gcc/g++/gfortran - 8.2.0 * gcc/g++/gfortran - 8.2.0
* cmake - 3.16.0 * cmake - 3.16.0
> Please using `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx.
> We make sure all things work fun under docker, and recommend using it to develop and deploy. > We make sure all things work fun under docker, and recommend using it to develop and deploy.
* [How to Install Docker](https://docs.docker.com/engine/install/) * [How to Install Docker](https://docs.docker.com/engine/install/)
...@@ -24,13 +27,16 @@ docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --nam ...@@ -24,13 +27,16 @@ docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --nam
* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html). * More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html).
2. Create python environment.
2. Build `speechx` and `examples`. ```
bash tools/venv.sh
```
> Do not source venv. 2. Build `speechx` and `examples`.
``` ```
pushd /path/to/speechx source venv/bin/activate
./build.sh ./build.sh
``` ```
......
...@@ -2,10 +2,9 @@ include(FetchContent) ...@@ -2,10 +2,9 @@ include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
gflags gflags
URL https://github.com/gflags/gflags/archive/v2.2.1.zip URL https://github.com/gflags/gflags/archive/v2.2.2.zip
URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
) )
FetchContent_MakeAvailable(gflags) FetchContent_MakeAvailable(gflags)
# openfst need # openfst need
......
include(FetchContent) include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
gtest gtest
URL https://github.com/google/googletest/archive/release-1.10.0.zip URL https://github.com/google/googletest/archive/release-1.11.0.zip
URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91 URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
) )
FetchContent_MakeAvailable(gtest) FetchContent_MakeAvailable(gtest)
......
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../ SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C export LC_AL=C
......
...@@ -54,4 +54,10 @@ compute_linear_spectrogram_main \ ...@@ -54,4 +54,10 @@ compute_linear_spectrogram_main \
--cmvn_file=$exp_dir/cmvn.ark --cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature." echo "compute linear spectrogram feature."
compute_fbank_main \
--num_bins 161 \
--wav_rspecifier=scp:$data_dir/wav.scp \
--feature_wspecifier=ark,t:$exp_dir/fbank.ark \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute fbank feature."
...@@ -6,7 +6,7 @@ SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx ...@@ -6,7 +6,7 @@ SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C export LC_AL=C
......
# Deepspeech2 Streaming NNet Test
Using for ds2 streaming nnet inference test.
# This contains the locations of binarys build required for running the examples.
unset GREP_OPTIONS
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
SPEECHX_BIN=$SPEECHX_BUILD/nnet
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
PADDLE_LIB_PATH=$(python -c "import paddle ; print(':'.join(paddle.sysconfig.get_lib()), end='')")
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH
#!/bin/bash
set -x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -f data/model/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
compute_fbank_main \
--num_bins 80 \
--wav_rspecifier=scp:$data/wav.scp \
--cmvn_file=$exp/cmvn.ark \
--feature_wspecifier=ark,t:$exp/fbank.ark
echo "compute fbank feature."
u2_nnet_main \
--model_path=$model_dir/export.jit \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--downsampling_rate=4 \
--acoustic_scale=1.0 \
--nnet_prob_wspecifier=ark,t:$exp/probs.ark
#!/bin/bash
# this script is for memory check, so please run ./run.sh first.
set +x
set -e
. ./path.sh
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
echo "please install valgrind in the speechx tools dir.\n"
exit 1
fi
ckpt_dir=./data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
ds2_model_test_main \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams
# U2/U2++ Streaming ASR
## Examples
* `wenetspeech` - Streaming Decoding using wenetspeech u2/u2++ model. Using aishell test data for testing.
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <thread> #include <thread>
#include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
// Initialize Google’s logging library. // Initialize Google’s logging library.
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
LOG(INFO) << "Found " << 10 << " cookies"; LOG(INFO) << "Found " << 10 << " cookies";
......
...@@ -195,8 +195,11 @@ void model_forward_test() { ...@@ -195,8 +195,11 @@ void model_forward_test() {
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
model_forward_test(); model_forward_test();
return 0; return 0;
......
...@@ -18,7 +18,6 @@ set(BINS ...@@ -18,7 +18,6 @@ set(BINS
tlg_decoder_main tlg_decoder_main
) )
message(STATUS "xxxxxxxxxx: " ${DEPS})
foreach(bin_name IN LISTS BINS) foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
......
...@@ -53,8 +53,11 @@ using std::vector; ...@@ -53,8 +53,11 @@ using std::vector;
// test ds2 online decoder by feeding speech feature // test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK(FLAGS_result_wspecifier != ""); CHECK(FLAGS_result_wspecifier != "");
CHECK(FLAGS_feature_rspecifier != ""); CHECK(FLAGS_feature_rspecifier != "");
......
...@@ -30,8 +30,11 @@ using std::vector; ...@@ -30,8 +30,11 @@ using std::vector;
// test decoder by feeding nnet posterior probability // test decoder by feeding nnet posterior probability
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader likelihood_reader( kaldi::SequentialBaseFloatMatrixReader likelihood_reader(
FLAGS_nnet_prob_respecifier); FLAGS_nnet_prob_respecifier);
......
...@@ -23,8 +23,11 @@ DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); ...@@ -23,8 +23,11 @@ DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate"); DEFINE_int32(sample_rate, 16000, "sample rate");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure(); ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::Recognizer recognizer(resource); ppspeech::Recognizer recognizer(resource);
......
...@@ -55,8 +55,11 @@ using std::vector; ...@@ -55,8 +55,11 @@ using std::vector;
// test TLG decoder by feeding speech feature. // test TLG decoder by feeding speech feature.
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
......
project(frontend)
add_library(frontend STATIC add_library(frontend STATIC
cmvn.cc cmvn.cc
db_norm.cc db_norm.cc
......
...@@ -30,8 +30,11 @@ DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)"); ...@@ -30,8 +30,11 @@ DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace boost::json; // from <boost/json.hpp> using namespace boost::json; // from <boost/json.hpp>
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file; LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
......
...@@ -32,13 +32,21 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); ...@@ -32,13 +32,21 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "", "read cmvn"); DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(num_bins, 161, "fbank num bins"); DEFINE_int32(num_bins, 161, "fbank num bins");
DEFINE_int32(sample_rate, 16000, "sampe rate: 16k, 8k.");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK(FLAGS_wav_rspecifier.size() > 0);
CHECK(FLAGS_feature_wspecifier.size() > 0);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier); FLAGS_wav_rspecifier);
kaldi::SequentialTableReader<kaldi::WaveInfoHolder> wav_info_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
...@@ -54,6 +62,10 @@ int main(int argc, char* argv[]) { ...@@ -54,6 +62,10 @@ int main(int argc, char* argv[]) {
opt.frame_opts.frame_shift_ms = 10; opt.frame_opts.frame_shift_ms = 10;
opt.mel_opts.num_bins = FLAGS_num_bins; opt.mel_opts.num_bins = FLAGS_num_bins;
opt.frame_opts.dither = 0.0; opt.frame_opts.dither = 0.0;
LOG(INFO) << "frame_length_ms: " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame_shift_ms: " << opt.frame_opts.frame_shift_ms;
LOG(INFO) << "num_bins: " << opt.mel_opts.num_bins;
LOG(INFO) << "dither: " << opt.frame_opts.dither;
std::unique_ptr<ppspeech::FrontendInterface> fbank( std::unique_ptr<ppspeech::FrontendInterface> fbank(
new ppspeech::Fbank(opt, std::move(data_source))); new ppspeech::Fbank(opt, std::move(data_source)));
...@@ -61,53 +73,73 @@ int main(int argc, char* argv[]) { ...@@ -61,53 +73,73 @@ int main(int argc, char* argv[]) {
std::unique_ptr<ppspeech::FrontendInterface> cmvn( std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
ppspeech::FeatureCacheOptions feat_cache_opts;
// the feature cache output feature chunk by chunk. // the feature cache output feature chunk by chunk.
ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "fbank: " << true; LOG(INFO) << "fbank: " << true;
LOG(INFO) << "feat dim: " << feature_cache.Dim(); LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk; float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate; int chunk_sample_size = streaming_chunk * FLAGS_sample_rate;
LOG(INFO) << "sr: " << sample_rate; LOG(INFO) << "sr: " << FLAGS_sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (sec): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size; LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done(); wav_reader.Next()) { for (; !wav_reader.Done() && !wav_info_reader.Done(); wav_reader.Next(), wav_info_reader.Next()) {
std::string utt = wav_reader.Key(); const std::string& utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value(); const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "process utt: " << utt;
const std::string& utt2 = wav_info_reader.Key();
const kaldi::WaveInfo& wave_info = wav_info_reader.Value();
CHECK(utt == utt2) << "wav reader and wav info reader using diff rspecifier!!!";
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "samples: " << wave_info.SampleCount();
LOG(INFO) << "dur: " << wave_info.Duration() << " sec";
CHECK(wave_info.SampFreq() == FLAGS_sample_rate) << "need " << FLAGS_sample_rate << " get " << wave_info.SampFreq();
// load first channel wav
int32 this_channel = 0; int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel); this_channel);
// compute feat chunk by chunk
int tot_samples = waveform.Dim(); int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0; int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats; std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0; int feature_rows = 0;
while (sample_offset < tot_samples) { while (sample_offset < tot_samples) {
// cur chunk size
int cur_chunk_size = int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset); std::min(chunk_sample_size, tot_samples - sample_offset);
// get chunk wav
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size); kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) { for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i); wav_chunk(i) = waveform(sample_offset + i);
} }
kaldi::Vector<BaseFloat> features; // compute feat
feature_cache.Accept(wav_chunk); feature_cache.Accept(wav_chunk);
// send finish signal
if (cur_chunk_size < chunk_sample_size) { if (cur_chunk_size < chunk_sample_size) {
feature_cache.SetFinished(); feature_cache.SetFinished();
} }
// read feat
kaldi::Vector<BaseFloat> features;
bool flag = true; bool flag = true;
do { do {
flag = feature_cache.Read(&features); flag = feature_cache.Read(&features);
feats.push_back(features); if (flag && features.Dim() != 0) {
feature_rows += features.Dim() / feature_cache.Dim(); feats.push_back(features);
feature_rows += features.Dim() / feature_cache.Dim();
}
} while (flag == true && features.Dim() != 0); } while (flag == true && features.Dim() != 0);
// forward offset
sample_offset += cur_chunk_size; sample_offset += cur_chunk_size;
} }
...@@ -125,14 +157,19 @@ int main(int argc, char* argv[]) { ...@@ -125,14 +157,19 @@ int main(int argc, char* argv[]) {
++cur_idx; ++cur_idx;
} }
} }
LOG(INFO) << "feat shape: " << features.NumRows() << " , " << features.NumCols();
feat_writer.Write(utt, features); feat_writer.Write(utt, features);
// reset frontend pipeline state
feature_cache.Reset(); feature_cache.Reset();
if (num_done % 50 == 0 && num_done != 0) if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances"; VLOG(2) << "Processed " << num_done << " utterances";
num_done++; num_done++;
} }
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors."; << " with errors.";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);
} }
...@@ -31,8 +31,11 @@ DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn"); ...@@ -31,8 +31,11 @@ DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier); FLAGS_wav_rspecifier);
......
project(nnet) set(srcs decodable.cc)
add_library(nnet STATIC if(USING_DS2)
decodable.cc list(APPEND srcs ds2_nnet.cc)
ds2_nnet.cc endif()
)
if(USING_U2)
list(APPEND srcs u2_nnet.cc)
endif()
add_library(nnet STATIC ${srcs})
target_link_libraries(nnet absl::strings) target_link_libraries(nnet absl::strings)
set(bin_name ds2_nnet_main) if(USING_U2)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet ${DEPS}) # target_link_libraries(nnet ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
\ No newline at end of file endif()
if(USING_DS2)
set(bin_name ds2_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
target_link_libraries(${bin_name} ${DEPS})
endif()
# test bin
if(USING_U2)
set(bin_name u2_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endif()
...@@ -30,6 +30,7 @@ Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet, ...@@ -30,6 +30,7 @@ Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
frames_ready_(0), frames_ready_(0),
acoustic_scale_(acoustic_scale) {} acoustic_scale_(acoustic_scale) {}
// for debug
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_cache_ = likelihood; nnet_cache_ = likelihood;
frames_ready_ += likelihood.NumRows(); frames_ready_ += likelihood.NumRows();
...@@ -41,6 +42,7 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { ...@@ -41,6 +42,7 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
// return the size of frame have computed. // return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; } int32 Decodable::NumFramesReady() const { return frames_ready_; }
// frame idx is from 0 to frame_ready_ -1; // frame idx is from 0 to frame_ready_ -1;
bool Decodable::IsLastFrame(int32 frame) { bool Decodable::IsLastFrame(int32 frame) {
bool flag = EnsureFrameHaveComputed(frame); bool flag = EnsureFrameHaveComputed(frame);
...@@ -72,26 +74,38 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) { ...@@ -72,26 +74,38 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) {
} }
bool Decodable::AdvanceChunk() { bool Decodable::AdvanceChunk() {
// read feats
Vector<BaseFloat> features; Vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) { if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
return false; return false;
} }
int32 nnet_dim = 0;
Vector<BaseFloat> inferences;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences);
// forward feats
int32 vocab_dim = 0;
Vector<BaseFloat> probs;
nnet_->FeedForward(features, frontend_->Dim(), &probs, &vocab_dim);
// cache nnet outupts
nnet_cache_.Resize(probs.Dim() / vocab_dim, vocab_dim);
nnet_cache_.CopyRowsFromVec(probs);
// update state
frame_offset_ = frames_ready_; frame_offset_ = frames_ready_;
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
return true; return true;
} }
// read one frame likelihood
bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) { bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
std::vector<BaseFloat> result; if (EnsureFrameHaveComputed(frame) == false) {
if (EnsureFrameHaveComputed(frame) == false) return false; return false;
likelihood->resize(nnet_cache_.NumCols()); }
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
int vocab_size = nnet_cache_.NumCols();
likelihood->resize(vocab_size);
for (int32 idx = 0; idx < vocab_size; ++idx) {
(*likelihood)[idx] = (*likelihood)[idx] =
nnet_cache_(frame - frame_offset_, idx) * acoustic_scale_; nnet_cache_(frame - frame_offset_, idx) * acoustic_scale_;
} }
......
...@@ -27,35 +27,54 @@ class Decodable : public kaldi::DecodableInterface { ...@@ -27,35 +27,54 @@ class Decodable : public kaldi::DecodableInterface {
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet, explicit Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend, const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale = 1.0); kaldi::BaseFloat acoustic_scale = 1.0);
// void Init(DecodableOpts config); // void Init(DecodableOpts config);
// nnet logprob output
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame); virtual bool IsLastFrame(int32 frame);
// nnet output dim, e.g. vocab size
virtual int32 NumIndices() const; virtual int32 NumIndices() const;
// not logprob
// nnet prob output
virtual bool FrameLikelihood(int32 frame, virtual bool FrameLikelihood(int32 frame,
std::vector<kaldi::BaseFloat>* likelihood); std::vector<kaldi::BaseFloat>* likelihood);
virtual int32 NumFramesReady() const; virtual int32 NumFramesReady() const;
// for offline test // for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
void Reset(); void Reset();
bool IsInputFinished() const { return frontend_->IsFinished(); } bool IsInputFinished() const { return frontend_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame); bool EnsureFrameHaveComputed(int32 frame);
int32 TokenId2NnetId(int32 token_id); int32 TokenId2NnetId(int32 token_id);
private: private:
bool AdvanceChunk(); bool AdvanceChunk();
std::shared_ptr<FrontendInterface> frontend_; std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;
// nnet outputs' cache
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_; kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// the frame is nnet prob frame rather than audio feature frame // the frame is nnet prob frame rather than audio feature frame
// nnet frame subsample the feature frame // nnet frame subsample the feature frame
// eg: 35 frame features output 8 frame inferences // eg: 35 frame features output 8 frame inferences
int32 frame_offset_; int32 frame_offset_;
int32 frames_ready_; int32 frames_ready_;
// todo: feature frame mismatch with nnet inference frame // todo: feature frame mismatch with nnet inference frame
// so use subsampled_frame // so use subsampled_frame
int32 current_log_post_subsampled_offset_; int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_; int32 num_chunk_computed_;
kaldi::BaseFloat acoustic_scale_; kaldi::BaseFloat acoustic_scale_;
}; };
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
// limitations under the License. // limitations under the License.
#include "nnet/ds2_nnet.h" #include "nnet/ds2_nnet.h"
#include "base/flags.h" #include "base/common.h"
#include "base/log.h"
#include "frontend/audio/assembler.h" #include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
...@@ -49,8 +48,11 @@ using kaldi::Matrix; ...@@ -49,8 +48,11 @@ using kaldi::Matrix;
using std::vector; using std::vector;
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
...@@ -146,7 +148,7 @@ int main(int argc, char* argv[]) { ...@@ -146,7 +148,7 @@ int main(int argc, char* argv[]) {
} }
kaldi::Matrix<kaldi::BaseFloat> result(prob_vec.size(), kaldi::Matrix<kaldi::BaseFloat> result(prob_vec.size(),
prob_vec[0].Dim()); prob_vec[0].Dim());
for (int32 row_idx = 0; row_idx < prob_vec.size(); ++row_idx) { for (int row_idx = 0; row_idx < prob_vec.size(); ++row_idx) {
for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) { for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) {
result(row_idx, col_idx) = prob_vec[row_idx](col_idx); result(row_idx, col_idx) = prob_vec[row_idx](col_idx);
} }
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
#include "nnet/nnet_itf.h"
#include "paddle/extension.h"
#include "paddle/jit/all.h"
#include "paddle/phi/api/all.h"
namespace ppspeech {
struct U2ModelOptions {
std::string model_path;
int thread_num;
bool use_gpu;
U2ModelOptions() : model_path(""), thread_num(1), use_gpu(false) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path");
opts->Register("thread-num", &thread_num, "thread num");
opts->Register("use-gpu", &use_gpu, "if use gpu");
}
};
class U2NnetBase : public NnetInterface {
public:
virtual int context() const { return right_context_ + 1; }
virtual int right_context() const { return right_context_; }
virtual int subsampling_rate() const { return subsampling_rate_; }
virtual int eos() const { return eos_; }
virtual int sos() const { return sos_; }
virtual int is_bidecoder() const { return is_bidecoder_; }
// current offset in decoder frame
virtual int offset() const { return offset_; }
virtual void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; }
virtual void set_num_left_chunks(int num_left_chunks) {
num_left_chunks_ = num_left_chunks;
}
// start: false, it is the start chunk of one sentence, else true
virtual int num_frames_for_chunk(bool start) const;
virtual std::shared_ptr<NnetInterface> Copy() const = 0;
virtual void ForwardEncoderChunk(
const std::vector<kaldi::BaseFloat>& chunk_feats,
int32 feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim);
virtual void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) = 0;
protected:
virtual void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
int32 feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) = 0;
virtual void CacheFeature(const std::vector<kaldi::BaseFloat>& chunk_feats,
int32 feat_dim);
protected:
// model specification
int right_context_{0};
int subsampling_rate_{1};
int sos_{0};
int eos_{0};
bool is_bidecoder_{false};
int chunk_size_{16}; // num of decoder frames. If chunk_size > 0, streaming
// case. Otherwise, none streaming case
int num_left_chunks_{-1}; // -1 means all left chunks
// asr decoder state
int offset_{0}; // current offset in encoder output time stamp. Used by
// position embedding.
std::vector<std::vector<float>> cached_feats_{}; // features cache
};
class U2Nnet : public U2NnetBase {
public:
U2Nnet(const U2ModelOptions& opts);
U2Nnet(const U2Nnet& other);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
int32 feature_dim,
kaldi::Vector<kaldi::BaseFloat>* inferences,
int32* inference_dim) override;
void Reset() override;
void Dim();
void LoadModel(const std::string& model_path_w_prefix);
void Warmup();
std::shared_ptr<paddle::jit::Layer> model() const { return model_; }
std::shared_ptr<NnetInterface> Copy() const override;
void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
int32 feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) override;
float ComputePathScore(const paddle::Tensor& prob,
const std::vector<int>& hyp,
int eos);
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
// debug
void FeedEncoderOuts(paddle::Tensor& encoder_out);
private:
U2ModelOptions opts_;
phi::Place dev_;
std::shared_ptr<paddle::jit::Layer> model_{nullptr};
std::vector<paddle::Tensor> encoder_outs_;
// transformer/conformer attention cache
paddle::Tensor att_cache_ = paddle::full({0, 0, 0, 0}, 0.0);
// conformer-only conv_module cache
paddle::Tensor cnn_cache_ = paddle::full({0, 0, 0, 0}, 0.0);
paddle::jit::Function forward_encoder_chunk_;
paddle::jit::Function forward_attention_decoder_;
paddle::jit::Function ctc_activation_;
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/u2_nnet.h"
#include "base/common.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_string(model_path, "", "paddle nnet model");
DEFINE_int32(nnet_decoder_chunk, 16, "nnet forward chunk");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int32 num_done = 0, num_err = 0;
CHECK(FLAGS_feature_rspecifier.size() > 0);
CHECK(FLAGS_nnet_prob_wspecifier.size() > 0);
CHECK(FLAGS_model_path.size() > 0);
LOG(INFO) << "input rspecifier: " << FLAGS_feature_rspecifier;
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
ppspeech::U2ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
int32 chunk_size =
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate +
FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
int nframes = feature.NumRows();
int feat_dim = feature.NumCols();
raw_data->SetDim(feat_dim);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim;
// // pad feats
// int32 padding_len = 0;
// if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
// padding_len =
// chunk_stride - (feature.NumRows() - chunk_size) %
// chunk_stride;
// feature.Resize(feature.NumRows() + padding_len,
// feature.NumCols(),
// kaldi::kCopyData);
// }
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
int32 frame_idx = 0;
std::vector<kaldi::Vector<kaldi::BaseFloat>> prob_vec;
int32 ori_feature_len = feature.NumRows();
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feat_dim);
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) {
LOG(WARNING) << "utt: " << utt << " skip last "
<< feature_chunk_size << " frames, expect is "
<< receptive_field_length;
break;
}
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row(
feature_chunk.Data() + row_id * feat_dim, feat_dim);
feature_chunk_row.CopyFromVec(feat_row);
++start;
}
// feat to frontend pipeline cache
raw_data->Accept(feature_chunk);
// send data finish signal
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
// get nnet outputs
vector<kaldi::BaseFloat> prob;
while (decodable->FrameLikelihood(frame_idx, &prob)) {
kaldi::Vector<kaldi::BaseFloat> vec_tmp(prob.size());
std::memcpy(vec_tmp.Data(),
prob.data(),
sizeof(kaldi::BaseFloat) * prob.size());
prob_vec.push_back(vec_tmp);
frame_idx++;
}
}
// after process one utt, then reset decoder state.
decodable->Reset();
if (prob_vec.size() == 0) {
// the TokenWriter can not write empty string.
++num_err;
LOG(WARNING) << " the nnet prob of " << utt << " is empty";
continue;
}
// writer nnet output
kaldi::MatrixIndexT nrow = prob_vec.size();
kaldi::MatrixIndexT ncol = prob_vec[0].Dim();
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> result(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
result(row_idx, col_idx) = prob_vec[row_idx](col_idx);
}
}
nnet_out_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
LOG(INFO) << " cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(websocket) add_subdirectory(websocket)
project(websocket) # project(websocket)
add_library(websocket STATIC add_library(websocket STATIC
websocket_server.cc websocket_server.cc
......
add_library(utils add_library(utils
file_utils.cc file_utils.cc
math.cc
) )
\ No newline at end of file
...@@ -38,11 +38,11 @@ float LogSumExp(float x, float y) { ...@@ -38,11 +38,11 @@ float LogSumExp(float x, float y) {
template <typename T> template <typename T>
struct ValGreaterComp { struct ValGreaterComp {
bool operator()(const std::pair<T, int32_t>& lhs, bool operator()(const std::pair<T, int32_t>& lhs,
const std::pair<T, int32_>& rhs) const { const std::pair<T, int32_t>& rhs) const {
return lhs.first > rhs.first || return lhs.first > rhs.first ||
(lhs.first == rhs.first && lhs.second < rhs.second); (lhs.first == rhs.first && lhs.second < rhs.second);
} }
} };
template <typename T> template <typename T>
void TopK(const std::vector<T>& data, void TopK(const std::vector<T>& data,
......
#!/bin/bash
set -ex
PYTHON=python3.7
test -d venv || virtualenv -p ${PYTHON} venv
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册