未验证 提交 acf1d272 编写于 作者: Y YangZhou 提交者: GitHub

[speechx] rm ds2 && rm boost (#2786)

* fix openfst download error

* add acknowledgments of openfst

* refactor directory

* clean ctc_decoders dir

* add nnet cache && make 2 thread work

* do not compile websocket

* rm ds2 && rm boost

* rm ds2 example
上级 5046d8ee
...@@ -57,13 +57,13 @@ repos: ...@@ -57,13 +57,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
- id: cpplint - id: cpplint
name: cpplint name: cpplint
description: Static code analysis of C/C++ files description: Static code analysis of C/C++ files
language: python language: python
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
......
...@@ -44,9 +44,6 @@ option(TEST_DEBUG "option for debug" OFF) ...@@ -44,9 +44,6 @@ option(TEST_DEBUG "option for debug" OFF)
option(USE_PROFILING "enable c++ profling" OFF) option(USE_PROFILING "enable c++ profling" OFF)
option(WITH_TESTING "unit test" ON) option(WITH_TESTING "unit test" ON)
option(USING_U2 "compile u2 model." ON)
option(USING_DS2 "compile with ds2 model." OFF)
option(USING_GPU "u2 compute on GPU." OFF) option(USING_GPU "u2 compute on GPU." OFF)
############################################################################### ###############################################################################
...@@ -56,21 +53,6 @@ include(gflags) ...@@ -56,21 +53,6 @@ include(gflags)
include(glog) include(glog)
# boost
# include(boost) # not work
set(boost_SOURCE_DIR ${fc_patch}/boost-src)
set(BOOST_ROOT ${boost_SOURCE_DIR})
include_directories(${boost_SOURCE_DIR})
link_directories(${boost_SOURCE_DIR}/stage/lib)
# Eigen
include(eigen)
find_package(Eigen3 REQUIRED)
# Kenlm
include(kenlm)
add_dependencies(kenlm eigen boost)
#openblas #openblas
include(openblas) include(openblas)
......
...@@ -4,20 +4,5 @@ set -xe ...@@ -4,20 +4,5 @@ set -xe
# the build script had verified in the paddlepaddle docker image. # the build script had verified in the paddlepaddle docker image.
# please follow the instruction below to install PaddlePaddle image. # please follow the instruction below to install PaddlePaddle image.
# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html # https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
boost_SOURCE_DIR=$PWD/fc_patch/boost-src cmake -B build
if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
tar xzfv boost_1_75_0.tar.gz
mkdir -p $PWD/fc_patch
mv boost_1_75_0 ${boost_SOURCE_DIR}
cd ${boost_SOURCE_DIR}
bash ./bootstrap.sh
./b2
cd -
echo -e "\n"
fi
#rm -rf build
mkdir -p build
cmake -B build -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
cmake --build build -j cmake --build build -j
# Deepspeech2 Streaming ASR
## Examples
* `websocket` - Streaming ASR with websocket for deepspeech2_aishell.
* `aishell` - Streaming Decoding under aishell dataset, for local WER test.
* `onnx` - Example to convert deepspeech2 to onnx format.
# Aishell - Deepspeech2 Streaming
> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/).
A C++ deployment example for using the deepspeech2 model to recognize `wav` and compute `CER`. We using AISHELL-1 as test data.
## Source path.sh
```bash
. path.sh
```
SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`.
## Recognize with linear feature
```bash
bash run.sh
```
`run.sh` has multi stage, for details please see `run.sh`:
1. donwload dataset, model and lm
2. convert cmvn format and compute feature
3. decode w/o lm by feature
4. decode w/ ngram lm by feature
5. decode w/ TLG graph by feature
6. recognize w/ TLG graph by wav input
### Recognize with `.scp` file for wav
This sciprt using `recognizer_main` to recognize wav file.
The input is `scp` file which look like this:
```text
# head data/split1/1/aishell_test.scp
BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav
BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav
...
BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav
```
If you want to recognize one wav, you can make `scp` file like this:
```text
key path/to/wav/file
```
Then specify `--wav_rspecifier=` param for `recognizer_main` bin. For other flags meaning, please see `help`:
```bash
recognizer_main --help
```
For the exmaple to using `recognizer_main` please see `run.sh`.
### CTC Prefix Beam Search w/o LM
```
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
### CTC Prefix Beam Search w/ LM
LM: zh_giga.no_cna_cmn.prune01244.klm
```
Overall -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327
Mandarin -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
### CTC TLG WFST
LM: [aishell train](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/aishell_graph.zip)
--acoustic_scale=1.2
```
Overall -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1819
Mandarin -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1818
Other -> 0.00 % N=0 C=0 S=0 D=0 I=1
```
LM: [wenetspeech](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/wenetspeech_graph.zip)
--acoustic_scale=1.5
```
Overall -> 10.93 % N=104765 C=93410 S=9780 D=1575 I=95
Mandarin -> 10.93 % N=104762 C=93410 S=9779 D=1573 I=95
Other -> 100.00 % N=3 C=0 S=1 D=2 I=0
```
## Recognize with fbank feature
This script is same to `run.sh`, but using fbank feature.
```bash
bash run_fbank.sh
```
### CTC Prefix Beam Search w/o LM
```
Overall -> 10.44 % N=104765 C=94194 S=10174 D=397 I=369
Mandarin -> 10.44 % N=104762 C=94194 S=10171 D=397 I=369
Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
```
### CTC Prefix Beam Search w/ LM
LM: zh_giga.no_cna_cmn.prune01244.klm
```
Overall -> 5.82 % N=104765 C=99386 S=4944 D=435 I=720
Mandarin -> 5.82 % N=104762 C=99386 S=4941 D=435 I=720
English -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
### CTC TLG WFST
LM: [aishell train](https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip)
```
Overall -> 9.58 % N=104765 C=94817 S=4326 D=5622 I=84
Mandarin -> 9.57 % N=104762 C=94817 S=4325 D=5620 I=84
Other -> 100.00 % N=3 C=0 S=1 D=2 I=0
```
## Build TLG WFST graph
The script is for building TLG wfst graph, depending on `srilm`, please make sure it is installed.
For more information please see the script below.
```bash
bash ./local/run_build_tlg.sh
```
#!/bin/bash
# To be run from one directory above this script.
. ./path.sh
nj=40
text=data/local/lm/text
lexicon=data/local/dict/lexicon.txt
for f in "$text" "$lexicon"; do
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
done
# Check SRILM tools
if ! which ngram-count > /dev/null; then
echo "srilm tools are not found, please download it and install it from: "
echo "http://www.speech.sri.com/projects/srilm/download.html"
echo "Then add the tools to your PATH"
exit 1
fi
# This script takes no arguments. It assumes you have already run
# aishell_data_prep.sh.
# It takes as input the files
# data/local/lm/text
# data/local/dict/lexicon.txt
dir=data/local/lm
mkdir -p $dir
cleantext=$dir/text.no_oov
# oov to <SPOKEN_NOISE>
# lexicon line: word char0 ... charn
# text line: utt word0 ... wordn -> line: <SPOKEN_NOISE> word0 ... wordn
text_dir=$(dirname $text)
split_name=$(basename $text)
./local/split_data.sh $text_dir $text $split_name $nj
utils/run.pl JOB=1:$nj $text_dir/split${nj}/JOB/${split_name}.no_oov.log \
cat ${text_dir}/split${nj}/JOB/${split_name} \| awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \
\> ${text_dir}/split${nj}/JOB/${split_name}.no_oov || exit 1;
cat ${text_dir}/split${nj}/*/${split_name}.no_oov > $cleantext
# compute word counts, sort in descending order
# line: count word
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort --parallel=`nproc` | uniq -c | \
sort --parallel=`nproc` -nr > $dir/word.counts || exit 1;
# Get counts from acoustic training transcripts, and add one-count
# for each word in the lexicon (but not silence, we don't want it
# in the LM-- we'll add it optionally later).
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
sort --parallel=`nproc` | uniq -c | sort --parallel=`nproc` -nr > $dir/unigram.counts || exit 1;
# word with <s> </s>
cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo "<s>"; echo "</s>" ) > $dir/wordlist
# hold out to compute ppl
heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results
mkdir -p $dir
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
head -$heldout_sent > $dir/heldout
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
tail -n +$heldout_sent > $dir/train
ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
-map-unk "<UNK>" -kndiscount -interpolate -lm $dir/lm.arpa
ngram -lm $dir/lm.arpa -ppl $dir/heldout
\ No newline at end of file
#!/bin/bash
set -eo pipefail
. path.sh
# attention, please replace the vocab is only for this script.
# different acustic model has different vocab
ckpt_dir=data/fbank_model
unit=$ckpt_dir/data/lang_char/vocab.txt # vocab file, line: char/spm_pice
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
stage=-1
stop_stage=100
corpus=aishell
lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt
text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
. utils/parse_options.sh
data=$PWD/data
mkdir -p $data
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
if [ ! -f $data/speech.ngram.zh.tar.gz ];then
# download ngram
pushd $data
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/speech.ngram.zh.tar.gz
tar xvzf speech.ngram.zh.tar.gz
popd
fi
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
# download model
mkdir -p $ckpt_dir
pushd $ckpt_dir
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
popd
fi
fi
if [ ! -f $unit ]; then
echo "$0: No such file $unit"
exit 1;
fi
if ! which ngram-count; then
# need srilm install
pushd $MAIN_ROOT/tools
make srilm.done
popd
fi
mkdir -p data/local/dict
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# Prepare dict
# line: char/spm_pices
cp $unit data/local/dict/units.txt
if [ ! -f $lexicon ];then
utils/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon
echo "Generate $lexicon from $text"
fi
# filter by vocab
# line: word ph0 ... phn -> line: word char0 ... charn
utils/fst/prepare_dict.py \
--unit_file $unit \
--in_lexicon ${lexicon} \
--out_lexicon data/local/dict/lexicon.txt
fi
lm=data/local/lm
mkdir -p $lm
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# Train ngram lm
cp $text $lm/text
local/aishell_train_lms.sh
echo "build LM done."
fi
# build TLG
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# build T & L
utils/fst/compile_lexicon_token_fst.sh \
data/local/dict data/local/tmp data/local/lang
# build G & TLG
utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
fi
aishell_wav_scp=aishell_test.scp
nj=40
cmvn=$data/cmvn_fbank.ark
wfst=$data/lang_test
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
if [ ! -d $data/test ]; then
# download test dataset
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
# convert cmvn format
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
fi
wer=aishell_wer
label_file=aishell_result
export GLOG_logtostderr=1
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# recognize w/ TLG graph
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/check_tlg.log \
recognizer_main \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_5.jit.pdmodel \
--streaming_chunk=30 \
--use_fbank=true \
--param_path=$model_dir/avg_5.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--model_cache_shapes="5-1-2048,5-1-2048" \
--graph_path=$wfst/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_check_tlg
cat $data/split${nj}/*/result_check_tlg > $exp/${label_file}_check_tlg
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_check_tlg > $exp/${wer}.check_tlg
echo "recognizer test have finished!!!"
echo "please checkout in ${exp}/${wer}.check_tlg"
fi
exit 0
#!/usr/bin/env bash
set -eo pipefail
data=$1
scp=$2
split_name=$3
numsplit=$4
# save in $data/split{n}
# $scp to split
#
if [[ ! $numsplit -gt 0 ]]; then
echo "Invalid num-split argument";
exit 1;
fi
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done)
# if this mkdir fails due to argument-list being too long, iterate.
if ! mkdir -p $directories >&/dev/null; then
for n in `seq $numsplit`; do
mkdir -p $data/split${numsplit}/$n
done
fi
echo "utils/split_scp.pl $scp $scp_splits"
utils/split_scp.pl $scp $scp_splits
# This contains the locations of binarys build required for running the examples.
MAIN_ROOT=`realpath $PWD/../../../../`
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
# openfst bin & kaldi bin
KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/
OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src
# srilm
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
export SRILM=${MAIN_ROOT}/tools/srilm
SPEECHX_BIN=$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin
#!/bin/bash
set -x
set -e
. path.sh
nj=40
stage=0
stop_stage=100
. utils/parse_options.sh
# 1. compile
if [ ! -d ${SPEECHX_BUILD} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
mkdir -p exp
exp=$PWD/exp
aishell_wav_scp=aishell_test.scp
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
if [ ! -d $data/test ]; then
# donwload dataset
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
# download model
mkdir -p $ckpt_dir
pushd $ckpt_dir
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
if [ ! -f $lm ]; then
# download kenlm bin
pushd $data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi
fi
# 3. make feature
text=$data/test/text
label_file=./aishell_result
wer=./aishell_wer
export GLOG_logtostderr=1
cmvn=$data/cmvn.ark
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# 3. convert cmvn format and compute linear feat
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
compute_linear_spectrogram_main \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
--cmvn_file=$cmvn \
echo "feature make have finished!!!"
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# decode w/o lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--nnet_decoder_chunk=8 \
--dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result
cat $data/split${nj}/*/result > $exp/${label_file}
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
echo "ctc-prefix-beam-search-decoder-ol without lm has finished!!!"
echo "please checkout in ${exp}/${wer}"
tail -n 7 $exp/${wer}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode w/ ngram lm with feature input
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--nnet_decoder_chunk=8 \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
cat $data/split${nj}/*/result_lm > $exp/${label_file}_lm
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm
echo "ctc-prefix-beam-search-decoder-ol with lm test has finished!!!"
echo "please checkout in ${exp}/${wer}.lm"
tail -n 7 $exp/${wer}.lm
fi
wfst=$data/wfst/
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
mkdir -p $wfst
if [ ! -f $wfst/aishell_graph.zip ]; then
# download TLG graph
pushd $wfst
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip aishell_graph.zip
mv aishell_graph/* $wfst
popd
fi
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# decoder w/ TLG graph with feature input
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$wfst/TLG.fst --max_active=7500 \
--nnet_decoder_chunk=8 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg
echo "wfst-decoder-ol have finished!!!"
echo "please checkout in ${exp}/${wer}.tlg"
tail -n 7 $exp/${wer}.tlg
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# recognize from wav file w/ TLG graph
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
recognizer_main \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \
--nnet_decoder_chunk=8 \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$wfst/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer
cat $data/split${nj}/*/result_recognizer > $exp/${label_file}_recognizer
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer
echo "recognizer test have finished!!!"
echo "please checkout in ${exp}/${wer}.recognizer"
tail -n 7 $exp/${wer}.recognizer
fi
\ No newline at end of file
#!/bin/bash
set +x
set -e
. path.sh
nj=40
stage=0
stop_stage=5
. utils/parse_options.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/fbank_model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
mkdir -p exp
exp=$PWD/exp
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
aishell_wav_scp=aishell_test.scp
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
mkdir -p $ckpt_dir
pushd $ckpt_dir
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
popd
fi
if [ ! -f $lm ]; then
pushd $data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi
fi
# 3. make feature
text=$data/test/text
label_file=./aishell_result_fbank
wer=./aishell_wer_fbank
export GLOG_logtostderr=1
cmvn=$data/cmvn_fbank.ark
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# 3. convert cmvn format and compute fbank feat
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn --binary=false
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
compute_fbank_main \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank_feat.ark,$data/split${nj}/JOB/fbank_feat.scp \
--cmvn_file=$cmvn \
--streaming_chunk=36
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# decode w/ lm by feature
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \
ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--model_cache_shapes="5-1-2048,5-1-2048" \
--nnet_decoder_chunk=8 \
--dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank
cat $data/split${nj}/*/result_fbank > $exp/${label_file}
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
tail -n 7 $exp/${wer}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with ngram lm by feature
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \
ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--model_cache_shapes="5-1-2048,5-1-2048" \
--nnet_decoder_chunk=8 \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/fbank_result_lm
cat $data/split${nj}/*/fbank_result_lm > $exp/${label_file}_lm
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm
tail -n 7 $exp/${wer}.lm
fi
wfst=$data/wfst_fbank/
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
mkdir -p $wfst
if [ ! -f $wfst/aishell_graph2.zip ]; then
pushd $wfst
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip
unzip aishell_graph2.zip
mv aishell_graph2/* $wfst
popd
fi
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# decode w/ TLG graph by feature
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \
ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--model_cache_shapes="5-1-2048,5-1-2048" \
--nnet_decoder_chunk=8 \
--graph_path=$wfst/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg
echo "wfst-decoder-ol have finished!!!"
echo "please checkout in ${exp}/${wer}.tlg"
tail -n 7 $exp/${wer}.tlg
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# recgonize w/ TLG graph by wav
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/fbank_recognizer.log \
recognizer_main \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_5.jit.pdmodel \
--use_fbank=true \
--param_path=$model_dir/avg_5.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--model_cache_shapes="5-1-2048,5-1-2048" \
--nnet_decoder_chunk=8 \
--graph_path=$wfst/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank_recognizer
cat $data/split${nj}/*/result_fbank_recognizer > $exp/${label_file}_recognizer
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer
echo "recognizer test have finished!!!"
echo "please checkout in ${exp}/${wer}.recognizer"
tail -n 7 $exp/${wer}.recognizer
fi
../../../../utils/
\ No newline at end of file
# Convert DeepSpeech2 model to ONNX format
> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/).
This example demonstrate converting ds2 model to ONNX fromat.
Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct.
The example test with these packages installed:
```
paddle2onnx 0.9.8 # develop 62c5424e22cd93968dc831216fc9e0f0fce3d819
paddleaudio 0.2.1
paddlefsl 1.1.0
paddlenlp 2.2.6
paddlepaddle-gpu 2.2.2
paddlespeech 0.0.0 # develop
paddlespeech-ctcdecoders 0.2.0
paddlespeech-feat 0.1.0
onnx 1.11.0
onnx-simplifier 0.0.0 # https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape
onnxoptimizer 0.2.7
onnxruntime 1.11.0
```
## Using
```
bash run.sh --stage 0 --stop_stage 5
```
1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
2. check paddleinference and onnxruntime output equal.
3. optimize onnx model
4. check paddleinference and optimized onnxruntime output equal.
5. quantize onnx model
6. check paddleinference and optimized onnxruntime output equal.
For more details please see `run.sh`.
## Outputs
The optimized onnx model is `exp/model.opt.onnx`, quanted model is `exp/model.optset11.quant.onnx`.
## [Results](https://github.com/PaddlePaddle/PaddleSpeech/wiki/ASR-Benchmark#streaming-asr)
机器硬件:`CPU:Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz`
测试脚本:`Streaming Server`
Acoustic Model | Model Size | enigne | dedoding_method | ctc_weight | decoding_chunk_size | num_decoding_left_chunk | RTF |
|:-------------:| :-----: | :-----: | :------------:| :-----: | :-----: | :-----: |:-----:|
| deepspeech2online_wenetspeech | 659MB | infernece | ctc_prefix_beam_search | - | 1 | - | 1.9108175171428279(utts=80) |
| deepspeech2online_wenetspeech | 659MB | onnx | ctc_prefix_beam_search | - | 1 | - | 0.5617182449999291 (utts=80) |
| deepspeech2online_wenetspeech | 166MB | onnx quant | ctc_prefix_beam_search | - | 1 | - | 0.44507715475808385 (utts=80) |
> quant 和机器有关,不是所有机器都支持。ONNX quant测试机器指令集支持:
> Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology eagerfpu pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 arat umip pku ospke avx512_vnni spec_ctrl
#!/usr/bin/env python3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import pickle
import numpy as np
import onnxruntime
import paddle
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--input_file',
type=str,
default="static_ds2online_inputs.pickle",
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model",
)
parser.add_argument(
'--model_type',
type=str,
default="aishell",
help="aishell(1024) or wenetspeech(2048)", )
parser.add_argument(
'--model_dir', type=str, default=".", help="paddle model dir.")
parser.add_argument(
'--model_prefix',
type=str,
default="avg_1.jit",
help="paddle model prefix.")
parser.add_argument(
'--onnx_model',
type=str,
default='./model.old.onnx',
help="onnx model.")
return parser.parse_args()
if __name__ == '__main__':
FLAGS = parse_args()
# input and output
with open(FLAGS.input_file, 'rb') as f:
iodict = pickle.load(f)
print(iodict.keys())
audio_chunk = iodict['audio_chunk']
audio_chunk_lens = iodict['audio_chunk_lens']
chunk_state_h_box = iodict['chunk_state_h_box']
chunk_state_c_box = iodict['chunk_state_c_bos']
print("raw state shape: ", chunk_state_c_box.shape)
if FLAGS.model_type == 'wenetspeech':
chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1)
chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1)
print("state shape: ", chunk_state_c_box.shape)
# paddle
model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix))
res_chunk, res_lens, chunk_state_h, chunk_state_c = model(
paddle.to_tensor(audio_chunk),
paddle.to_tensor(audio_chunk_lens),
paddle.to_tensor(chunk_state_h_box),
paddle.to_tensor(chunk_state_c_box), )
# onnxruntime
options = onnxruntime.SessionOptions()
options.enable_profiling = True
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
"audio_chunk": audio_chunk,
"audio_chunk_lens": audio_chunk_lens,
"chunk_state_h_box": chunk_state_h_box,
"chunk_state_c_box": chunk_state_c_box
})
print(sess.end_profiling())
# assert paddle equal ort
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
if FLAGS.model_type == 'aishell':
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))
#!/bin/bash
# show model
if [ $# != 1 ];then
echo "usage: $0 model_path"
exit 1
fi
file=$1
pip install netron
netron -p 8082 --host $(hostname -i) $file
\ No newline at end of file
#!/bin/bash
# clone onnx repos
git clone https://github.com/onnx/onnx.git
git clone https://github.com/microsoft/onnxruntime.git
git clone https://github.com/PaddlePaddle/Paddle2ONNX.git
\ No newline at end of file
#!/usr/bin/env python3
import argparse
import onnx
from onnx import version_converter
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog=__doc__)
parser.add_argument(
"--model-file", type=str, required=True, help='path/to/the/model.onnx.')
parser.add_argument(
"--save-model",
type=str,
required=True,
help='path/to/saved/model.onnx.')
# Models must be opset10 or higher to be quantized.
parser.add_argument(
"--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
args = parser.parse_args()
print(f"to opset: {args.target_opset}")
# Preprocessing: load the model to be converted.
model_path = args.model_file
original_model = onnx.load(model_path)
# print('The model before conversion:\n{}'.format(original_model))
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model
converted_model = version_converter.convert_version(original_model,
args.target_opset)
# print('The model after conversion:\n{}'.format(converted_model))
onnx.save(converted_model, args.save_model)
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# flake8: noqa
import argparse
import logging
import numpy as np
import onnx
import sympy
from onnx import helper
from onnx import numpy_helper
from onnx import shape_inference
from packaging import version
assert version.parse(onnx.__version__) >= version.parse("1.8.0")
logger = logging.getLogger(__name__)
def get_attribute(node, attr_name, default_value=None):
found = [attr for attr in node.attribute if attr.name == attr_name]
if found:
return helper.get_attribute_value(found[0])
return default_value
def get_dim_from_proto(dim):
return getattr(dim, dim.WhichOneof('value')) if type(
dim.WhichOneof('value')) == str else None
def is_sequence(type_proto):
cls_type = type_proto.WhichOneof('value')
assert cls_type in ['tensor_type', 'sequence_type']
return cls_type == 'sequence_type'
def get_shape_from_type_proto(type_proto):
assert not is_sequence(type_proto)
if type_proto.tensor_type.HasField('shape'):
return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim]
else:
return None # note no shape is different from shape without dim (scalar)
def get_shape_from_value_info(vi):
cls_type = vi.type.WhichOneof('value')
if cls_type is None:
return None
if is_sequence(vi.type):
if 'tensor_type' == vi.type.sequence_type.elem_type.WhichOneof('value'):
return get_shape_from_type_proto(vi.type.sequence_type.elem_type)
else:
return None
else:
return get_shape_from_type_proto(vi.type)
def make_named_value_info(name):
vi = onnx.ValueInfoProto()
vi.name = name
return vi
def get_shape_from_sympy_shape(sympy_shape):
return [
None if i is None else (int(i) if is_literal(i) else str(i))
for i in sympy_shape
]
def is_literal(dim):
return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(
dim, 'is_number') and dim.is_number)
def handle_negative_axis(axis, rank):
assert axis < rank and axis >= -rank
return axis if axis >= 0 else rank + axis
def get_opset(mp, domain=None):
domain = domain or ['', 'onnx', 'ai.onnx']
if type(domain) != list:
domain = [domain]
for opset in mp.opset_import:
if opset.domain in domain:
return opset.version
return None
def as_scalar(x):
if type(x) == list:
assert len(x) == 1
return x[0]
elif type(x) == np.ndarray:
return x.item()
else:
return x
def as_list(x, keep_none):
if type(x) == list:
return x
elif type(x) == np.ndarray:
return list(x)
elif keep_none and x is None:
return None
else:
return [x]
def sympy_reduce_product(x):
if type(x) == list:
value = sympy.Integer(1)
for v in x:
value = value * v
else:
value = x
return value
class SymbolicShapeInference:
def __init__(self,
int_max,
auto_merge,
guess_output_rank,
verbose,
prefix=''):
self.dispatcher_ = {
'Add':
self._infer_symbolic_compute_ops,
'ArrayFeatureExtractor':
self._infer_ArrayFeatureExtractor,
'AveragePool':
self._infer_Pool,
'BatchNormalization':
self._infer_BatchNormalization,
'Cast':
self._infer_Cast,
'CategoryMapper':
self._infer_CategoryMapper,
'Compress':
self._infer_Compress,
'Concat':
self._infer_Concat,
'ConcatFromSequence':
self._infer_ConcatFromSequence,
'Constant':
self._infer_Constant,
'ConstantOfShape':
self._infer_ConstantOfShape,
'Conv':
self._infer_Conv,
'CumSum':
self._pass_on_shape_and_type,
'Div':
self._infer_symbolic_compute_ops,
'Einsum':
self._infer_Einsum,
'Expand':
self._infer_Expand,
'Equal':
self._infer_symbolic_compute_ops,
'Floor':
self._infer_symbolic_compute_ops,
'Gather':
self._infer_Gather,
'GatherElements':
self._infer_GatherElements,
'GatherND':
self._infer_GatherND,
'Gelu':
self._pass_on_shape_and_type,
'If':
self._infer_If,
'Loop':
self._infer_Loop,
'MatMul':
self._infer_MatMul,
'MatMulInteger16':
self._infer_MatMulInteger,
'MaxPool':
self._infer_Pool,
'Max':
self._infer_symbolic_compute_ops,
'Min':
self._infer_symbolic_compute_ops,
'Mul':
self._infer_symbolic_compute_ops,
'NonMaxSuppression':
self._infer_NonMaxSuppression,
'NonZero':
self._infer_NonZero,
'OneHot':
self._infer_OneHot,
'Pad':
self._infer_Pad,
'Range':
self._infer_Range,
'Reciprocal':
self._pass_on_shape_and_type,
'ReduceSum':
self._infer_ReduceSum,
'ReduceProd':
self._infer_ReduceProd,
'Reshape':
self._infer_Reshape,
'Resize':
self._infer_Resize,
'Round':
self._pass_on_shape_and_type,
'Scan':
self._infer_Scan,
'ScatterElements':
self._infer_ScatterElements,
'SequenceAt':
self._infer_SequenceAt,
'SequenceInsert':
self._infer_SequenceInsert,
'Shape':
self._infer_Shape,
'Size':
self._infer_Size,
'Slice':
self._infer_Slice,
'SoftmaxCrossEntropyLoss':
self._infer_SoftmaxCrossEntropyLoss,
'SoftmaxCrossEntropyLossInternal':
self._infer_SoftmaxCrossEntropyLoss,
'NegativeLogLikelihoodLossInternal':
self._infer_SoftmaxCrossEntropyLoss,
'Split':
self._infer_Split,
'SplitToSequence':
self._infer_SplitToSequence,
'Squeeze':
self._infer_Squeeze,
'Sub':
self._infer_symbolic_compute_ops,
'Tile':
self._infer_Tile,
'TopK':
self._infer_TopK,
'Transpose':
self._infer_Transpose,
'Unsqueeze':
self._infer_Unsqueeze,
'Where':
self._infer_symbolic_compute_ops,
'ZipMap':
self._infer_ZipMap,
'Neg':
self._infer_symbolic_compute_ops,
# contrib ops:
'Attention':
self._infer_Attention,
'BiasGelu':
self._infer_BiasGelu,
'EmbedLayerNormalization':
self._infer_EmbedLayerNormalization,
'FastGelu':
self._infer_FastGelu,
'Gelu':
self._infer_Gelu,
'LayerNormalization':
self._infer_LayerNormalization,
'LongformerAttention':
self._infer_LongformerAttention,
'PythonOp':
self._infer_PythonOp,
'SkipLayerNormalization':
self._infer_SkipLayerNormalization
}
self.aten_op_dispatcher_ = {
'aten::embedding': self._infer_Gather,
'aten::bitwise_or': self._infer_aten_bitwise_or,
'aten::diagonal': self._infer_aten_diagonal,
'aten::max_pool2d_with_indices': self._infer_aten_pool2d,
'aten::multinomial': self._infer_aten_multinomial,
'aten::unfold': self._infer_aten_unfold,
'aten::argmax': self._infer_aten_argmax,
'aten::avg_pool2d': self._infer_aten_pool2d,
'aten::_adaptive_avg_pool2d': self._infer_aten_pool2d,
'aten::binary_cross_entropy_with_logits': self._infer_aten_bce,
'aten::numpy_T': self._infer_Transpose,
}
self.run_ = True
self.suggested_merge_ = {}
self.symbolic_dims_ = {}
self.input_symbols_ = {}
self.auto_merge_ = auto_merge
self.guess_output_rank_ = guess_output_rank
self.verbose_ = verbose
self.int_max_ = int_max
self.subgraph_id_ = 0
self.prefix_ = prefix
def _add_suggested_merge(self, symbols, apply=False):
assert all([(type(s) == str and s in self.symbolic_dims_) or
is_literal(s) for s in symbols])
symbols = set(symbols)
for k, v in self.suggested_merge_.items():
if k in symbols:
symbols.remove(k)
symbols.add(v)
map_to = None
# if there is literal, map to it first
for s in symbols:
if is_literal(s):
map_to = s
break
# when no literals, map to input symbolic dims, then existing symbolic dims
if map_to is None:
for s in symbols:
if s in self.input_symbols_:
map_to = s
break
if map_to is None:
for s in symbols:
if type(self.symbolic_dims_[s]) == sympy.Symbol:
map_to = s
break
# when nothing to map to, use the shorter one
if map_to is None:
if self.verbose_ > 0:
logger.warning(
'Potential unsafe merge between symbolic expressions: ({})'.
format(','.join(symbols)))
symbols_list = list(symbols)
lens = [len(s) for s in symbols_list]
map_to = symbols_list[lens.index(min(lens))]
symbols.remove(map_to)
for s in symbols:
if s == map_to:
continue
if is_literal(map_to) and is_literal(s):
assert int(map_to) == int(s)
self.suggested_merge_[s] = int(map_to) if is_literal(
map_to) else map_to
for k, v in self.suggested_merge_.items():
if v == s:
self.suggested_merge_[k] = map_to
if apply and self.auto_merge_:
self._apply_suggested_merge()
def _apply_suggested_merge(self, graph_input_only=False):
if not self.suggested_merge_:
return
for i in list(self.out_mp_.graph.input) + (
[] if graph_input_only else list(self.out_mp_.graph.value_info)):
for d in i.type.tensor_type.shape.dim:
if d.dim_param in self.suggested_merge_:
v = self.suggested_merge_[d.dim_param]
if is_literal(v):
d.dim_value = int(v)
else:
d.dim_param = v
def _preprocess(self, in_mp):
self.out_mp_ = onnx.ModelProto()
self.out_mp_.CopyFrom(in_mp)
self.graph_inputs_ = dict(
[(i.name, i) for i in list(self.out_mp_.graph.input)])
self.initializers_ = dict(
[(i.name, i) for i in self.out_mp_.graph.initializer])
self.known_vi_ = dict(
[(i.name, i) for i in list(self.out_mp_.graph.input)])
self.known_vi_.update(
dict([(i.name, helper.make_tensor_value_info(i.name, i.data_type,
list(i.dims)))
for i in self.out_mp_.graph.initializer]))
def _merge_symbols(self, dims):
if not all([type(d) == str for d in dims]):
if self.auto_merge_:
unique_dims = list(set(dims))
is_int = [is_literal(d) for d in unique_dims]
assert sum(
is_int
) <= 1 # if there are more than 1 unique ints, something is wrong
if sum(is_int) == 1:
int_dim = is_int.index(1)
if self.verbose_ > 0:
logger.debug('dim {} has been merged with value {}'.
format(unique_dims[:int_dim] + unique_dims[
int_dim + 1:], unique_dims[int_dim]))
self._check_merged_dims(unique_dims, allow_broadcast=False)
return unique_dims[int_dim]
else:
if self.verbose_ > 0:
logger.debug('dim {} has been mergd with dim {}'.format(
unique_dims[1:], unique_dims[0]))
return dims[0]
else:
return None
if all([d == dims[0] for d in dims]):
return dims[0]
merged = [
self.suggested_merge_[d] if d in self.suggested_merge_ else d
for d in dims
]
if all([d == merged[0] for d in merged]):
assert merged[0] in self.symbolic_dims_
return merged[0]
else:
return None
# broadcast from right to left, and merge symbolic dims if needed
def _broadcast_shapes(self, shape1, shape2):
new_shape = []
rank1 = len(shape1)
rank2 = len(shape2)
new_rank = max(rank1, rank2)
for i in range(new_rank):
dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
if dim1 == 1 or dim1 == dim2:
new_dim = dim2
elif dim2 == 1:
new_dim = dim1
else:
new_dim = self._merge_symbols([dim1, dim2])
if not new_dim:
# warning about unsupported broadcast when not auto merge
# note that auto merge has the risk of incorrectly merge symbols while one of them being 1
# for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b'
if self.auto_merge_:
self._add_suggested_merge([dim1, dim2], apply=True)
else:
logger.warning('unsupported broadcast between ' + str(
dim1) + ' ' + str(dim2))
new_shape = [new_dim] + new_shape
return new_shape
def _get_shape(self, node, idx):
name = node.input[idx]
if name in self.known_vi_:
vi = self.known_vi_[name]
return get_shape_from_value_info(vi)
else:
assert name in self.initializers_
return list(self.initializers_[name].dims)
def _get_shape_rank(self, node, idx):
return len(self._get_shape(node, idx))
def _get_sympy_shape(self, node, idx):
sympy_shape = []
for d in self._get_shape(node, idx):
if type(d) == str:
sympy_shape.append(self.symbolic_dims_[d] if d in
self.symbolic_dims_ else sympy.Symbol(
d, integer=True, nonnegative=True))
else:
assert None != d
sympy_shape.append(d)
return sympy_shape
def _get_value(self, node, idx):
name = node.input[idx]
assert name in self.sympy_data_ or name in self.initializers_
return self.sympy_data_[
name] if name in self.sympy_data_ else numpy_helper.to_array(
self.initializers_[name])
def _try_get_value(self, node, idx):
if idx >= len(node.input):
return None
name = node.input[idx]
if name in self.sympy_data_ or name in self.initializers_:
return self._get_value(node, idx)
return None
def _update_computed_dims(self, new_sympy_shape):
for i, new_dim in enumerate(new_sympy_shape):
if not is_literal(new_dim) and not type(new_dim) == str:
str_dim = str(new_dim)
if str_dim in self.suggested_merge_:
if is_literal(self.suggested_merge_[str_dim]):
continue # no need to create dim for literals
new_sympy_shape[i] = self.symbolic_dims_[
self.suggested_merge_[str_dim]]
else:
# add new_dim if it's a computational expression
if not str(new_dim) in self.symbolic_dims_:
self.symbolic_dims_[str(new_dim)] = new_dim
def _onnx_infer_single_node(self, node):
# skip onnx shape inference for some ops, as they are handled in _infer_*
skip_infer = node.op_type in [
'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', 'Attention',
'BiasGelu', 'EmbedLayerNormalization', 'FastGelu', 'Gelu',
'LayerNormalization', 'LongformerAttention',
'SkipLayerNormalization', 'PythonOp'
]
if not skip_infer:
# Only pass initializers that satisfy the following condition:
# (1) Operator need value of some input for shape inference.
# For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output.
# (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
# (3) The initializer is not in graph input. The means the node input is "constant" in inference.
initializers = []
if (get_opset(self.out_mp_) >= 9) and node.op_type in ['Unsqueeze']:
initializers = [
self.initializers_[name] for name in node.input
if (name in self.initializers_ and name not in
self.graph_inputs_)
]
# run single node inference with self.known_vi_ shapes
tmp_graph = helper.make_graph(
[node], 'tmp', [self.known_vi_[i] for i in node.input if i],
[make_named_value_info(i) for i in node.output], initializers)
self.tmp_mp_.graph.CopyFrom(tmp_graph)
self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
for i_o in range(len(node.output)):
o = node.output[i_o]
vi = self.out_mp_.graph.value_info.add()
if not skip_infer:
vi.CopyFrom(self.tmp_mp_.graph.output[i_o])
else:
vi.name = o
self.known_vi_[o] = vi
def _onnx_infer_subgraph(self,
node,
subgraph,
use_node_input=True,
inc_subgraph_id=True):
if self.verbose_ > 2:
logger.debug(
'Inferencing subgraph of node {} with output({}...): {}'.format(
node.name, node.output[0], node.op_type))
# node inputs are not passed directly to the subgraph
# it's up to the node dispatcher to prepare subgraph input
# for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape
# besides, inputs in subgraph could shadow implicit inputs
subgraph_inputs = set(
[i.name for i in list(subgraph.initializer) + list(subgraph.input)])
subgraph_implicit_input = set([
name for name in self.known_vi_.keys()
if not name in subgraph_inputs
])
tmp_graph = helper.make_graph(
list(subgraph.node), 'tmp',
list(subgraph.input) +
[self.known_vi_[i] for i in subgraph_implicit_input],
[make_named_value_info(i.name) for i in subgraph.output])
tmp_graph.initializer.extend([
i for i in self.out_mp_.graph.initializer
if i.name in subgraph_implicit_input
])
tmp_graph.initializer.extend(subgraph.initializer)
self.tmp_mp_.graph.CopyFrom(tmp_graph)
symbolic_shape_inference = SymbolicShapeInference(
self.int_max_,
self.auto_merge_,
self.guess_output_rank_,
self.verbose_,
prefix=self.prefix_ + '_' + str(self.subgraph_id_))
if inc_subgraph_id:
self.subgraph_id_ += 1
all_shapes_inferred = False
symbolic_shape_inference._preprocess(self.tmp_mp_)
symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl(
self.sympy_data_.copy())
symbolic_shape_inference._update_output_from_vi()
if use_node_input:
# if subgraph uses node input, it needs to update to merged dims
subgraph.ClearField('input')
subgraph.input.extend(
symbolic_shape_inference.out_mp_.graph.input[:len(node.input)])
subgraph.ClearField('output')
subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output)
subgraph.ClearField('value_info')
subgraph.value_info.extend(
symbolic_shape_inference.out_mp_.graph.value_info)
subgraph.ClearField('node')
subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node)
# for new symbolic dims from subgraph output, add to main graph symbolic dims
subgraph_shapes = [
get_shape_from_value_info(o)
for o in symbolic_shape_inference.out_mp_.graph.output
]
subgraph_new_symbolic_dims = set([
d for s in subgraph_shapes
if s for d in s if type(d) == str and not d in self.symbolic_dims_
])
new_dims = {}
for d in subgraph_new_symbolic_dims:
assert d in symbolic_shape_inference.symbolic_dims_
new_dims[d] = symbolic_shape_inference.symbolic_dims_[d]
self.symbolic_dims_.update(new_dims)
return symbolic_shape_inference
def _get_int_values(self, node, broadcast=False):
values = [self._try_get_value(node, i) for i in range(len(node.input))]
if all([v is not None for v in values]):
# some shape compute is in floating point, cast to int for sympy
for i, v in enumerate(values):
if type(v) != np.ndarray:
continue
if len(v.shape) > 1:
new_v = None # ignore value for rank > 1
elif len(v.shape) == 0:
new_v = int(v.item())
else:
assert len(v.shape) == 1
new_v = [int(vv) for vv in v]
values[i] = new_v
values_len = [len(v) if type(v) == list else 0 for v in values]
max_len = max(values_len)
if max_len >= 1 and broadcast:
# broadcast
for i, v in enumerate(values):
if v is None:
continue # don't broadcast if value is unknown
if type(v) == list:
if len(v) < max_len:
values[i] = v * max_len
else:
assert len(v) == max_len
else:
values[i] = [v] * max_len
return values
def _compute_on_sympy_data(self, node, op_func):
assert len(node.output) == 1
values = self._get_int_values(node, broadcast=True)
if all([v is not None for v in values]):
is_list = [type(v) == list for v in values]
as_list = any(is_list)
if as_list:
self.sympy_data_[node.output[
0]] = [op_func(vs) for vs in zip(*values)]
else:
self.sympy_data_[node.output[0]] = op_func(values)
def _pass_on_sympy_data(self, node):
assert len(
node.
input) == 1 or node.op_type in ['Reshape', 'Unsqueeze', 'Squeeze']
self._compute_on_sympy_data(node, lambda x: x[0])
def _pass_on_shape_and_type(self, node):
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
self._get_shape(node, 0)))
def _new_symbolic_dim(self, prefix, dim):
new_dim = '{}_d{}'.format(prefix, dim)
if new_dim in self.suggested_merge_:
v = self.suggested_merge_[new_dim]
new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v
else:
new_symbolic_dim = sympy.Symbol(
new_dim, integer=True, nonnegative=True)
self.symbolic_dims_[new_dim] = new_symbolic_dim
return new_symbolic_dim
def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
return self._new_symbolic_dim('{}{}_{}_o{}_'.format(
node.op_type, self.prefix_,
list(self.out_mp_.graph.node).index(node), out_idx), dim)
def _new_symbolic_shape(self, rank, node, out_idx=0):
return [
self._new_symbolic_dim_from_output(node, out_idx, i)
for i in range(rank)
]
def _compute_conv_pool_shape(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
if len(node.input) > 1:
W_shape = self._get_sympy_shape(node, 1)
rank = len(W_shape) - 2 # number of spatial axes
kernel_shape = W_shape[-rank:]
sympy_shape[1] = W_shape[0]
else:
W_shape = None
kernel_shape = get_attribute(node, 'kernel_shape')
rank = len(kernel_shape)
assert len(sympy_shape) == rank + 2
# only need to symbolic shape inference if input has symbolic dims in spatial axes
is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]]
if not any(is_symbolic_dims):
shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
if len(shape) > 0:
assert len(sympy_shape) == len(shape)
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
return sympy_shape
dilations = get_attribute(node, 'dilations', [1] * rank)
strides = get_attribute(node, 'strides', [1] * rank)
effective_kernel_shape = [(k - 1) * d + 1
for k, d in zip(kernel_shape, dilations)]
pads = get_attribute(node, 'pads')
if pads is None:
pads = [0] * (2 * rank)
auto_pad = get_attribute(node, 'auto_pad',
b'NOTSET').decode('utf-8')
if auto_pad != 'VALID' and auto_pad != 'NOTSET':
try:
residual = [
sympy.Mod(d, s)
for d, s in zip(sympy_shape[-rank:], strides)
]
total_pads = [
max(0, (k - s) if r == 0 else (k - r))
for k, s, r in zip(effective_kernel_shape, strides,
residual)
]
except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
total_pads = [
max(0, (k - s))
for k, s in zip(effective_kernel_shape, strides)
] # assuming no residual if sympy throws error
elif auto_pad == 'VALID':
total_pads = []
else:
total_pads = [0] * rank
else:
assert len(pads) == 2 * rank
total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])]
ceil_mode = get_attribute(node, 'ceil_mode', 0)
for i in range(rank):
effective_input_size = sympy_shape[-rank + i]
if len(total_pads) > 0:
effective_input_size = effective_input_size + total_pads[i]
if ceil_mode:
strided_kernel_positions = sympy.ceiling(
(effective_input_size - effective_kernel_shape[i]) /
strides[i])
else:
strided_kernel_positions = (
effective_input_size - effective_kernel_shape[i]
) // strides[i]
sympy_shape[-rank + i] = strided_kernel_positions + 1
return sympy_shape
def _check_merged_dims(self, dims, allow_broadcast=True):
if allow_broadcast:
dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
if not all([d == dims[0] for d in dims]):
self._add_suggested_merge(dims, apply=True)
def _compute_matmul_shape(self, node, output_dtype=None):
lhs_shape = self._get_shape(node, 0)
rhs_shape = self._get_shape(node, 1)
lhs_rank = len(lhs_shape)
rhs_rank = len(rhs_shape)
lhs_reduce_dim = 0
rhs_reduce_dim = 0
assert lhs_rank > 0 and rhs_rank > 0
if lhs_rank == 1 and rhs_rank == 1:
new_shape = []
elif lhs_rank == 1:
rhs_reduce_dim = -2
new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]]
elif rhs_rank == 1:
lhs_reduce_dim = -1
new_shape = lhs_shape[:lhs_reduce_dim]
else:
lhs_reduce_dim = -1
rhs_reduce_dim = -2
new_shape = self._broadcast_shapes(
lhs_shape[:-2],
rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]]
# merge reduce dim
self._check_merged_dims(
[lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
allow_broadcast=False)
if output_dtype is None:
# infer output_dtype from input type when not specified
output_dtype = self.known_vi_[node.input[
0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_dtype,
new_shape))
def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
'''
update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches
'''
dst_tensor_type = dst_type.sequence_type.elem_type.tensor_type if is_sequence(
dst_type) else dst_type.tensor_type
src_tensor_type = src_type.sequence_type.elem_type.tensor_type if is_sequence(
src_type) else src_type.tensor_type
if dst_tensor_type.elem_type != src_tensor_type.elem_type:
node_id = node.name if node.name else node.op_type
raise ValueError(
f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}"
)
if dst_tensor_type.HasField('shape'):
for di, ds in enumerate(
zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)):
if ds[0] != ds[1]:
# create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type
# for sequence_type, clear the dimension
new_dim = onnx.TensorShapeProto.Dimension()
if not is_sequence(dst_type):
new_dim.dim_param = str(
self._new_symbolic_dim_from_output(node, out_idx,
di))
dst_tensor_type.shape.dim[di].CopyFrom(new_dim)
else:
dst_tensor_type.CopyFrom(src_tensor_type)
def _infer_ArrayFeatureExtractor(self, node):
data_shape = self._get_shape(node, 0)
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, data_shape[:-1] +
indices_shape))
def _infer_symbolic_compute_ops(self, node):
funcs = {
'Add':
lambda l: l[0] + l[1],
'Div':
lambda l: l[0] // l[1], # integer div in sympy
'Equal':
lambda l: l[0] == l[1],
'Floor':
lambda l: sympy.floor(l[0]),
'Max':
lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
'Min':
lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])),
'Mul':
lambda l: l[0] * l[1],
'Sub':
lambda l: l[0] - l[1],
'Where':
lambda l: l[1] if l[0] else l[2],
'Neg':
lambda l: -l[0]
}
assert node.op_type in funcs
self._compute_on_sympy_data(node, funcs[node.op_type])
def _infer_Cast(self, node):
self._pass_on_sympy_data(node)
def _infer_CategoryMapper(self, node):
input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
if input_type == onnx.TensorProto.STRING:
output_type = onnx.TensorProto.INT64
else:
output_type = onnx.TensorProto.STRING
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_type,
self._get_shape(node, 0)))
def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0)
# create a new symbolic dimension for Compress output
compress_len = str(self._new_symbolic_dim_from_output(node))
axis = get_attribute(node, 'axis')
if axis == None:
# when axis is not specified, input is flattened before compress so output is 1D
output_shape = [compress_len]
else:
output_shape = input_shape
output_shape[handle_negative_axis(axis, len(
input_shape))] = compress_len
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Concat(self, node):
if any([
i in self.sympy_data_ or i in self.initializers_
for i in node.input
]):
values = self._get_int_values(node)
print("=======", values, node.name, get_attribute(node, 'axis'))
if all([v is not None for v in values]):
axis = get_attribute(node, 'axis')
if axis < 0:
axis = axis + len(values[0])
assert 0 == axis
self.sympy_data_[node.output[0]] = []
for i in range(len(node.input)):
value = values[i]
if type(value) == list:
self.sympy_data_[node.output[0]].extend(value)
else:
self.sympy_data_[node.output[0]].append(value)
sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis'), len(sympy_shape))
for i_idx in range(1, len(node.input)):
input_shape = self._get_sympy_shape(node, i_idx)
if input_shape:
sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
self._update_computed_dims(sympy_shape)
# merge symbolic dims for non-concat axes
for d in range(len(sympy_shape)):
if d == axis:
continue
dims = [
self._get_shape(node, i_idx)[d]
for i_idx in range(len(node.input))
if self._get_shape(node, i_idx)
]
if all([d == dims[0] for d in dims]):
continue
merged = self._merge_symbols(dims)
if type(merged) == str:
sympy_shape[d] = self.symbolic_dims_[merged] if merged else None
else:
sympy_shape[d] = merged
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[0]].type.tensor_type.
elem_type, get_shape_from_sympy_shape(sympy_shape)))
def _infer_ConcatFromSequence(self, node):
seq_shape = self._get_shape(node, 0)
new_axis = 1 if get_attribute(node, 'new_axis') else 0
axis = handle_negative_axis(
get_attribute(node, 'axis'), len(seq_shape) + new_axis)
concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis))
new_shape = seq_shape
if new_axis:
new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:]
else:
new_shape[axis] = concat_dim
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[0]]
.type.sequence_type.elem_type.tensor_type.elem_type, new_shape))
def _infer_Constant(self, node):
t = get_attribute(node, 'value')
self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
def _infer_ConstantOfShape(self, node):
sympy_shape = self._get_int_values(node)[0]
vi = self.known_vi_[node.output[0]]
if sympy_shape is not None:
if type(sympy_shape) != list:
sympy_shape = [sympy_shape]
self._update_computed_dims(sympy_shape)
# update sympy data if output type is int, and shape is known
if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all(
[is_literal(x) for x in sympy_shape]):
self.sympy_data_[node.output[0]] = np.ones(
[int(x) for x in sympy_shape],
dtype=np.int64) * numpy_helper.to_array(
get_attribute(node, 'value', 0))
else:
# create new dynamic shape
# note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length
sympy_shape = self._new_symbolic_shape(
self._get_shape(node, 0)[0], node)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape)))
def _infer_Conv(self, node):
sympy_shape = self._compute_conv_pool_shape(node)
self._update_computed_dims(sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape)))
def _infer_Einsum(self, node):
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
equation = get_attribute(node, 'equation')
equation = equation.replace(b' ', b'')
mid_index = equation.find(b'->')
left_equation = equation[:mid_index] if mid_index != -1 else equation
num_operands = 0
num_ellipsis = 0
num_ellipsis_indices = 0
letter_to_dim = {}
terms = left_equation.split(b',')
for term in terms:
ellipsis_index = term.find(b'...')
shape = self._get_shape(node, num_operands)
rank = len(shape)
if ellipsis_index != -1:
if num_ellipsis == 0:
num_ellipsis_indices = rank - len(term) + 3
num_ellipsis = num_ellipsis + 1
for i in range(1, rank + 1):
letter = term[-i]
if letter != 46: # letter != b'.'
dim = shape[-i]
if letter not in letter_to_dim.keys():
letter_to_dim[letter] = dim
elif type(dim) != sympy.Symbol:
letter_to_dim[letter] = dim
num_operands = num_operands + 1
new_sympy_shape = []
from collections import OrderedDict
num_letter_occurrences = OrderedDict()
if mid_index != -1:
right_equation = equation[mid_index + 2:]
right_ellipsis_index = right_equation.find(b'...')
if right_ellipsis_index != -1:
for i in range(num_ellipsis_indices):
new_sympy_shape.append(shape[i])
for c in right_equation:
if c != 46: # c != b'.'
new_sympy_shape.append(letter_to_dim[c])
else:
for i in range(num_ellipsis_indices):
new_sympy_shape.append(shape[i])
for c in left_equation:
if c != 44 and c != 46: # c != b',' and c != b'.':
if c in num_letter_occurrences:
num_letter_occurrences[c] = num_letter_occurrences[
c] + 1
else:
num_letter_occurrences[c] = 1
for key, value in num_letter_occurrences.items():
if value == 1:
new_sympy_shape.append(letter_to_dim[key])
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_dtype,
new_sympy_shape))
def _infer_Expand(self, node):
expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True)
if expand_to_shape is not None:
# new_shape's dim can come from shape value
self._update_computed_dims(expand_to_shape)
shape = self._get_shape(node, 0)
new_shape = self._broadcast_shapes(
shape, get_shape_from_sympy_shape(expand_to_shape))
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, new_shape))
def _infer_Gather(self, node):
data_shape = self._get_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(data_shape))
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, data_shape[:axis] +
indices_shape + data_shape[axis +
1:]))
# for 1D input, do some sympy compute
if node.input[0] in self.sympy_data_ and len(
data_shape) == 1 and 0 == get_attribute(node, 'axis', 0):
idx = self._try_get_value(node, 1)
if idx is not None:
data = self.sympy_data_[node.input[0]]
if type(data) == list:
if type(idx) == np.ndarray and len(idx.shape) == 1:
self.sympy_data_[node.output[
0]] = [data[int(i)] for i in idx]
else:
self.sympy_data_[node.output[0]] = data[int(idx)]
else:
assert idx == 0 or idx == -1
self.sympy_data_[node.output[0]] = data
def _infer_GatherElements(self, node):
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, indices_shape))
def _infer_GatherND(self, node):
data_shape = self._get_shape(node, 0)
data_rank = len(data_shape)
indices_shape = self._get_shape(node, 1)
indices_rank = len(indices_shape)
last_index_dimension = indices_shape[-1]
assert is_literal(
last_index_dimension) and last_index_dimension <= data_rank
new_shape = indices_shape[:-1] + data_shape[last_index_dimension:]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, new_shape))
def _infer_If(self, node):
# special case for constant condition, in case there are mismatching shape from the non-executed branch
subgraphs = [
get_attribute(node, 'then_branch'), get_attribute(node,
'else_branch')
]
cond = self._try_get_value(node, 0)
if cond is not None:
if as_scalar(cond) > 0:
subgraphs[1].CopyFrom(subgraphs[0])
else:
subgraphs[0].CopyFrom(subgraphs[1])
for i_sub, subgraph in enumerate(subgraphs):
subgraph_infer = self._onnx_infer_subgraph(
node, subgraph, use_node_input=False)
for i_out in range(len(node.output)):
vi = self.known_vi_[node.output[i_out]]
if i_sub == 0:
vi.CopyFrom(subgraph.output[i_out])
vi.name = node.output[i_out]
else:
self._fuse_tensor_type(node, i_out, vi.type,
subgraph.output[i_out].type)
# pass on sympy data from subgraph, if cond is constant
if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else
1):
if subgraph.output[
i_out].name in subgraph_infer.sympy_data_:
self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[
subgraph.output[i_out].name]
def _infer_Loop(self, node):
subgraph = get_attribute(node, 'body')
assert len(subgraph.input) == len(node.input)
num_loop_carried = len(
node.input) - 2 # minus the length and initial loop condition
# when sequence_type is used as loop carried input
# needs to run subgraph infer twice if the tensor shape in sequence contains None
for i, si in enumerate(subgraph.input):
si_name = si.name
si.CopyFrom(self.known_vi_[node.input[i]])
si.name = si_name
self._onnx_infer_subgraph(node, subgraph)
# check subgraph input/output for shape changes in loop carried variables
# for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a)
# for sequence_type, propagate from output to input
need_second_infer = False
for i_out in range(1, num_loop_carried + 1):
so = subgraph.output[i_out]
so_shape = get_shape_from_value_info(so)
if is_sequence(so.type):
if so_shape and None in so_shape:
# copy shape from output to input
# note that loop input is [loop_len, cond, input_0, input_1, ...]
# while loop output is [cond, output_0, output_1, ...]
subgraph.input[i_out +
1].type.sequence_type.elem_type.CopyFrom(
so.type.sequence_type.elem_type)
need_second_infer = True
else:
si = subgraph.input[i_out + 1]
si_shape = get_shape_from_value_info(si)
for di, dims in enumerate(zip(si_shape, so_shape)):
if dims[0] != dims[1]:
new_dim = onnx.TensorShapeProto.Dimension()
new_dim.dim_param = str(
self._new_symbolic_dim_from_output(node, i_out, di))
si.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
so.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
need_second_infer = True
if need_second_infer:
if self.verbose_ > 2:
logger.debug(
"Rerun Loop: {}({}...), because of sequence in loop carried variables".
format(node.name, node.output[0]))
self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False)
# create a new symbolic dimension for iteration dependent dimension
loop_iter_dim = str(self._new_symbolic_dim_from_output(node))
for i in range(len(node.output)):
vi = self.known_vi_[node.output[i]]
vi.CopyFrom(subgraph.output[
i +
1]) # first subgraph output is condition, not in node output
if i >= num_loop_carried:
assert not is_sequence(
vi.type) # TODO: handle loop accumulation in sequence_type
subgraph_vi_dim = subgraph.output[i +
1].type.tensor_type.shape.dim
vi.type.tensor_type.shape.ClearField('dim')
vi_dim = vi.type.tensor_type.shape.dim
vi_dim.add().dim_param = loop_iter_dim
vi_dim.extend(list(subgraph_vi_dim))
vi.name = node.output[i]
def _infer_MatMul(self, node):
self._compute_matmul_shape(node)
def _infer_MatMulInteger(self, node):
self._compute_matmul_shape(node, onnx.TensorProto.INT32)
def _infer_NonMaxSuppression(self, node):
selected = str(self._new_symbolic_dim_from_output(node))
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], onnx.TensorProto.INT64, [selected, 3]))
def _infer_NonZero(self, node):
input_rank = self._get_shape_rank(node, 0)
# create a new symbolic dimension for NonZero output
nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1))
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
def _infer_OneHot(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
depth = self._try_get_value(node, 1)
axis = get_attribute(node, 'axis', -1)
axis = handle_negative_axis(axis, len(sympy_shape) + 1)
new_shape = get_shape_from_sympy_shape(sympy_shape[:axis] + [
self._new_symbolic_dim_from_output(node)
if not is_literal(depth) else depth
] + sympy_shape[axis:])
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[2]].type.tensor_type.elem_type, new_shape))
def _infer_Pad(self, node):
if get_opset(self.out_mp_) <= 10:
pads = get_attribute(node, 'pads')
else:
pads = self._try_get_value(node, 1)
sympy_shape = self._get_sympy_shape(node, 0)
rank = len(sympy_shape)
if pads is not None:
assert len(pads) == 2 * rank
new_sympy_shape = [
d + pad_up + pad_down
for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[
rank:])
]
self._update_computed_dims(new_sympy_shape)
else:
# dynamic pads, create new symbolic dimensions
new_sympy_shape = self._new_symbolic_shape(rank, node)
output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)))
def _infer_Pool(self, node):
sympy_shape = self._compute_conv_pool_shape(node)
self._update_computed_dims(sympy_shape)
for o in node.output:
if not o:
continue
vi = self.known_vi_[o]
vi.CopyFrom(
helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(
sympy_shape)))
def _infer_aten_bitwise_or(self, node):
shape0 = self._get_shape(node, 0)
shape1 = self._get_shape(node, 1)
new_shape = self._broadcast_shapes(shape0, shape1)
t0 = self.known_vi_[node.input[0]]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], t0.type.tensor_type.elem_type, new_shape))
def _infer_aten_diagonal(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
rank = len(sympy_shape)
offset = self._try_get_value(node, 1)
dim1 = self._try_get_value(node, 2)
dim2 = self._try_get_value(node, 3)
assert offset is not None and dim1 is not None and dim2 is not None
dim1 = handle_negative_axis(dim1, rank)
dim2 = handle_negative_axis(dim2, rank)
new_shape = []
for dim, val in enumerate(sympy_shape):
if dim not in [dim1, dim2]:
new_shape.append(val)
shape1 = sympy_shape[dim1]
shape2 = sympy_shape[dim2]
if offset >= 0:
diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset))
else:
diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2))
new_shape.append(diag_shape)
if node.output[0]:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
new_shape)))
def _infer_aten_multinomial(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
rank = len(sympy_shape)
assert rank in [1, 2]
num_samples = self._try_get_value(node, 1)
di = rank - 1
last_dim = num_samples if num_samples else str(
self._new_symbolic_dim_from_output(node, 0, di))
output_shape = sympy_shape[:-1] + [last_dim]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], onnx.TensorProto.INT64,
get_shape_from_sympy_shape(output_shape)))
def _infer_aten_pool2d(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
assert len(sympy_shape) == 4
sympy_shape[-2:] = [
self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]
]
self._update_computed_dims(sympy_shape)
for i, o in enumerate(node.output):
if not o:
continue
vi = self.known_vi_[o]
elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[
node.input[0]].type.tensor_type.elem_type
vi.CopyFrom(
helper.make_tensor_value_info(
o, elem_type, get_shape_from_sympy_shape(sympy_shape)))
def _infer_aten_unfold(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
dimension = self._try_get_value(node, 1)
size = self._try_get_value(node, 2)
step = self._try_get_value(node, 3)
if dimension is not None and size is not None and step is not None:
assert dimension < len(sympy_shape)
sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1
sympy_shape.append(size)
else:
rank = len(sympy_shape)
sympy_shape = self._new_symbolic_shape(rank + 1, node)
self._update_computed_dims(sympy_shape)
if node.output[0]:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
sympy_shape)))
def _infer_aten_argmax(self, node):
new_shape = None
if node.input[1] == '':
# The argmax of the flattened input is returned.
new_shape = []
else:
dim = self._try_get_value(node, 1)
keepdim = self._try_get_value(node, 2)
if keepdim is not None:
sympy_shape = self._get_sympy_shape(node, 0)
if dim is not None:
dim = handle_negative_axis(dim, len(sympy_shape))
if keepdim:
sympy_shape[dim] = 1
else:
del sympy_shape[dim]
else:
rank = len(sympy_shape)
sympy_shape = self._new_symbolic_shape(rank if keepdim else
rank - 1, node)
self._update_computed_dims(sympy_shape)
new_shape = get_shape_from_sympy_shape(sympy_shape)
if node.output[0] and new_shape is not None:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], onnx.TensorProto.INT64, new_shape))
def _infer_aten_bce(self, node):
reduction = self._try_get_value(node, 4)
if reduction is None:
reduction = 1
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
if reduction == 0:
vi.type.tensor_type.elem_type = elem_type
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
else:
vi.CopyFrom(
helper.make_tensor_value_info(vi.name, elem_type,
self._get_shape(node, 0)))
def _infer_BatchNormalization(self, node):
self._propagate_shape_and_type(node)
# this works for opsets < 14 and 14 since we check i < len(node.output) in the loop
for i in [1, 2, 3, 4]:
if i < len(node.output) and node.output[i] != "":
# all of these parameters have the same shape as the 1st input
self._propagate_shape_and_type(
node, input_index=1, output_index=i)
def _infer_Range(self, node):
vi = self.known_vi_[node.output[0]]
input_data = self._get_int_values(node)
if all([i is not None for i in input_data]):
start = as_scalar(input_data[0])
limit = as_scalar(input_data[1])
delta = as_scalar(input_data[2])
new_sympy_shape = [
sympy.Max(sympy.ceiling((limit - start) / delta), 0)
]
else:
new_sympy_shape = [self._new_symbolic_dim_from_output(node)]
self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[0]].type.tensor_type.
elem_type, get_shape_from_sympy_shape(new_sympy_shape)))
def _infer_ReduceSum(self, node):
keep_dims = get_attribute(node, 'keepdims', 1)
if get_opset(self.out_mp_) >= 13 and len(node.input) > 1:
# ReduceSum changes axes to input[1] in opset 13
axes = self._try_get_value(node, 1)
vi = self.known_vi_[node.output[0]]
if axes is None:
assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[
0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
self._new_symbolic_shape(
self._get_shape_rank(node, 0), node))))
else:
shape = self._get_shape(node, 0)
output_shape = []
axes = [handle_negative_axis(a, len(shape)) for a in axes]
for i, d in enumerate(shape):
if i in axes:
if keep_dims:
output_shape.append(1)
else:
output_shape.append(d)
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], self.known_vi_[node.input[
0]].type.tensor_type.elem_type, output_shape))
def _infer_ReduceProd(self, node):
axes = get_attribute(node, 'axes')
keep_dims = get_attribute(node, 'keepdims', 1)
if keep_dims == 0 and axes == [0]:
data = self._get_int_values(node)[0]
if data is not None:
self.sympy_data_[node.output[0]] = sympy_reduce_product(data)
def _infer_Reshape(self, node):
shape_value = self._try_get_value(node, 1)
vi = self.known_vi_[node.output[0]]
if shape_value is None:
shape_shape = self._get_shape(node, 1)
assert len(shape_shape) == 1
shape_rank = shape_shape[0]
assert is_literal(shape_rank)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(
self._new_symbolic_shape(shape_rank, node))))
else:
input_sympy_shape = self._get_sympy_shape(node, 0)
total = int(1)
for d in input_sympy_shape:
total = total * d
new_sympy_shape = []
deferred_dim_idx = -1
non_deferred_size = int(1)
for i, d in enumerate(shape_value):
if type(d) == sympy.Symbol:
new_sympy_shape.append(d)
elif d == 0:
new_sympy_shape.append(input_sympy_shape[i])
non_deferred_size = non_deferred_size * input_sympy_shape[i]
else:
new_sympy_shape.append(d)
if d == -1:
deferred_dim_idx = i
elif d != 0:
non_deferred_size = non_deferred_size * d
assert new_sympy_shape.count(-1) < 2
if -1 in new_sympy_shape:
new_dim = total // non_deferred_size
new_sympy_shape[deferred_dim_idx] = new_dim
self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
self._pass_on_sympy_data(node)
def _infer_Resize(self, node):
vi = self.known_vi_[node.output[0]]
input_sympy_shape = self._get_sympy_shape(node, 0)
if get_opset(self.out_mp_) <= 10:
scales = self._try_get_value(node, 1)
if scales is not None:
new_sympy_shape = [
sympy.simplify(sympy.floor(d * s))
for d, s in zip(input_sympy_shape, scales)
]
self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[
0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
else:
roi = self._try_get_value(node, 1)
scales = self._try_get_value(node, 2)
sizes = self._try_get_value(node, 3)
if sizes is not None:
new_sympy_shape = [
sympy.simplify(sympy.floor(s)) for s in sizes
]
self._update_computed_dims(new_sympy_shape)
elif scales is not None:
rank = len(scales)
if get_attribute(node, 'coordinate_transformation_mode'
) == 'tf_crop_and_resize':
assert len(roi) == 2 * rank
roi_start = list(roi)[:rank]
roi_end = list(roi)[rank:]
else:
roi_start = [0] * rank
roi_end = [1] * rank
scales = list(scales)
new_sympy_shape = [
sympy.simplify(sympy.floor(d * (end - start) * scale))
for d, start, end, scale in zip(input_sympy_shape,
roi_start, roi_end, scales)
]
self._update_computed_dims(new_sympy_shape)
else:
new_sympy_shape = self._new_symbolic_shape(
self._get_shape_rank(node, 0), node)
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
new_sympy_shape)))
def _infer_Scan(self, node):
subgraph = get_attribute(node, 'body')
num_scan_inputs = get_attribute(node, 'num_scan_inputs')
scan_input_axes = get_attribute(node, 'scan_input_axes',
[0] * num_scan_inputs)
num_scan_states = len(node.input) - num_scan_inputs
scan_input_axes = [
handle_negative_axis(
ax, self._get_shape_rank(node, i + num_scan_states))
for i, ax in enumerate(scan_input_axes)
]
# We may have cases where the subgraph has optionial inputs that appear in both subgraph's input and initializer,
# but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs.
assert len(subgraph.input) >= len(node.input)
subgraph_inputs = subgraph.input[:len(node.input)]
for i, si in enumerate(subgraph_inputs):
subgraph_name = si.name
si.CopyFrom(self.known_vi_[node.input[i]])
if i >= num_scan_states:
scan_input_dim = si.type.tensor_type.shape.dim
scan_input_dim.remove(
scan_input_dim[scan_input_axes[i - num_scan_states]])
si.name = subgraph_name
self._onnx_infer_subgraph(node, subgraph)
num_scan_outputs = len(node.output) - num_scan_states
scan_output_axes = get_attribute(node, 'scan_output_axes',
[0] * num_scan_outputs)
scan_input_dim = get_shape_from_type_proto(
self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
for i, o in enumerate(node.output):
vi = self.known_vi_[o]
if i >= num_scan_states:
shape = get_shape_from_type_proto(subgraph.output[i].type)
new_dim = handle_negative_axis(
scan_output_axes[i - num_scan_states], len(shape) + 1)
shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:]
vi.CopyFrom(
helper.make_tensor_value_info(o, subgraph.output[
i].type.tensor_type.elem_type, shape))
else:
vi.CopyFrom(subgraph.output[i])
vi.name = o
def _infer_ScatterElements(self, node):
data_shape = self._get_shape(node, 0)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, data_shape))
def _infer_SequenceAt(self, node):
# need to create new symbolic dimension if sequence shape has None:
seq_shape = self._get_shape(node, 0)
vi = self.known_vi_[node.output[0]]
if seq_shape is not None:
for di, d in enumerate(seq_shape):
if d is not None:
continue
new_dim = onnx.TensorShapeProto.Dimension()
new_dim.dim_param = str(
self._new_symbolic_dim_from_output(node, 0, di))
vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
def _infer_SequenceInsert(self, node):
# workaround bug in onnx's shape inference
vi_seq = self.known_vi_[node.input[0]]
vi_tensor = self.known_vi_[node.input[1]]
vi_out_seq = self.known_vi_[node.output[0]]
vi_out_seq.CopyFrom(vi_seq)
vi_out_seq.name = node.output[0]
self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type)
def _infer_Shape(self, node):
self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)
def _infer_Size(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape)
self.known_vi_[node.output[0]].CopyFrom(
helper.make_tensor_value_info(node.output[0],
onnx.TensorProto.INT64, []))
def _infer_Slice(self, node):
def less_equal(x, y):
try:
return bool(x <= y)
except TypeError:
pass
try:
return bool(y >= x)
except TypeError:
pass
try:
return bool(-x >= -y)
except TypeError:
pass
try:
return bool(-y <= -x)
except TypeError:
# the last attempt; this may raise TypeError
return bool(y - x >= 0)
def handle_negative_index(index, bound):
""" normalizes a negative index to be in [0, bound) """
try:
if not less_equal(0, index):
if is_literal(index) and index <= -self.int_max_:
# this case is handled separately
return index
return bound + index
except TypeError:
logger.warning("Cannot determine if {} < 0".format(index))
return index
if get_opset(self.out_mp_) <= 9:
axes = get_attribute(node, 'axes')
starts = get_attribute(node, 'starts')
ends = get_attribute(node, 'ends')
if not axes:
axes = list(range(len(starts)))
steps = [1] * len(axes)
else:
starts = as_list(self._try_get_value(node, 1), keep_none=True)
ends = as_list(self._try_get_value(node, 2), keep_none=True)
axes = self._try_get_value(node, 3)
steps = self._try_get_value(node, 4)
if axes is None and not (starts is None and ends is None):
axes = list(
range(0, len(starts if starts is not None else ends)))
if steps is None and not (starts is None and ends is None):
steps = [1] * len(starts if starts is not None else ends)
axes = as_list(axes, keep_none=True)
steps = as_list(steps, keep_none=True)
new_sympy_shape = self._get_sympy_shape(node, 0)
if starts is None or ends is None:
if axes is None:
for i in range(len(new_sympy_shape)):
new_sympy_shape[i] = self._new_symbolic_dim_from_output(
node, 0, i)
else:
new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape)
for i in axes:
new_sympy_shape[i] = self._new_symbolic_dim_from_output(
node, 0, i)
else:
for i, s, e, t in zip(axes, starts, ends, steps):
e = handle_negative_index(e, new_sympy_shape[i])
if is_literal(e):
if e >= self.int_max_:
e = new_sympy_shape[i]
elif e <= -self.int_max_:
e = 0 if s > 0 else -1
elif is_literal(new_sympy_shape[i]):
if e < 0:
e = max(0, e + new_sympy_shape[i])
e = min(e, new_sympy_shape[i])
else:
if e > 0:
e = sympy.Min(
e, new_sympy_shape[i]
) if e > 1 else e #special case for slicing first to make computation easier
else:
if is_literal(new_sympy_shape[i]):
e = sympy.Min(e, new_sympy_shape[i])
else:
try:
if not less_equal(e, new_sympy_shape[i]):
e = new_sympy_shape[i]
except Exception:
logger.warning(
'Unable to determine if {} <= {}, treat as equal'.
format(e, new_sympy_shape[i]))
e = new_sympy_shape[i]
s = handle_negative_index(s, new_sympy_shape[i])
if is_literal(new_sympy_shape[i]) and is_literal(s):
s = max(0, min(s, new_sympy_shape[i]))
new_sympy_shape[i] = sympy.simplify(
(e - s + t + (-1 if t > 0 else 1)) // t)
self._update_computed_dims(new_sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
# handle sympy_data if needed, for slice in shape computation
if (node.input[0] in self.sympy_data_ and [0] == axes and
len(starts) == 1 and len(ends) == 1 and len(steps) == 1):
input_sympy_data = self.sympy_data_[node.input[0]]
if type(input_sympy_data) == list or (
type(input_sympy_data) == np.array and
len(input_sympy_data.shape) == 1):
self.sympy_data_[node.output[0]] = input_sympy_data[starts[
0]:ends[0]:steps[0]]
def _infer_SoftmaxCrossEntropyLoss(self, node):
vi = self.known_vi_[node.output[0]]
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi.type.tensor_type.elem_type = elem_type
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
if len(node.output) > 1:
data_shape = self._get_shape(node, 0)
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(
helper.make_tensor_value_info(vi.name, elem_type, data_shape))
def _infer_Split_Common(self, node, make_value_info_func):
input_sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(input_sympy_shape))
split = get_attribute(node, 'split')
if not split:
num_outputs = len(node.output)
split = [input_sympy_shape[axis] /
sympy.Integer(num_outputs)] * num_outputs
self._update_computed_dims(split)
else:
split = [sympy.Integer(s) for s in split]
for i_o in range(len(split)):
vi = self.known_vi_[node.output[i_o]]
vi.CopyFrom(
make_value_info_func(node.output[i_o], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
input_sympy_shape[:axis] + [
split[i_o]
] + input_sympy_shape[axis + 1:])))
self.known_vi_[vi.name] = vi
def _infer_Split(self, node):
self._infer_Split_Common(node, helper.make_tensor_value_info)
def _infer_SplitToSequence(self, node):
self._infer_Split_Common(node, helper.make_sequence_value_info)
def _infer_Squeeze(self, node):
input_shape = self._get_shape(node, 0)
op_set = get_opset(self.out_mp_)
# Depending on op-version 'axes' are provided as attribute or via 2nd input
if op_set < 13:
axes = get_attribute(node, 'axes')
assert self._try_get_value(node, 1) is None
else:
axes = self._try_get_value(node, 1)
assert get_attribute(node, 'axes') is None
if axes is None:
# No axes have been provided (neither via attribute nor via input).
# In this case the 'Shape' op should remove all axis with dimension 1.
# For symbolic dimensions we guess they are !=1.
output_shape = [s for s in input_shape if s != 1]
if self.verbose_ > 0:
symbolic_dimensions = [s for s in input_shape if type(s) != int]
if len(symbolic_dimensions) > 0:
logger.debug(
f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
+
f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}"
)
else:
axes = [handle_negative_axis(a, len(input_shape)) for a in axes]
output_shape = []
for i in range(len(input_shape)):
if i not in axes:
output_shape.append(input_shape[i])
else:
assert input_shape[i] == 1 or type(input_shape[i]) != int
if self.verbose_ > 0 and type(input_shape[i]) != int:
logger.debug(
f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
+
f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1."
)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
self._pass_on_sympy_data(node)
def _infer_Tile(self, node):
repeats_value = self._try_get_value(node, 1)
new_sympy_shape = []
if repeats_value is not None:
input_sympy_shape = self._get_sympy_shape(node, 0)
for i, d in enumerate(input_sympy_shape):
new_dim = d * repeats_value[i]
new_sympy_shape.append(new_dim)
self._update_computed_dims(new_sympy_shape)
else:
new_sympy_shape = self._new_symbolic_shape(
self._get_shape_rank(node, 0), node)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
def _infer_TopK(self, node):
rank = self._get_shape_rank(node, 0)
axis = handle_negative_axis(get_attribute(node, 'axis', -1), rank)
new_shape = self._get_shape(node, 0)
if get_opset(self.out_mp_) <= 9:
k = get_attribute(node, 'k')
else:
k = self._get_int_values(node)[1]
if k == None:
k = self._new_symbolic_dim_from_output(node)
else:
k = as_scalar(k)
if type(k) in [int, str]:
new_shape[axis] = k
else:
new_sympy_shape = self._get_sympy_shape(node, 0)
new_sympy_shape[axis] = k
self._update_computed_dims(
new_sympy_shape
) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape
new_shape = get_shape_from_sympy_shape(new_sympy_shape)
for i_o in range(len(node.output)):
vi = self.known_vi_[node.output[i_o]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
i_o], vi.type.tensor_type.elem_type, new_shape))
def _infer_Transpose(self, node):
if node.input[0] in self.sympy_data_:
data_shape = self._get_shape(node, 0)
perm = get_attribute(node, 'perm',
reversed(list(range(len(data_shape)))))
input_data = self.sympy_data_[node.input[0]]
self.sympy_data_[node.output[0]] = np.transpose(
np.array(input_data).reshape(*data_shape),
axes=tuple(perm)).flatten().tolist()
def _infer_Unsqueeze(self, node):
input_shape = self._get_shape(node, 0)
op_set = get_opset(self.out_mp_)
# Depending on op-version 'axes' are provided as attribute or via 2nd input
if op_set < 13:
axes = get_attribute(node, 'axes')
assert self._try_get_value(node, 1) is None
else:
axes = self._try_get_value(node, 1)
assert get_attribute(node, 'axes') is None
output_rank = len(input_shape) + len(axes)
axes = [handle_negative_axis(a, output_rank) for a in axes]
input_axis = 0
output_shape = []
for i in range(output_rank):
if i in axes:
output_shape.append(1)
else:
output_shape.append(input_shape[input_axis])
input_axis += 1
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
self._pass_on_sympy_data(node)
def _infer_ZipMap(self, node):
map_key_type = None
if get_attribute(node, 'classlabels_int64s') is not None:
map_key_type = onnx.TensorProto.INT64
elif get_attribute(node, 'classlabels_strings') is not None:
map_key_type = onnx.TensorProto.STRING
assert map_key_type is not None
new_vi = onnx.ValueInfoProto()
new_vi.name = node.output[0]
new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(new_vi)
def _infer_Attention(self, node):
shape = self._get_shape(node, 0)
shape_bias = self._get_shape(node, 2)
assert len(shape) == 3 and len(shape_bias) == 1
qkv_hidden_sizes_attr = get_attribute(node, 'qkv_hidden_sizes')
if qkv_hidden_sizes_attr is not None:
assert len(qkv_hidden_sizes_attr) == 3
shape[2] = int(qkv_hidden_sizes_attr[2])
else:
shape[2] = int(shape_bias[0] / 3)
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_dtype, shape))
if len(node.output) > 1:
# input shape: (batch_size, sequence_length, hidden_size)
# past shape: (2, batch_size, num_heads, past_sequence_length, head_size)
# mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len)
# present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length
input_shape = self._get_shape(node, 0)
past_shape = self._get_shape(node, 4)
mask_shape = self._get_shape(node, 3)
if len(past_shape) == 5:
if len(mask_shape) in [2, 3]:
past_shape[3] = mask_shape[-1]
elif isinstance(input_shape[1], int) and isinstance(
past_shape[3], int):
past_shape[3] = input_shape[1] + past_shape[3]
else:
past_shape[3] = f"{past_shape[3]}+{input_shape[1]}"
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(
helper.make_tensor_value_info(vi.name, output_dtype,
past_shape))
def _infer_BiasGelu(self, node):
self._propagate_shape_and_type(node)
def _infer_FastGelu(self, node):
self._propagate_shape_and_type(node)
def _infer_Gelu(self, node):
self._propagate_shape_and_type(node)
def _infer_LayerNormalization(self, node):
self._propagate_shape_and_type(node)
def _infer_LongformerAttention(self, node):
self._propagate_shape_and_type(node)
def _infer_EmbedLayerNormalization(self, node):
input_ids_shape = self._get_shape(node, 0)
word_embedding_shape = self._get_shape(node, 2)
assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2
output_shape = input_ids_shape + [word_embedding_shape[1]]
word_embedding_dtype = self.known_vi_[node.input[
2]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], word_embedding_dtype,
output_shape))
mask_index_shape = [input_ids_shape[0]]
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
1], onnx.TensorProto.INT32, mask_index_shape))
if len(node.output) > 2:
# Optional output of add before layer nomalization is done
# shape is same as the output
vi = self.known_vi_[node.output[2]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
2], word_embedding_dtype, output_shape))
def _infer_SkipLayerNormalization(self, node):
self._propagate_shape_and_type(node)
def _infer_PythonOp(self, node):
output_tensor_types = get_attribute(node, 'output_tensor_types')
assert output_tensor_types
output_tensor_ranks = get_attribute(node, 'output_tensor_ranks')
assert output_tensor_ranks
# set the context output seperately.
# The first output is autograd's context.
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0],
onnx.TensorProto.INT64, []))
# Outputs after autograd's context are tensors.
# We assume their ranks are fixed for different model inputs.
for i in range(len(node.output) - 1):
# Process the i-th tensor outputs.
vi = self.known_vi_[node.output[i + 1]]
sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node)
shape = get_shape_from_sympy_shape(sympy_shape)
value_info = helper.make_tensor_value_info(
node.output[i + 1], output_tensor_types[i], shape)
vi.CopyFrom(value_info)
def _propagate_shape_and_type(self, node, input_index=0, output_index=0):
shape = self._get_shape(node, input_index)
output_dtype = self.known_vi_[node.input[
input_index]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[output_index]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[output_index],
output_dtype, shape))
def _is_none_dim(self, dim_value):
if type(dim_value) != str:
return False
if "unk__" not in dim_value:
return False
if dim_value in self.symbolic_dims_.keys():
return False
return True
def _is_shape_contains_none_dim(self, out_shape):
for out in out_shape:
if self._is_none_dim(out):
return out
return None
def _infer_impl(self, start_sympy_data=None):
self.sympy_data_ = start_sympy_data or {}
self.out_mp_.graph.ClearField('value_info')
self._apply_suggested_merge(graph_input_only=True)
self.input_symbols_ = set()
for i in self.out_mp_.graph.input:
input_shape = get_shape_from_value_info(i)
if input_shape is None:
continue
if is_sequence(i.type):
input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
else:
input_dims = i.type.tensor_type.shape.dim
for i_dim, dim in enumerate(input_shape):
if dim is None:
# some models use None for symbolic dim in input, replace it with a string
input_dims[i_dim].dim_param = str(
self._new_symbolic_dim(i.name, i_dim))
self.input_symbols_.update(
[d for d in input_shape if type(d) == str])
for s in self.input_symbols_:
if s in self.suggested_merge_:
s_merge = self.suggested_merge_[s]
assert s_merge in self.symbolic_dims_
self.symbolic_dims_[s] = self.symbolic_dims_[s_merge]
else:
# Since inputs are not produced by other ops, we can assume positivity
self.symbolic_dims_[s] = sympy.Symbol(
s, integer=True, positive=True)
# create a temporary ModelProto for single node inference
# note that we remove initializer to have faster inference
# for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways
self.tmp_mp_ = onnx.ModelProto()
self.tmp_mp_.CopyFrom(self.out_mp_)
self.tmp_mp_.graph.ClearField('initializer')
# compute prerequesite for node for topological sort
# node with subgraphs may have dependency on implicit inputs, which will affect topological sort
prereq_for_node = {
} # map from node to all its inputs, including implicit ones in subgraph
def get_prereq(node):
names = set(i for i in node.input if i)
subgraphs = []
if 'If' == node.op_type:
subgraphs = [
get_attribute(node, 'then_branch'),
get_attribute(node, 'else_branch')
]
elif node.op_type in ['Loop', 'Scan']:
subgraphs = [get_attribute(node, 'body')]
for g in subgraphs:
g_outputs_and_initializers = {i.name for i in g.initializer}
g_prereq = set()
for n in g.node:
g_outputs_and_initializers.update(n.output)
for n in g.node:
g_prereq.update([
i for i in get_prereq(n)
if i not in g_outputs_and_initializers
])
names.update(g_prereq)
# remove subgraph inputs from g_prereq since those are local-only
for i in g.input:
if i.name in names:
names.remove(i.name)
return names
for n in self.tmp_mp_.graph.node:
prereq_for_node[n.output[0]] = get_prereq(n)
# topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
sorted_nodes = []
sorted_known_vi = set([
i.name
for i in list(self.out_mp_.graph.input) + list(
self.out_mp_.graph.initializer)
])
if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]):
# Loop/Scan will have some graph output in graph inputs, so don't do topological sort
sorted_nodes = self.out_mp_.graph.node
else:
while not all(
[o.name in sorted_known_vi for o in self.out_mp_.graph.output]):
old_sorted_nodes_len = len(sorted_nodes)
for node in self.out_mp_.graph.node:
if (node.output[0] not in sorted_known_vi) and all([
i in sorted_known_vi
for i in prereq_for_node[node.output[0]] if i
]):
sorted_known_vi.update(node.output)
sorted_nodes.append(node)
if old_sorted_nodes_len == len(sorted_nodes) and not all([
o.name in sorted_known_vi
for o in self.out_mp_.graph.output
]):
raise Exception('Invalid model with cyclic graph')
for node in sorted_nodes:
assert all([i in self.known_vi_ for i in node.input if i])
self._onnx_infer_single_node(node)
known_aten_op = False
if node.op_type in self.dispatcher_:
self.dispatcher_[node.op_type](node)
elif node.op_type in ['ConvTranspose']:
# onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
# before adding symbolic compute for them
# mark the output type as UNDEFINED to allow guessing of rank
vi = self.known_vi_[node.output[0]]
if len(vi.type.tensor_type.shape.dim) == 0:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
elif node.op_type == 'ATen' and node.domain == 'org.pytorch.aten':
for attr in node.attribute:
# TODO: Is overload_name needed?
if attr.name == 'operator':
aten_op_name = attr.s.decode('utf-8') if isinstance(
attr.s, bytes) else attr.s
if aten_op_name in self.aten_op_dispatcher_:
known_aten_op = True
self.aten_op_dispatcher_[aten_op_name](node)
break
if self.verbose_ > 2:
logger.debug(node.op_type + ': ' + node.name)
for i, name in enumerate(node.input):
logger.debug(' Input {}: {} {}'.format(
i, name, 'initializer'
if name in self.initializers_ else ''))
# onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb']
# symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case
if node.op_type in [
'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger',
'MatMulInteger16', 'Where', 'Sum'
]:
vi = self.known_vi_[node.output[0]]
out_rank = len(get_shape_from_type_proto(vi.type))
in_shapes = [
self._get_shape(node, i) for i in range(len(node.input))
]
for d in range(out_rank - (2 if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16'
] else 0)):
in_dims = [
s[len(s) - out_rank + d] for s in in_shapes
if len(s) + d >= out_rank
]
if len(in_dims) > 1:
self._check_merged_dims(in_dims, allow_broadcast=True)
for i_o in range(len(node.output)):
vi = self.known_vi_[node.output[i_o]]
out_type = vi.type
out_type_kind = out_type.WhichOneof('value')
# do not process shape for non-tensors
if out_type_kind not in [
'tensor_type', 'sparse_tensor_type', None
]:
if self.verbose_ > 2:
if out_type_kind == 'sequence_type':
seq_cls_type = out_type.sequence_type.elem_type.WhichOneof(
'value')
if 'tensor_type' == seq_cls_type:
logger.debug(' {}: sequence of {} {}'.format(
node.output[i_o],
str(get_shape_from_value_info(vi)),
onnx.TensorProto.DataType.Name(
vi.type.sequence_type.elem_type.
tensor_type.elem_type)))
else:
logger.debug(' {}: sequence of {}'.format(
node.output[i_o], seq_cls_type))
else:
logger.debug(' {}: {}'.format(node.output[i_o],
out_type_kind))
continue
out_shape = get_shape_from_value_info(vi)
out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
if self.verbose_ > 2:
logger.debug(' {}: {} {}'.format(
node.output[i_o],
str(out_shape),
onnx.TensorProto.DataType.Name(
vi.type.tensor_type.elem_type)))
if node.output[i_o] in self.sympy_data_:
logger.debug(' Sympy Data: ' + str(self.sympy_data_[
node.output[i_o]]))
# onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain
if (out_shape is not None and
(None in out_shape or
self._is_shape_contains_none_dim(out_shape))
) or out_type_undefined:
if self.auto_merge_:
if node.op_type in [
'Add', 'Sub', 'Mul', 'Div', 'MatMul',
'MatMulInteger', 'MatMulInteger16', 'Concat',
'Where', 'Sum', 'Equal', 'Less', 'Greater',
'LessOrEqual', 'GreaterOrEqual'
]:
shapes = [
self._get_shape(node, i)
for i in range(len(node.input))
]
if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16'
]:
if None in out_shape or self._is_shape_contains_none_dim(
out_shape):
if None in out_shape:
idx = out_shape.index(None)
else:
idx = out_shape.index(
self._is_shape_contains_none_dim(
out_shape))
dim_idx = [
len(s) - len(out_shape) + idx
for s in shapes
]
# only support auto merge for MatMul for dim < rank-2 when rank > 2
assert len(
shapes[0]) > 2 and dim_idx[0] < len(
shapes[0]) - 2
assert len(
shapes[1]) > 2 and dim_idx[1] < len(
shapes[1]) - 2
elif node.op_type == 'Expand':
# auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq])
shapes = [
self._get_shape(node, 0), self._get_value(node,
1)
]
else:
shapes = []
if shapes:
for idx in range(len(out_shape)):
if out_shape[
idx] is not None and not self._is_none_dim(
out_shape[idx]):
continue
# note that the broadcasting rule aligns from right to left
# if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge
dim_idx = [
len(s) - len(out_shape) + idx
for s in shapes
]
if len(dim_idx) > 0:
self._add_suggested_merge([
s[i] if is_literal(s[i]) else str(s[i])
for s, i in zip(shapes, dim_idx)
if i >= 0
])
self.run_ = True
else:
self.run_ = False
else:
self.run_ = False
# create new dynamic dims for ops not handled by symbolic shape inference
if self.run_ == False and not node.op_type in self.dispatcher_ and not known_aten_op:
is_unknown_op = out_type_undefined and (
out_shape is None or len(out_shape) == 0)
if is_unknown_op:
# unknown op to ONNX, maybe from higher opset or other domain
# only guess the output rank from input 0 when using guess_output_rank option
out_rank = self._get_shape_rank(
node, 0) if self.guess_output_rank_ else -1
else:
# valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape
out_rank = len(out_shape)
if out_rank >= 0:
new_shape = self._new_symbolic_shape(out_rank, node,
i_o)
if out_type_undefined:
# guess output data type from input vi if not defined
out_dtype = self.known_vi_[node.input[
0]].type.tensor_type.elem_type
else:
# otherwise, use original data type
out_dtype = vi.type.tensor_type.elem_type
vi.CopyFrom(
helper.make_tensor_value_info(
vi.name, out_dtype,
get_shape_from_sympy_shape(new_shape)))
if self.verbose_ > 0:
if is_unknown_op:
logger.debug(
"Possible unknown op: {} node: {}, guessing {} shape".
format(node.op_type, node.name,
vi.name))
if self.verbose_ > 2:
logger.debug(' {}: {} {}'.format(
node.output[i_o],
str(new_shape),
vi.type.tensor_type.elem_type))
self.run_ = True
continue # continue the inference after guess, no need to stop as no merge is needed
if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
logger.debug(
'Stopping at incomplete shape inference at ' +
node.op_type + ': ' + node.name)
logger.debug('node inputs:')
for i in node.input:
logger.debug(self.known_vi_[i])
logger.debug('node outputs:')
for o in node.output:
logger.debug(self.known_vi_[o])
if self.auto_merge_ and not out_type_undefined:
logger.debug('Merging: ' + str(
self.suggested_merge_))
return False
self.run_ = False
return True
def _update_output_from_vi(self):
for output in self.out_mp_.graph.output:
if output.name in self.known_vi_:
output.CopyFrom(self.known_vi_[output.name])
@staticmethod
def infer_shapes(in_mp,
int_max=2**31 - 1,
auto_merge=False,
guess_output_rank=False,
verbose=0):
onnx_opset = get_opset(in_mp)
if (not onnx_opset) or onnx_opset < 7:
logger.warning('Only support models of onnx opset 7 and above.')
return None
symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(in_mp)
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl()
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
raise Exception("Incomplete symbolic shape inference")
return symbolic_shape_inference.out_mp_
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True, help='The input model file')
parser.add_argument('--output', help='The output model file')
parser.add_argument(
'--auto_merge',
help='Automatically merge symbolic dims when confliction happens',
action='store_true',
default=False)
parser.add_argument(
'--int_max',
help='maximum value for integer to be treated as boundless for ops like slice',
type=int,
default=2**31 - 1)
parser.add_argument(
'--guess_output_rank',
help='guess output rank to be the same as input 0 for unknown ops',
action='store_true',
default=False)
parser.add_argument(
'--verbose',
help='Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed',
type=int,
default=0)
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
logger.info('input model: ' + args.input)
if args.output:
logger.info('output model ' + args.output)
logger.info('Doing symbolic shape inference...')
out_mp = SymbolicShapeInference.infer_shapes(
onnx.load(args.input), args.int_max, args.auto_merge,
args.guess_output_rank, args.verbose)
if args.output and out_mp:
onnx.save(out_mp, args.output)
logger.info('Done!')
#!/bin/bash
set -e
if [ $# != 3 ];then
# ./local/onnx_opt.sh model.old.onnx model.opt.onnx "audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
echo "usage: $0 onnx.model.in onnx.model.out input_shape "
exit 1
fi
# onnx optimizer
pip install onnx-simplifier
in=$1
out=$2
input_shape=$3
check_n=3
onnxsim $in $out $check_n --dynamic-input-shape --input-shape $input_shape
\ No newline at end of file
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# prune model by output names
import argparse
import copy
import sys
import onnx
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--output_names',
required=True,
nargs='+',
help='The outputs of pruned model.')
parser.add_argument(
'--save_file', required=True, help='Path to save the new onnx model.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
if len(set(args.output_names)) < len(args.output_names):
print(
"[ERROR] There's dumplicate name in --output_names, which is not allowed."
)
sys.exit(-1)
model = onnx.load(args.model)
# collect all node outputs and graph output
output_tensor_names = set()
for node in model.graph.node:
for out in node.output:
# may contain model output
output_tensor_names.add(out)
# for out in model.graph.output:
# output_tensor_names.add(out.name)
for output_name in args.output_names:
if output_name not in output_tensor_names:
print(
"[ERROR] Cannot find output tensor name '{}' in onnx model graph.".
format(output_name))
sys.exit(-1)
output_node_indices = set() # has output names
output_to_node = dict() # all node outputs
for i, node in enumerate(model.graph.node):
for out in node.output:
output_to_node[out] = i
if out in args.output_names:
output_node_indices.add(i)
# from outputs find all the ancestors
reserved_node_indices = copy.deepcopy(
output_node_indices) # nodes need to keep
reserved_inputs = set() # model input to keep
new_output_node_indices = copy.deepcopy(output_node_indices)
while True and len(new_output_node_indices) > 0:
output_node_indices = copy.deepcopy(new_output_node_indices)
new_output_node_indices = set()
for out_node_idx in output_node_indices:
# backtrace to parenet
for ipt in model.graph.node[out_node_idx].input:
if ipt in output_to_node:
reserved_node_indices.add(output_to_node[ipt])
new_output_node_indices.add(output_to_node[ipt])
else:
reserved_inputs.add(ipt)
num_inputs = len(model.graph.input)
num_outputs = len(model.graph.output)
num_nodes = len(model.graph.node)
print(
f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes"
)
print(f"{len(reserved_node_indices)} node to keep.")
# del node not to keep
for idx in range(num_nodes - 1, -1, -1):
if idx not in reserved_node_indices:
del model.graph.node[idx]
# del graph input not to keep
for idx in range(num_inputs - 1, -1, -1):
if model.graph.input[idx].name not in reserved_inputs:
del model.graph.input[idx]
# del old graph outputs
for i in range(num_outputs):
del model.graph.output[0]
# new graph output as user input
for out in args.output_names:
model.graph.output.extend([onnx.ValueInfoProto(name=out)])
# infer shape
try:
from onnx_infer_shape import SymbolicShapeInference
model = SymbolicShapeInference.infer_shapes(
model,
int_max=2**31 - 1,
auto_merge=True,
guess_output_rank=False,
verbose=1)
except Exception as e:
print(f"skip infer shape step: {e}")
# check onnx model
onnx.checker.check_model(model)
# save onnx model
onnx.save(model, args.save_file)
print("[Finished] The new model saved in {}.".format(args.save_file))
print("[DEBUG INFO] The inputs of new model: {}".format(
[x.name for x in model.graph.input]))
print("[DEBUG INFO] The outputs of new model: {}".format(
[x.name for x in model.graph.output]))
#!/usr/bin/env python3 -W ignore::DeprecationWarning
# rename node to new names
import argparse
import sys
import onnx
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model',
required=True,
help='Path of directory saved the input model.')
parser.add_argument(
'--origin_names',
required=True,
nargs='+',
help='The original name you want to modify.')
parser.add_argument(
'--new_names',
required=True,
nargs='+',
help='The new name you want change to, the number of new_names should be same with the number of origin_names'
)
parser.add_argument(
'--save_file', required=True, help='Path to save the new onnx model.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
if len(set(args.origin_names)) < len(args.origin_names):
print(
"[ERROR] There's dumplicate name in --origin_names, which is not allowed."
)
sys.exit(-1)
if len(set(args.new_names)) < len(args.new_names):
print(
"[ERROR] There's dumplicate name in --new_names, which is not allowed."
)
sys.exit(-1)
if len(args.new_names) != len(args.origin_names):
print(
"[ERROR] Number of --new_names must be same with the number of --origin_names."
)
sys.exit(-1)
model = onnx.load(args.model)
# collect input and all node output
output_tensor_names = set()
for ipt in model.graph.input:
output_tensor_names.add(ipt.name)
for node in model.graph.node:
for out in node.output:
output_tensor_names.add(out)
for origin_name in args.origin_names:
if origin_name not in output_tensor_names:
print(
f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph."
)
sys.exit(-1)
for new_name in args.new_names:
if new_name in output_tensor_names:
print(
"[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
)
sys.exit(-1)
# rename graph input
for i, ipt in enumerate(model.graph.input):
if ipt.name in args.origin_names:
idx = args.origin_names.index(ipt.name)
model.graph.input[i].name = args.new_names[idx]
# rename node input and output
for i, node in enumerate(model.graph.node):
for j, ipt in enumerate(node.input):
if ipt in args.origin_names:
idx = args.origin_names.index(ipt)
model.graph.node[i].input[j] = args.new_names[idx]
for j, out in enumerate(node.output):
if out in args.origin_names:
idx = args.origin_names.index(out)
model.graph.node[i].output[j] = args.new_names[idx]
# rename graph output
for i, out in enumerate(model.graph.output):
if out.name in args.origin_names:
idx = args.origin_names.index(out.name)
model.graph.output[i].name = args.new_names[idx]
# check onnx model
onnx.checker.check_model(model)
# save model
onnx.save(model, args.save_file)
print("[Finished] The new model saved in {}.".format(args.save_file))
print("[DEBUG INFO] The inputs of new model: {}".format(
[x.name for x in model.graph.input]))
print("[DEBUG INFO] The outputs of new model: {}".format(
[x.name for x in model.graph.output]))
#!/usr/bin/env python3
import argparse
from onnxruntime.quantization import quantize_dynamic
from onnxruntime.quantization import QuantType
def quantize_onnx_model(onnx_model_path,
quantized_model_path,
nodes_to_exclude=[]):
print("Starting quantization...")
quantize_dynamic(
onnx_model_path,
quantized_model_path,
weight_type=QuantType.QInt8,
nodes_to_exclude=nodes_to_exclude)
print(f"Quantized model saved to: {quantized_model_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-in",
type=str,
required=True,
help="ONNX model", )
parser.add_argument(
"--model-out",
type=str,
required=True,
default='model.quant.onnx',
help="ONNX model", )
parser.add_argument(
"--nodes-to-exclude",
type=str,
required=True,
help="nodes to exclude. e.g. conv,linear.", )
args = parser.parse_args()
nodes_to_exclude = args.nodes_to_exclude.split(',')
quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
import argparse
import onnxruntime as ort
# onnxruntime optimizer.
# https://onnxruntime.ai/docs/performance/graph-optimizations.html
# https://onnxruntime.ai/docs/api/python/api_summary.html#api
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_in', required=True, type=str, help='Path to onnx model.')
parser.add_argument(
'--opt_level',
required=True,
type=int,
default=0,
choices=[0, 1, 2],
help='Path to onnx model.')
parser.add_argument(
'--model_out', required=True, help='path to save the optimized model.')
parser.add_argument('--debug', default=False, help='output debug info.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
sess_options = ort.SessionOptions()
# Set graph optimization level
print(f"opt level: {args.opt_level}")
if args.opt_level == 0:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
elif args.opt_level == 1:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
else:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = args.model_out
session = ort.InferenceSession(args.model_in, sess_options)
#!/bin/bash
if [ $# != 4 ];then
# local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx
echo "usage: $0 model_dir model_name param_name onnx_output_name"
exit 1
fi
dir=$1
model=$2
param=$3
output=$4
pip install paddle2onnx
pip install onnx
# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2
# opset10 support quantize
paddle2onnx --model_dir $dir \
--model_filename $model \
--params_filename $param \
--save_file $output \
--enable_dev_version True \
--opset_version 11 \
--enable_onnx_checker True
\ No newline at end of file
# This contains the locations of binarys build required for running the examples.
MAIN_ROOT=`realpath $PWD/../../../../`
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
export PATH=$PATH:$TOOLS_BIN
#!/bin/bash
set -e
. path.sh
stage=0
stop_stage=50
tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz
#tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz
model_prefix=avg_10.jit
#model_prefix=avg_1.jit
model=${model_prefix}.pdmodel
param=${model_prefix}.pdiparams
. utils/parse_options.sh
data=data
exp=exp
mkdir -p $data $exp
dir=$data/exp/deepspeech2_online/checkpoints
# wenetspeech or aishell
model_type=$(echo $tarfile | cut -d '_' -f 4)
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
test -f $data/$tarfile || wget -P $data -c https://paddlespeech.bj.bcebos.com/s2t/$model_type/asr0/$tarfile
# wenetspeech ds2 model
pushd $data
tar zxvf $tarfile
popd
# ds2 model demo inputs
pushd $exp
wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle
popd
fi
input_file=$exp/static_ds2online_inputs.pickle
test -e $input_file
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then
# to onnx
./local/tonnx.sh $dir $model $param $exp/model.onnx
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.onnx
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ] ;then
# ort graph optmize
./local/ort_opt.py --model_in $exp/model.onnx --opt_level 0 --model_out $exp/model.ort.opt.onnx
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.ort.opt.onnx
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then
# convert opset_num to 11
./local/onnx_convert_opset.py --target-opset 11 --model-file $exp/model.ort.opt.onnx --save-model $exp/model.optset11.onnx
# quant model
nodes_to_exclude='p2o.Conv.0,p2o.Conv.2'
./local/ort_dyanmic_quant.py --model-in $exp/model.optset11.onnx --model-out $exp/model.optset11.quant.onnx --nodes-to-exclude "${nodes_to_exclude}"
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.optset11.quant.onnx
fi
# aishell rnn hidden is 1024
# wenetspeech rnn hiddn is 2048
if [ $model_type == 'aishell' ];then
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024"
elif [ $model_type == 'wenetspeech' ];then
input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,2048 chunk_state_h_box:5,1,2048"
else
echo "not support: $model_type"
exit -1
fi
if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ] ;then
# wenetspeech ds2 model execed 2GB limit, will error.
# simplifying onnx model
./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape"
./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.opt.onnx
fi
../../../../utils/
\ No newline at end of file
# Streaming DeepSpeech2 Server with WebSocket
This example is about using `websocket` as streaming deepspeech2 server. For deepspeech2 model training please see [here](../../../../examples/aishell/asr0/).
The websocket protocal is same to [PaddleSpeech Server](../../../../demos/streaming_asr_server/),
for detail of implementation please see [here](../../../speechx/protocol/websocket/).
## Source path.sh
```bash
. path.sh
```
SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`.
## Start WebSocket Server
```bash
bash websoket_server.sh
```
The output is like below:
```text
I1130 02:19:32.029882 12856 cmvn_json2kaldi_main.cc:39] cmvn josn path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/data/mean_std.json
I1130 02:19:32.032230 12856 cmvn_json2kaldi_main.cc:73] nframe: 907497
I1130 02:19:32.032564 12856 cmvn_json2kaldi_main.cc:85] cmvn stats have write into: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark
I1130 02:19:32.032579 12856 cmvn_json2kaldi_main.cc:86] Binary: 1
I1130 02:19:32.798342 12937 feature_pipeline.h:53] cmvn file: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark
I1130 02:19:32.798542 12937 feature_pipeline.h:58] dither: 0
I1130 02:19:32.798583 12937 feature_pipeline.h:60] frame shift ms: 10
I1130 02:19:32.798588 12937 feature_pipeline.h:62] feature type: linear
I1130 02:19:32.798596 12937 feature_pipeline.h:80] frame length ms: 20
I1130 02:19:32.798601 12937 feature_pipeline.h:88] subsampling rate: 4
I1130 02:19:32.798606 12937 feature_pipeline.h:90] nnet receptive filed length: 7
I1130 02:19:32.798611 12937 feature_pipeline.h:92] nnet chunk size: 1
I1130 02:19:32.798615 12937 feature_pipeline.h:94] frontend fill zeros: 0
I1130 02:19:32.798630 12937 nnet_itf.h:52] subsampling rate: 4
I1130 02:19:32.798635 12937 nnet_itf.h:54] model path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdmodel
I1130 02:19:32.798640 12937 nnet_itf.h:57] param path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdiparams
I1130 02:19:32.798643 12937 nnet_itf.h:59] DS2 param:
I1130 02:19:32.798647 12937 nnet_itf.h:61] cache names: chunk_state_h_box,chunk_state_c_box
I1130 02:19:32.798652 12937 nnet_itf.h:63] cache shape: 5-1-1024,5-1-1024
I1130 02:19:32.798656 12937 nnet_itf.h:65] input names: audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box
I1130 02:19:32.798660 12937 nnet_itf.h:67] output names: softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0
I1130 02:19:32.798664 12937 ctc_tlg_decoder.h:41] fst path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//TLG.fst
I1130 02:19:32.798669 12937 ctc_tlg_decoder.h:42] fst symbole table: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//words.txt
I1130 02:19:32.798673 12937 ctc_tlg_decoder.h:47] LatticeFasterDecoder max active: 7500
I1130 02:19:32.798677 12937 ctc_tlg_decoder.h:49] LatticeFasterDecoder beam: 15
I1130 02:19:32.798681 12937 ctc_tlg_decoder.h:50] LatticeFasterDecoder lattice_beam: 7.5
I1130 02:19:32.798708 12937 websocket_server_main.cc:37] Listening at port 8082
```
## Start WebSocket Client
```bash
bash websocket_client.sh
```
This script using AISHELL-1 test data to call websocket server.
The input is specific by `--wav_rspecifier=scp:$data/$aishell_wav_scp`.
The `scp` file which look like this:
```text
# head data/split1/1/aishell_test.scp
BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav
BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav
...
BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav
```
If you want to recognize one wav, you can make `scp` file like this:
```text
key path/to/wav/file
```
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
SPEECHX_BIN=$SPEECHX_BUILD/protocol/websocket:$SPEECHX_BUILD/frontend/audio
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
# output
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
export GLOG_logtostderr=1
# websocket client
websocket_client_main \
--wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.5
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
if [ ! -f $ckpt_dir/data/mean_std.json ]; then
mkdir -p $ckpt_dir
pushd $ckpt_dir
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
export GLOG_logtostderr=1
# 3. gen cmvn
cmvn=$data/cmvn.ark
cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
wfst=$data/wfst/
mkdir -p $wfst
if [ ! -f $wfst/aishell_graph.zip ]; then
pushd $wfst
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip aishell_graph.zip
mv aishell_graph/* $wfst
popd
fi
# 5. test websocket server
websocket_server_main \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$wfst/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$wfst/TLG.fst --max_active=7500 \
--acoustic_scale=1.2
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
unset GREP_OPTIONS unset GREP_OPTIONS
SPEECHX_ROOT=$PWD/../../../ SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx/asr
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
...@@ -12,7 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin ...@@ -12,7 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/../common/frontend/audio:$SPEECHX_BUILD/recognizer
PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);") PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
set(srcs) set(srcs)
if (USING_DS2)
list(APPEND srcs list(APPEND srcs
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_beam_search_decoder.cc
ctc_tlg_decoder.cc
)
endif()
if (USING_U2)
list(APPEND srcs
ctc_prefix_beam_search_decoder.cc ctc_prefix_beam_search_decoder.cc
) )
endif()
add_library(decoder STATIC ${srcs}) add_library(decoder STATIC ${srcs})
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
# test # test
if (USING_DS2) set(TEST_BINS
set(BINS
ctc_beam_search_decoder_main
nnet_logprob_decoder_main
ctc_tlg_decoder_main
)
foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach()
endif()
if (USING_U2)
set(TEST_BINS
ctc_prefix_beam_search_decoder_main ctc_prefix_beam_search_decoder_main
) )
foreach(bin_name IN LISTS TEST_BINS) foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach() endforeach()
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "base/common.h"
#include "decoder/ctc_decoders/decoder_utils.h"
#include "utils/file_utils.h"
namespace ppspeech {
using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: "
<< vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path;
if (opts_.lm_path != "") {
init_ext_scorer_ = std::make_shared<Scorer>(
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
}
CHECK_EQ(opts_.blank, 0);
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id_ = it - vocabulary_.begin();
// if no space in vocabulary
if (static_cast<size_t>(space_id_) >= vocabulary_.size()) {
space_id_ = -2;
}
}
void CTCBeamSearch::Reset() {
// num_frame_decoded_ = 0;
// ResetPrefixes();
InitDecoder();
}
void CTCBeamSearch::InitDecoder() {
num_frame_decoded_ = 0;
// ResetPrefixes();
prefixes_.clear();
root_ = std::make_shared<PathTrie>();
root_->score = root_->log_prob_b_prev = 0.0;
prefixes_.push_back(root_.get());
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root_->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root_->set_matcher(matcher);
}
}
void CTCBeamSearch::Decode(
std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
}
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (1) {
vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob;
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
}
}
void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes_.size(); i++) {
if (prefixes_[i] != nullptr) {
delete prefixes_[i];
prefixes_[i] = nullptr;
}
}
prefixes_.clear();
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
const vector<string>& nbest_words) {
kaldi::Timer timer;
AdvanceDecoding(probs);
LOG(INFO) << "ctc decoding elapsed time(s) "
<< static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath(int n) {
int beam_size = n == -1 ? opts_.beam_size : std::min(n, opts_.beam_size);
return get_beam_search_result(prefixes_, vocabulary_, beam_size);
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return GetNBestPath(-1);
}
string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
return result[0].second;
}
string CTCBeamSearch::GetFinalBestPath() {
CalculateApproxScore();
LMRescore();
return GetBestPath();
}
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(),
vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
probs_seq[i][j] = static_cast<double>(probs[i][j]);
}
}
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
std::sort(prefixes_.begin(),
prefixes_.begin() + num_prefixes_,
prefix_compare);
if (num_prefixes_ == 0) {
continue;
}
min_cutoff = prefixes_[num_prefixes_ - 1]->score +
std::log(prob[opts_.blank]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes_ == beam_size);
}
vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes_.clear();
// update log probs
root_->iterate_to_vec(prefixes_);
// only preserve top beam_size prefixes_
if (prefixes_.size() >= beam_size) {
std::nth_element(prefixes_.begin(),
prefixes_.begin() + beam_size,
prefixes_.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes_.size(); ++i) {
prefixes_[i]->remove();
}
} // end if
num_frame_decoded_++;
} // end for probs_seq
}
int32 CTCBeamSearch::SearchOneChar(
const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes_.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes_[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
if (c == opts_.blank) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id_ || init_ext_scorer_->is_character_based())) {
PathTrie* prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
return 0;
}
void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size;
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
std::sort(
prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
double approx_ctc = prefixes_[i]->score;
if (init_ext_scorer_ != nullptr) {
vector<int> output;
prefixes_[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = init_ext_scorer_->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight:
approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
init_ext_scorer_->alpha;
}
prefixes_[i]->approx_ctc = approx_ctc;
}
}
void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
auto prefix = prefixes_[i];
if (!prefix->is_empty() && prefix->character != space_id_) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
score += init_ext_scorer_->beta;
prefix->score += score;
}
}
}
}
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// used by deepspeech2
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/decoder_itf.h"
namespace ppspeech {
class CTCBeamSearch : public DecoderBase {
public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {}
void InitDecoder();
void Reset();
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(int n);
std::string GetFinalBestPath();
std::string GetPartialResult() {
CHECK(false) << "Not implement.";
return {};
}
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
const std::vector<std::string>& nbest_words);
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<std::string> vocabulary_; // todo remove later
int space_id_;
std::shared_ptr<PathTrie> root_;
std::vector<PathTrie*> prefixes_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// used by deepspeech2
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "", "language model");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(subsampling_rate,
4,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK_NE(FLAGS_result_wspecifier, "");
CHECK_NE(FLAGS_feature_rspecifier, "");
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_path = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
return (num_done != 0 ? 0 : 1);
}
ThreadPool/
build/
dist/
kenlm/
openfst-1.6.3/
openfst-1.6.3.tar.gz
swig_decoders.egg-info/
decoders_wrap.cxx
swig_decoders.py
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ctc_beam_search_decoder.h"
#include <algorithm>
#include <cmath>
#include <iostream>
#include <limits>
#include <map>
#include <utility>
#include "ThreadPool.h"
#include "fst/fstlib.h"
#include "decoder_utils.h"
#include "path_trie.h"
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, std::string>> ctc_beam_search_decoding(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
// vocabulary.size() + 1,
vocabulary.size(),
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
std::vector<PathTrie *> prefixes;
prefixes.push_back(&root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
// prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
auto &prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}
std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur =
log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
}
} // end of loop over time
// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score += score;
}
}
}
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute approximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
return get_beam_search_result(prefixes, vocabulary, beam_size);
}
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoding_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(num_processes);
// number of samples
size_t batch_size = probs_split.size();
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoding,
probs_split[i],
vocabulary,
beam_size,
cutoff_prob,
cutoff_top_n,
ext_scorer,
blank_id));
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer) {
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
}
void ctc_beam_search_decode_chunk(
PathTrie *root,
std::vector<PathTrie *> &prefixes,
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
// vocabulary.size() + 1,
vocabulary.size(),
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// init prefixes' root
//
// prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
auto &prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}
std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur =
log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
}
} // end of loop over time
return;
}
std::vector<std::pair<double, std::string>> get_decode_result(
std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size,
Scorer *ext_scorer) {
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score += score;
}
}
}
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
std::vector<std::pair<double, std::string>> res =
get_beam_search_result(prefixes, vocabulary, beam_size);
// pay back the last word of each prefix that doesn't end with space (for
// decoding by chunk)
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score -= score;
}
}
}
return res;
}
void free_storage(std::unique_ptr<CtcBeamSearchDecoderStorage> &storage) {
storage = nullptr;
}
CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {}
CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch(
const std::vector<std::string> &vocabulary,
size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id)
: batch_size(batch_size),
beam_size(beam_size),
num_processes(num_processes),
cutoff_prob(cutoff_prob),
cutoff_top_n(cutoff_top_n),
ext_scorer(ext_scorer),
blank_id(blank_id) {
VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!");
VALID_CHECK_GT(
this->num_processes, 0, "num_processes must be nonnegative!");
this->vocabulary = vocabulary;
for (size_t i = 0; i < batch_size; i++) {
this->decoder_storage_vector.push_back(
std::unique_ptr<CtcBeamSearchDecoderStorage>(
new CtcBeamSearchDecoderStorage()));
ctc_beam_search_decode_chunk_begin(
this->decoder_storage_vector[i]->root, ext_scorer);
}
};
/**
* Input
* probs_split: shape [B, T, D]
*/
void CtcBeamSearchDecoderBatch::next(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &has_value) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool
size_t num_has_value = 0;
for (int i = 0; i < has_value.size(); i++)
if (has_value[i] == "true") num_has_value += 1;
ThreadPool pool(std::min(num_processes, num_has_value));
// number of samples
size_t probs_num = probs_split.size();
VALID_CHECK_EQ(this->batch_size,
probs_num,
"The batch size of the current input data should be same "
"with the input data before");
// enqueue the tasks of decoding
std::vector<std::future<void>> res;
for (size_t i = 0; i < batch_size; ++i) {
if (has_value[i] == "true") {
res.emplace_back(pool.enqueue(
ctc_beam_search_decode_chunk,
std::ref(this->decoder_storage_vector[i]->root),
std::ref(this->decoder_storage_vector[i]->prefixes),
probs_split[i],
this->vocabulary,
this->beam_size,
this->cutoff_prob,
this->cutoff_top_n,
this->ext_scorer,
this->blank_id));
}
}
for (size_t i = 0; i < batch_size; ++i) {
res[i].get();
}
return;
};
/**
* Return
* batch_result: shape[B, beam_size,(-approx_ctc score, string)]
*/
std::vector<std::vector<std::pair<double, std::string>>>
CtcBeamSearchDecoderBatch::decode() {
VALID_CHECK_GT(
this->num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(this->num_processes);
// number of samples
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < this->batch_size; ++i) {
res.emplace_back(
pool.enqueue(get_decode_result,
std::ref(this->decoder_storage_vector[i]->prefixes),
this->vocabulary,
this->beam_size,
this->ext_scorer));
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (size_t i = 0; i < this->batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
/**
* reset the state of ctcBeamSearchDecoderBatch
*/
void CtcBeamSearchDecoderBatch::reset_state(size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n) {
this->batch_size = batch_size;
this->beam_size = beam_size;
this->num_processes = num_processes;
this->cutoff_prob = cutoff_prob;
this->cutoff_top_n = cutoff_top_n;
VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!");
VALID_CHECK_GT(
this->num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(this->num_processes);
// number of samples
// enqueue the tasks of decoding
std::vector<std::future<void>> res;
size_t storage_size = decoder_storage_vector.size();
for (size_t i = 0; i < storage_size; i++) {
res.emplace_back(pool.enqueue(
free_storage, std::ref(this->decoder_storage_vector[i])));
}
for (size_t i = 0; i < storage_size; ++i) {
res[i].get();
}
std::vector<std::unique_ptr<CtcBeamSearchDecoderStorage>>().swap(
decoder_storage_vector);
for (size_t i = 0; i < this->batch_size; i++) {
this->decoder_storage_vector.push_back(
std::unique_ptr<CtcBeamSearchDecoderStorage>(
new CtcBeamSearchDecoderStorage()));
ctc_beam_search_decode_chunk_begin(
this->decoder_storage_vector[i]->root, this->ext_scorer);
}
}
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#include <string>
#include <utility>
#include <vector>
#include "scorer.h"
/* CTC Beam Search Decoder
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::vector<std::pair<double, std::string>> ctc_beam_search_decoding(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
/* CTC Beam Search Decoder for batch data
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoding_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary,
size_t beam_size,
size_t num_processes,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
/**
* Store the root and prefixes for decoder
*/
class CtcBeamSearchDecoderStorage {
public:
PathTrie *root = nullptr;
std::vector<PathTrie *> prefixes;
CtcBeamSearchDecoderStorage() {
// init prefixes' root
this->root = new PathTrie();
this->root->log_prob_b_prev = 0.0;
// The score of root is in log scale.Since the prob=1.0, the prob score
// in log scale is 0.0
this->root->score = root->log_prob_b_prev;
// std::vector<PathTrie *> prefixes;
this->prefixes.push_back(root);
};
~CtcBeamSearchDecoderStorage() {
if (root != nullptr) {
delete root;
root = nullptr;
}
};
};
/**
* The ctc beam search decoder, support batchsize >= 1
*/
class CtcBeamSearchDecoderBatch {
public:
CtcBeamSearchDecoderBatch(const std::vector<std::string> &vocabulary,
size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id);
~CtcBeamSearchDecoderBatch();
void next(const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &has_value);
std::vector<std::vector<std::pair<double, std::string>>> decode();
void reset_state(size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n);
private:
std::vector<std::string> vocabulary;
size_t batch_size;
size_t beam_size;
size_t num_processes;
double cutoff_prob;
size_t cutoff_top_n;
Scorer *ext_scorer;
size_t blank_id;
std::vector<std::unique_ptr<CtcBeamSearchDecoderStorage>>
decoder_storage_vector;
};
/**
* function for chunk decoding
*/
void ctc_beam_search_decode_chunk(
PathTrie *root,
std::vector<PathTrie *> &prefixes,
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id);
std::vector<std::pair<double, std::string>> get_decode_result(
std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size,
Scorer *ext_scorer);
/**
* free the CtcBeamSearchDecoderStorage
*/
void free_storage(std::unique_ptr<CtcBeamSearchDecoderStorage> &storage);
/**
* initialize the root
*/
void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer);
#endif // CTC_BEAM_SEARCH_DECODER_H_
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ctc_greedy_decoder.h"
#include "decoder_utils.h"
std::string ctc_greedy_decoding(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size(),
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec(num_time_steps, 0);
std::vector<size_t> idx_vec;
for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0;
size_t max_idx = 0;
const std::vector<double> &probs_step = probs_seq[i];
for (size_t j = 0; j < probs_step.size(); ++j) {
if (max_prob < probs_step[j]) {
max_idx = j;
max_prob = probs_step[j];
}
}
// id with maximum probability in current time step
max_idx_vec[i] = max_idx;
// deduplicate
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]);
}
}
std::string best_path_result;
for (size_t i = 0; i < idx_vec.size(); ++i) {
if (idx_vec[i] != blank_id) {
std::string ch = vocabulary[idx_vec[i]];
best_path_result += (ch == kSPACE) ? tSPACE : ch;
}
}
return best_path_result;
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CTC_GREEDY_DECODER_H
#define CTC_GREEDY_DECODER_H
#include <string>
#include <vector>
/* CTC Greedy (Best Path) Decoder
*
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* The decoding result in string
*/
std::string ctc_greedy_decoding(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary,
size_t blank_id);
#endif // CTC_GREEDY_DECODER_H
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder_utils.h"
#include <algorithm>
#include <cmath>
#include <limits>
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step,
double cutoff_prob,
size_t cutoff_top_n) {
std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob_step.size(); ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
}
// pruning of vocabulary
size_t cutoff_len = prob_step.size();
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort(prob_idx.begin(),
prob_idx.end(),
pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) {
double cum_prob = 0.0;
cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n)
break;
}
}
prob_idx = std::vector<std::pair<int, double>>(
prob_idx.begin(), prob_idx.begin() + cutoff_len);
}
std::vector<std::pair<size_t, float>> log_prob_idx;
for (size_t i = 0; i < cutoff_len; ++i) {
log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
}
return log_prob_idx;
}
std::vector<std::pair<double, std::string>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size) {
// allow for the post processing
std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]);
}
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::string>> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
std::vector<int> output;
space_prefixes[i]->get_path_vec(output);
// convert index to string
std::string output_str;
for (size_t j = 0; j < output.size(); j++) {
std::string ch = vocabulary[output[j]];
output_str += (ch == kSPACE) ? tSPACE : ch;
}
std::pair<double, std::string> output_pair(
-space_prefixes[i]->approx_ctc, output_str);
output_vecs.emplace_back(output_pair);
}
return output_vecs;
}
size_t get_utf8_str_len(const std::string &str) {
size_t str_len = 0;
for (char c : str) {
str_len += ((c & 0xc0) != 0x80);
}
return str_len;
}
std::vector<std::string> split_utf8_str(const std::string &str) {
std::vector<std::string> result;
std::string out_str;
for (char c : str) {
if ((c & 0xc0) != 0x80) // new UTF-8 character
{
if (!out_str.empty()) {
result.push_back(out_str);
out_str.clear();
}
}
out_str.append(1, c);
}
result.push_back(out_str);
return result;
}
std::vector<std::string> split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size();
while (true) {
std::size_t end = s.find(delim, start);
if (end == std::string::npos) {
if (start < s.size()) {
result.push_back(s.substr(start));
}
break;
}
if (end > start) {
result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
}
return result;
}
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
if (x->score == y->score) {
if (x->character == y->character) {
return false;
} else {
return (x->character < y->character);
}
} else {
return x->score > y->score;
}
}
void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary) {
if (dictionary->NumStates() == 0) {
fst::StdVectorFst::StateId start = dictionary->AddState();
assert(start == 0);
dictionary->SetStart(start);
}
fst::StdVectorFst::StateId src = dictionary->Start();
fst::StdVectorFst::StateId dst;
for (auto c : word) {
dst = dictionary->AddState();
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
src = dst;
}
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
}
bool add_word_to_dictionary(
const std::string &word,
const std::unordered_map<std::string, int> &char_map,
bool add_space,
int SPACE_ID,
fst::StdVectorFst *dictionary) {
auto characters = split_utf8_str(word);
std::vector<int> int_word;
for (auto &c : characters) {
if (c == " ") {
int_word.push_back(SPACE_ID);
} else {
auto int_c = char_map.find(c);
if (int_c != char_map.end()) {
int_word.push_back(int_c->second);
} else {
return false; // return without adding
}
}
}
if (add_space) {
int_word.push_back(SPACE_ID);
}
add_word_to_fst(int_word, dictionary);
return true; // return with successful adding
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_
#include <string>
#include <utility>
#include "fst/log.h"
#include "path_trie.h"
const std::string kSPACE = "<space>";
const std::string tSPACE = " ";
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// inline function for validation check
inline void check(
bool x, const char *expr, const char *file, int line, const char *err) {
if (!x) {
std::cout << "[" << file << ":" << line << "] ";
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
}
}
#define VALID_CHECK(x, info) \
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) {
return a.first > b.first;
}
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) {
return a.second > b.second;
}
// Return the sum of two probabilities in log scale
template <typename T>
T log_sum_exp(const T &x, const T &y) {
static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y;
if (y <= num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
}
// Get pruned probability vector for each time step's beam search
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step,
double cutoff_prob,
size_t cutoff_top_n);
// Get beam search result from prefixes in trie tree
std::vector<std::pair<double, std::string>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size);
// Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y);
/* Get length of utf8 encoding string
* See: http://stackoverflow.com/a/4063229
*/
size_t get_utf8_str_len(const std::string &str);
/* Split a string into a list of strings on a given string
* delimiter. NB: delimiters on beginning / end of string are
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
*/
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
/* Splits string into vector of strings representing
* UTF-8 characters (not same as chars)
*/
std::vector<std::string> split_utf8_str(const std::string &str);
// Add a word in index to the dicionary of fst
void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary);
// Add a word in string to dictionary
bool add_word_to_dictionary(
const std::string &word,
const std::unordered_map<std::string, int> &char_map,
bool add_space,
int SPACE_ID,
fst::StdVectorFst *dictionary);
#endif // DECODER_UTILS_H
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "path_trie.h"
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "decoder_utils.h"
PathTrie::PathTrie() {
log_prob_b_prev = -NUM_FLT_INF;
log_prob_nb_prev = -NUM_FLT_INF;
log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF;
score = -NUM_FLT_INF;
ROOT_ = -1;
character = ROOT_;
exists_ = true;
parent = nullptr;
dictionary_ = nullptr;
dictionary_state_ = 0;
has_dictionary_ = false;
matcher_ = nullptr;
}
PathTrie::~PathTrie() {
for (auto child : children_) {
delete child.second;
child.second = nullptr;
}
}
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) {
break;
}
}
if (child != children_.end()) {
if (!child->second->exists_) {
child->second->exists_ = true;
child->second->log_prob_b_prev = -NUM_FLT_INF;
child->second->log_prob_nb_prev = -NUM_FLT_INF;
child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->log_prob_nb_cur = -NUM_FLT_INF;
}
return (child->second);
} else {
if (has_dictionary_) {
matcher_->SetState(dictionary_state_);
bool found = matcher_->Find(new_char + 1);
if (!found) {
// Adding this character causes word outside dictionary
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = dictionary_->Final(dictionary_state_);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
dictionary_state_ = dictionary_->Start();
}
return nullptr;
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
new_path->dictionary_ = dictionary_;
new_path->dictionary_state_ = matcher_->Value().nextstate;
new_path->has_dictionary_ = true;
new_path->matcher_ = matcher_;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
}
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
return get_path_vec(output, ROOT_);
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
int stop,
size_t max_steps) {
if (character == stop || character == ROOT_ || output.size() == max_steps) {
std::reverse(output.begin(), output.end());
return this;
} else {
output.push_back(character);
return parent->get_path_vec(output, stop, max_steps);
}
}
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
if (exists_) {
log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur;
log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF;
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
output.push_back(this);
}
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
}
void PathTrie::remove() {
exists_ = false;
if (children_.size() == 0) {
if (parent != nullptr) {
auto child = parent->children_.begin();
for (child = parent->children_.begin();
child != parent->children_.end();
++child) {
if (child->first == character) {
parent->children_.erase(child);
break;
}
}
if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove();
}
}
delete this;
}
}
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
dictionary_ = dictionary;
dictionary_state_ = dictionary->Start();
has_dictionary_ = true;
}
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
matcher_ = matcher;
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "fst/fstlib.h"
/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
*/
class PathTrie {
public:
PathTrie();
~PathTrie();
// get new prefix after appending new char
PathTrie* get_path_trie(int new_char, bool reset = true);
// get the prefix in index from root to current node
PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(
std::vector<int>& output,
int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
// update log probs
void iterate_to_vec(std::vector<PathTrie*>& output);
// set dictionary for FST
void set_dictionary(fst::StdVectorFst* dictionary);
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
bool is_empty() { return ROOT_ == character; }
// remove current path from root
void remove();
float log_prob_b_prev;
float log_prob_nb_prev;
float log_prob_b_cur;
float log_prob_nb_cur;
float score;
float approx_ctc;
int character;
PathTrie* parent;
private:
int ROOT_;
bool exists_;
bool has_dictionary_;
std::vector<std::pair<int, PathTrie*>> children_;
// pointer to dictionary of FST
fst::StdVectorFst* dictionary_;
fst::StdVectorFst::StateId dictionary_state_;
// true if finding ars in FST
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
};
#endif // PATH_TRIE_H
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the
// "COPYING.LESSER.3");
#include "scorer.h"
#include <unistd.h>
#include <iostream>
#include "lm/config.hh"
#include "lm/model.hh"
#include "lm/state.hh"
#include "decoder_utils.h"
using namespace lm::ngram;
// if your platform is windows ,you need add the define
#define F_OK 0
Scorer::Scorer(double alpha,
double beta,
const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
this->alpha = alpha;
this->beta = beta;
dictionary = nullptr;
is_character_based_ = true;
language_model_ = nullptr;
max_order_ = 0;
dict_size_ = 0;
SPACE_ID_ = -1;
setup(lm_path, vocab_list);
}
Scorer::~Scorer() {
if (language_model_ != nullptr) {
delete static_cast<lm::base::Model*>(language_model_);
}
if (dictionary != nullptr) {
delete static_cast<fst::StdVectorFst*>(dictionary);
}
}
void Scorer::setup(const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
// load language model
load_lm(lm_path);
// set char map for scorer
set_char_map(vocab_list);
// fill the dictionary for FST
if (!is_character_based()) {
fill_dictionary(true);
}
}
void Scorer::load_lm(const std::string& lm_path) {
const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
RetriveStrEnumerateVocab enumerate;
lm::ngram::Config config;
config.enumerate_vocab = &enumerate;
language_model_ = lm::ngram::LoadVirtual(filename, config);
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
vocabulary_ = enumerate.vocabulary;
for (size_t i = 0; i < vocabulary_.size(); ++i) {
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
is_character_based_ = false;
}
}
}
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
double cond_prob;
lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin
model->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV
if (word_index == 0) {
return OOV_SCORE;
}
cond_prob = model->BaseScore(&state, word_index, &out_state);
tmp_state = state;
state = out_state;
out_state = tmp_state;
}
// return log10 prob
return cond_prob;
}
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
std::vector<std::string> sentence;
if (words.size() == 0) {
for (size_t i = 0; i < max_order_; ++i) {
sentence.push_back(START_TOKEN);
}
} else {
for (size_t i = 0; i < max_order_ - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
}
sentence.push_back(END_TOKEN);
return get_log_prob(sentence);
}
double Scorer::get_log_prob(const std::vector<std::string>& words) {
assert(words.size() > max_order_);
double score = 0.0;
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + max_order_);
score += get_log_cond_prob(ngram);
}
return score;
}
void Scorer::reset_params(float alpha, float beta) {
this->alpha = alpha;
this->beta = beta;
}
std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word;
for (auto ind : input) {
word += char_list_[ind];
}
return word;
}
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
if (labels.empty()) return {};
std::string s = vec2str(labels);
std::vector<std::string> words;
if (is_character_based_) {
words = split_utf8_str(s);
} else {
words = split_str(s, " ");
}
return words;
}
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
char_list_ = char_list;
char_map_.clear();
// Set the char map for the FST for spelling correction
for (size_t i = 0; i < char_list_.size(); i++) {
if (char_list_[i] == kSPACE) {
SPACE_ID_ = i;
}
// The initial state of FST is state 0, hence the index of chars in
// the FST should start from 1 to avoid the conflict with the initial
// state, otherwise wrong decoding results would be given.
char_map_[char_list_[i]] = i + 1;
}
}
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<std::string> ngram;
PathTrie* current_node = prefix;
PathTrie* new_node = nullptr;
for (int order = 0; order < max_order_; order++) {
std::vector<int> prefix_vec;
if (is_character_based_) {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
current_node = new_node;
} else {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
current_node = new_node->parent; // Skipping spaces
}
// reconstruct word
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
return ngram;
}
void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary;
// For each unigram convert to ints and put in trie
int dict_size = 0;
for (const auto& word : vocabulary_) {
bool added = add_word_to_dictionary(
word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
dict_size += added ? 1 : 0;
}
dict_size_ = dict_size;
/* Simplify FST
* This gets rid of "epsilon" transitions in the FST.
* These are transitions that don't require a string input to be taken.
* Getting rid of them is necessary to make the FST deterministic, but
* can greatly increase the size of the FST
*/
fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
/* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state)
*/
fst::Determinize(dictionary, new_dict);
/* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary
*/
fst::Minimize(new_dict);
this->dictionary = new_dict;
}
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the
// "COPYING.LESSER.3");
#ifndef SCORER_H_
#define SCORER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
#include "lm/word_index.hh"
#include "path_trie.h"
const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
// Implement a callback to retrive the dictionary of language model.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public:
RetriveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece &str) {
vocabulary.push_back(std::string(str.data(), str.length()));
}
std::vector<std::string> vocabulary;
};
/* External scorer to query score for n-gram or sentence, including language
* model scoring and word insertion.
*
* Example:
* Scorer scorer(alpha, beta, "path_of_language_model");
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/
class Scorer {
public:
Scorer(double alpha,
double beta,
const std::string &lm_path,
const std::vector<std::string> &vocabulary);
~Scorer();
double get_log_cond_prob(const std::vector<std::string> &words);
double get_sent_log_prob(const std::vector<std::string> &words);
// return the max order
size_t get_max_order() const { return max_order_; }
// return the dictionary size of language model
size_t get_dict_size() const { return dict_size_; }
// retrun true if the language model is character based
bool is_character_based() const { return is_character_based_; }
// reset params alpha & beta
void reset_params(float alpha, float beta);
// make ngram for a given prefix
std::vector<std::string> make_ngram(PathTrie *prefix);
// trransform the labels in index to the vector of words (word based lm) or
// the vector of characters (character based lm)
std::vector<std::string> split_labels(const std::vector<int> &labels);
// language model weight
double alpha;
// word insertion weight
double beta;
// pointer to the dictionary of FST
void *dictionary;
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void setup(const std::string &lm_path,
const std::vector<std::string> &vocab_list);
// load language model from given path
void load_lm(const std::string &lm_path);
// fill dictionary for FST
void fill_dictionary(bool add_space);
// set char map
void set_char_map(const std::vector<std::string> &char_list);
double get_log_prob(const std::vector<std::string> &words);
// translate the vector in index to string
std::string vec2str(const std::vector<int> &input);
private:
void *language_model_;
bool is_character_based_;
size_t max_order_;
size_t dict_size_;
int SPACE_ID_;
std::vector<std::string> char_list_;
std::unordered_map<std::string, int> char_map_;
std::vector<std::string> vocabulary_;
};
#endif // SCORER_H_
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test decoder by feeding nnet posterior probability
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader likelihood_reader(
FLAGS_nnet_prob_respecifier);
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nullptr, nullptr));
decoder.InitDecoder();
for (; !likelihood_reader.Done(); likelihood_reader.Next()) {
string utt = likelihood_reader.Key();
const kaldi::Matrix<BaseFloat> likelihood = likelihood_reader.Value();
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << likelihood.NumRows();
LOG(INFO) << "cols: " << likelihood.NumCols();
decodable->Acceptlikelihood(likelihood);
decoder.AdvanceDecode(decodable);
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h" //#include "decoder/ctc_tlg_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
// feature // feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
......
set(srcs decodable.cc nnet_producer.cc) set(srcs decodable.cc nnet_producer.cc)
if(USING_DS2) list(APPEND srcs u2_nnet.cc)
list(APPEND srcs ds2_nnet.cc)
endif()
if(USING_U2)
list(APPEND srcs u2_nnet.cc)
endif()
add_library(nnet STATIC ${srcs}) add_library(nnet STATIC ${srcs})
target_link_libraries(nnet utils) target_link_libraries(nnet utils)
if(USING_U2) target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
endif()
if(USING_DS2)
set(bin_name ds2_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
target_link_libraries(${bin_name} ${DEPS})
endif()
# test bin # test bin
#if(USING_U2) #if(USING_U2)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/ds2_nnet.h"
#include "utils/strings.h"
namespace ppspeech {
using kaldi::Matrix;
using kaldi::Vector;
using std::shared_ptr;
using std::string;
using std::vector;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names;
cache_names = StrSplit(opts.cache_names, ",");
std::vector<std::string> cache_shapes;
cache_shapes = StrSplit(opts.cache_shape, ",");
assert(cache_shapes.size() == cache_names.size());
cache_encouts_.clear();
cache_names_idx_.clear();
for (size_t i = 0; i < cache_shapes.size(); i++) {
std::vector<std::string> tmp_shape;
tmp_shape = StrSplit(cache_shapes[i], "-");
std::vector<int> cur_shape;
std::transform(tmp_shape.begin(),
tmp_shape.end(),
std::back_inserter(cur_shape),
[](const std::string& s) { return atoi(s.c_str()); });
cache_names_idx_[cache_names[i]] = i;
std::shared_ptr<Tensor<BaseFloat>> cache_eout =
std::make_shared<Tensor<BaseFloat>>(cur_shape);
cache_encouts_.push_back(cache_eout);
}
}
PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
subsampling_rate_ = opts.subsample_rate;
paddle_infer::Config config;
config.SetModel(opts.model_path, opts.param_path);
if (opts.use_gpu) {
config.EnableUseGpu(500, 0);
}
config.SwitchIrOptim(opts.switch_ir_optim);
if (opts.enable_fc_padding == false) {
config.DisableFCPadding();
}
if (opts.enable_profile) {
config.EnableProfile();
}
pool.reset(
new paddle_infer::services::PredictorPool(config, opts.thread_num));
if (pool == nullptr) {
LOG(ERROR) << "create the predictor pool failed";
}
pool_usages.resize(opts.thread_num);
std::fill(pool_usages.begin(), pool_usages.end(), false);
LOG(INFO) << "load paddle model success";
LOG(INFO) << "start to check the predictor input and output names";
LOG(INFO) << "input names: " << opts.input_names;
LOG(INFO) << "output names: " << opts.output_names;
std::vector<std::string> input_names_vec = StrSplit(opts.input_names, ",");
std::vector<std::string> output_names_vec = StrSplit(opts.output_names, ",");
paddle_infer::Predictor* predictor = GetPredictor();
std::vector<std::string> model_input_names = predictor->GetInputNames();
assert(input_names_vec.size() == model_input_names.size());
for (size_t i = 0; i < model_input_names.size(); i++) {
assert(input_names_vec[i] == model_input_names[i]);
}
std::vector<std::string> model_output_names = predictor->GetOutputNames();
assert(output_names_vec.size() == model_output_names.size());
for (size_t i = 0; i < output_names_vec.size(); i++) {
assert(output_names_vec[i] == model_output_names[i]);
}
ReleasePredictor(predictor);
InitCacheEncouts(opts);
}
void PaddleNnet::Reset() { InitCacheEncouts(opts_); }
paddle_infer::Predictor* PaddleNnet::GetPredictor() {
paddle_infer::Predictor* predictor = nullptr;
std::lock_guard<std::mutex> guard(pool_mutex);
int pred_id = 0;
while (pred_id < pool_usages.size()) {
if (pool_usages[pred_id] == false) {
predictor = pool->Retrive(pred_id);
break;
}
++pred_id;
}
if (predictor) {
pool_usages[pred_id] = true;
predictor_to_thread_id[predictor] = pred_id;
} else {
LOG(INFO) << "Failed to get predictor from pool !!!";
}
return predictor;
}
int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
std::lock_guard<std::mutex> guard(pool_mutex);
auto iter = predictor_to_thread_id.find(predictor);
if (iter == predictor_to_thread_id.end()) {
LOG(INFO) << "there is no such predictor";
return 0;
}
pool_usages[iter->second] = false;
predictor_to_thread_id.erase(predictor);
return 0;
}
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
auto iter = cache_names_idx_.find(name);
if (iter == cache_names_idx_.end()) {
return nullptr;
}
assert(iter->second < cache_encouts_.size());
return cache_encouts_[iter->second];
}
void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) {
paddle_infer::Predictor* predictor = GetPredictor();
int feat_row = features.Dim() / feature_dim;
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
// feed inputs
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, feat_row, feature_dim};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(features.Data());
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(feat_row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> state_h =
predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
state_h->Reshape(h_cache->get_shape());
state_h->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> state_c =
predictor->GetInputHandle(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
state_c->Reshape(c_cache->get_shape());
state_c->CopyFromCpu(c_cache->get_data().data());
// forward
bool success = predictor->Run();
if (success == false) {
LOG(INFO) << "predictor run occurs error";
}
// fetch outpus
std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
assert(c_cache->get_shape() == c_out->shape());
c_out->CopyToCpu(c_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
int32 row = output_shape[1];
int32 col = output_shape[2];
// inferences->Resize(row * col);
// *inference_dim = col;
out->logprobs.Resize(row * col);
out->vocab_dim = col;
output_tensor->CopyToCpu(out->logprobs.Data());
ReleasePredictor(predictor);
}
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <numeric>
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "paddle_inference_api.h"
namespace ppspeech {
template <typename T>
class Tensor {
public:
Tensor() {}
explicit Tensor(const std::vector<int>& shape) : _shape(shape) {
int neml = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
LOG(INFO) << "Tensor neml: " << neml;
_data.resize(neml, 0);
}
void reshape(const std::vector<int>& shape) {
_shape = shape;
int neml = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
_data.resize(neml, 0);
}
const std::vector<int>& get_shape() const { return _shape; }
std::vector<T>& get_data() { return _data; }
private:
std::vector<int> _shape;
std::vector<T> _data;
};
class PaddleNnet : public NnetBase {
public:
explicit PaddleNnet(const ModelOptions& opts);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) override;
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override {
VLOG(2) << "deepspeech2 not has AttentionRescoring.";
}
void Dim();
void Reset() override;
bool IsLogProb() override { return false; }
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
void InitCacheEncouts(const ModelOptions& opts);
void EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out)
const override {}
private:
paddle_infer::Predictor* GetPredictor();
int ReleasePredictor(paddle_infer::Predictor* predictor);
std::unique_ptr<paddle_infer::services::PredictorPool> pool;
std::vector<bool> pool_usages;
std::mutex pool_mutex;
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
std::map<std::string, int> cache_names_idx_;
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
ModelOptions opts_;
public:
DISALLOW_COPY_AND_ASSIGN(PaddleNnet);
};
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_writer(FLAGS_nnet_prob_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
int32 num_done = 0, num_err = 0;
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
int32 frame_idx = 0;
std::vector<kaldi::Vector<kaldi::BaseFloat>> prob_vec;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
vector<kaldi::BaseFloat> prob;
while (decodable->FrameLikelihood(frame_idx, &prob)) {
kaldi::Vector<kaldi::BaseFloat> vec_tmp(prob.size());
std::memcpy(vec_tmp.Data(),
prob.data(),
sizeof(kaldi::BaseFloat) * prob.size());
prob_vec.push_back(vec_tmp);
frame_idx++;
}
}
decodable->Reset();
if (prob_vec.size() == 0) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the nnet prob of " << utt << " is empty";
continue;
}
kaldi::Matrix<kaldi::BaseFloat> result(prob_vec.size(),
prob_vec[0].Dim());
for (int row_idx = 0; row_idx < prob_vec.size(); ++row_idx) {
for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) {
result(row_idx, col_idx) = prob_vec[row_idx](col_idx);
}
}
nnet_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
...@@ -65,7 +65,6 @@ bool NnetProducer::Compute() { ...@@ -65,7 +65,6 @@ bool NnetProducer::Compute() {
size_t nframes = logprobs.Dim() / vocab_dim; size_t nframes = logprobs.Dim() / vocab_dim;
VLOG(2) << "Forward out " << nframes << " decoder frames."; VLOG(2) << "Forward out " << nframes << " decoder frames.";
std::vector<BaseFloat> logprob(vocab_dim); std::vector<BaseFloat> logprob(vocab_dim);
// remove later.
for (size_t idx = 0; idx < nframes; ++idx) { for (size_t idx = 0; idx < nframes; ++idx) {
for (size_t prob_idx = 0; prob_idx < vocab_dim; ++prob_idx) { for (size_t prob_idx = 0; prob_idx < vocab_dim; ++prob_idx) {
logprob[prob_idx] = logprobs(idx * vocab_dim + prob_idx); logprob[prob_idx] = logprobs(idx * vocab_dim + prob_idx);
......
set(srcs) set(srcs)
if (USING_DS2)
list(APPEND srcs list(APPEND srcs
recognizer.cc
)
endif()
if (USING_U2)
list(APPEND srcs
u2_recognizer.cc u2_recognizer.cc
) )
endif()
add_library(recognizer STATIC ${srcs}) add_library(recognizer STATIC ${srcs})
target_link_libraries(recognizer PUBLIC decoder) target_link_libraries(recognizer PUBLIC decoder)
# test set(TEST_BINS
if (USING_DS2)
set(BINS recognizer_main)
foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach()
endif()
if (USING_U2)
set(TEST_BINS
u2_recognizer_main u2_recognizer_main
u2_recognizer_thread_main u2_recognizer_thread_main
) )
foreach(bin_name IN LISTS TEST_BINS) foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach() endforeach()
\ No newline at end of file
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
Recognizer::Recognizer(const RecognizerResource& resource) {
// resource_ = resource;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
BaseFloat ac_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
decoder_.reset(new TLGDecoder(resource.tlg_opts));
input_finished_ = false;
}
void Recognizer::Accept(const Vector<BaseFloat>& waves) {
feature_pipeline_->Accept(waves);
}
void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); }
std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
}
std::string Recognizer::GetPartialResult() {
return decoder_->GetPartialResult();
}
void Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
bool Recognizer::IsFinished() { return input_finished_; }
void Recognizer::Reset() {
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
}
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
DECLARE_double(acoustic_scale);
namespace ppspeech {
struct RecognizerResource {
kaldi::BaseFloat acoustic_scale{1.0};
FeaturePipelineOptions feature_pipeline_opts{};
ModelOptions model_opts{};
TLGDecoderOptions tlg_opts{};
// CTCBeamSearchOptions beam_search_opts;
static RecognizerResource InitFromFlags() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts =
FeaturePipelineOptions::InitFromFlags();
resource.feature_pipeline_opts.assembler_opts.fill_zero = true;
LOG(INFO) << "ds2 need fill zero be true: "
<< resource.feature_pipeline_opts.assembler_opts.fill_zero;
resource.model_opts = ModelOptions::InitFromFlags();
resource.tlg_opts = TLGDecoderOptions::InitFromFlags();
return resource;
}
};
class Recognizer {
public:
explicit Recognizer(const RecognizerResource& resouce);
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
void Decode();
std::string GetFinalResult();
std::string GetPartialResult();
void SetFinished();
bool IsFinished();
void Reset();
private:
// std::shared_ptr<RecognizerResource> resource_;
// RecognizerResource resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<TLGDecoder> decoder_;
bool input_finished_;
};
} // namespace ppspeech
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "recognizer/recognizer.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
ppspeech::RecognizerResource resource =
ppspeech::RecognizerResource::InitFromFlags();
ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
kaldi::Timer timer;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
tot_wav_duration += tot_samples * 1.0 / sample_rate;
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
}
recognizer.Decode();
// no overlap
sample_offset += cur_chunk_size;
}
std::string result;
result = recognizer.GetFinalResult();
recognizer.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done);
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s";
KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration;
}
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(glog) add_subdirectory(glog)
add_subdirectory(nnet)
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ds2_model_test_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// deepspeech2 online model info
#include <algorithm>
#include <fstream>
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <thread>
#include "base/flags.h"
#include "base/log.h"
#include "paddle_inference_api.h"
using std::cout;
using std::endl;
DEFINE_string(model_path, "", "xxx.pdmodel");
DEFINE_string(param_path, "", "xxx.pdiparams");
DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame");
DEFINE_int32(feat_dim, 161, "feature dim");
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = FLAGS_chunk_size; // chunk_size in frame
int col_size = FLAGS_feat_dim; // feat dim
cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;
;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
CHECK_NE(model_graph, "");
CHECK_NE(model_params, "");
cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl;
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
cout << "SwitchIrOptim: " << false << endl;
config.DisableFCPadding();
cout << "DisableFCPadding: " << endl;
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size()
<< ",col=" << feats[0].size() << std::endl;
std::vector<float> pp_input_mat;
for (const auto& item : feats) {
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
}
std::cout << "4. fead the data to model" << std::endl;
int row = feats.size();
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
for (auto name : input_names) {
cout << "model input names: " << name << endl;
}
for (auto name : output_names) {
cout << "model output names: " << name << endl;
}
// input
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(pp_input_mat.data());
// input length
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
// state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {5, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(),
chunk_state_h_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
// state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {5, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(),
chunk_state_c_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
// run
bool success = predictor->Run();
// state_h out
std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
// stage_c out
std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
// output tensor
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
col = output_shape[2];
// probs
std::vector<std::vector<float>> probs;
probs.reserve(row);
for (int i = 0; i < row; i++) {
probs.push_back(std::vector<float>());
probs.back().reserve(col);
for (int j = 0; j < col; j++) {
probs.back().push_back(output_probs[i * col + j]);
}
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size()
<< " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
model_forward_test();
return 0;
}
...@@ -20,15 +20,12 @@ ...@@ -20,15 +20,12 @@
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h" #include "utils/file_utils.h"
// #include "boost/json.hpp" #include "utils/picojson.h"
#include <boost/json/src.hpp>
DEFINE_string(json_file, "", "cmvn json file"); DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)"); DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace boost::json; // from <boost/json.hpp>
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:"); gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
...@@ -40,36 +37,49 @@ int main(int argc, char* argv[]) { ...@@ -40,36 +37,49 @@ int main(int argc, char* argv[]) {
auto ifs = std::ifstream(FLAGS_json_file); auto ifs = std::ifstream(FLAGS_json_file);
std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file); std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file);
auto value = boost::json::parse(json_str); picojson::value value;
if (!value.is_object()) { std::string err;
const char* json_end = picojson::parse(
value, json_str.c_str(), json_str.c_str() + json_str.size(), &err);
if (!value.is<picojson::object>()) {
LOG(ERROR) << "Input json file format error."; LOG(ERROR) << "Input json file format error.";
} }
for (auto obj : value.as_object()) { const picojson::value::object& obj = value.get<picojson::object>();
if (obj.key() == "mean_stat") { for (picojson::value::object::const_iterator elem = obj.begin();
VLOG(2) << "mean_stat:" << obj.value(); elem != obj.end();
++elem) {
if (elem->first == "mean_stat") {
VLOG(2) << "mean_stat:" << elem->second;
// const picojson::value tmp =
// elem->second.get(0);//<picojson::array>();
double tmp =
elem->second.get(0).get<double>(); //<picojson::array>();
VLOG(2) << "tmp: " << tmp;
} }
if (obj.key() == "var_stat") { if (elem->first == "var_stat") {
VLOG(2) << "var_stat: " << obj.value(); VLOG(2) << "var_stat: " << elem->second;
} }
if (obj.key() == "frame_num") { if (elem->first == "frame_num") {
VLOG(2) << "frame_num: " << obj.value(); VLOG(2) << "frame_num: " << elem->second;
} }
} }
boost::json::array mean_stat = value.at("mean_stat").as_array(); const picojson::value::array& mean_stat =
value.get("mean_stat").get<picojson::array>();
std::vector<kaldi::BaseFloat> mean_stat_vec; std::vector<kaldi::BaseFloat> mean_stat_vec;
for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) { for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) {
mean_stat_vec.push_back(it->as_double()); mean_stat_vec.push_back((*it).get<double>());
} }
boost::json::array var_stat = value.at("var_stat").as_array(); const picojson::value::array& var_stat =
value.get("var_stat").get<picojson::array>();
std::vector<kaldi::BaseFloat> var_stat_vec; std::vector<kaldi::BaseFloat> var_stat_vec;
for (auto it = var_stat.begin(); it != var_stat.end(); it++) { for (auto it = var_stat.begin(); it != var_stat.end(); it++) {
var_stat_vec.push_back(it->as_double()); var_stat_vec.push_back((*it).get<double>());
} }
kaldi::int32 frame_num = uint64_t(value.at("frame_num").as_int64()); kaldi::int32 frame_num = value.get("frame_num").get<int64_t>();
LOG(INFO) << "nframe: " << frame_num; LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size(); size_t mean_size = mean_stat_vec.size();
......
/*
* Copyright 2009-2010 Cybozu Labs, Inc.
* Copyright 2011-2014 Kazuho Oku
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef picojson_h
#define picojson_h
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cstddef>
#include <iostream>
#include <iterator>
#include <limits>
#include <map>
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
#define PICOJSON_USE_INT64 1
// for isnan/isinf
#if __cplusplus >= 201103L
#include <cmath>
#else
extern "C" {
#ifdef _MSC_VER
#include <float.h>
#elif defined(__INTEL_COMPILER)
#include <mathimf.h>
#else
#include <math.h>
#endif
}
#endif
#ifndef PICOJSON_USE_RVALUE_REFERENCE
#if (defined(__cpp_rvalue_references) && __cpp_rvalue_references >= 200610) || (defined(_MSC_VER) && _MSC_VER >= 1600)
#define PICOJSON_USE_RVALUE_REFERENCE 1
#else
#define PICOJSON_USE_RVALUE_REFERENCE 0
#endif
#endif // PICOJSON_USE_RVALUE_REFERENCE
#ifndef PICOJSON_NOEXCEPT
#if PICOJSON_USE_RVALUE_REFERENCE
#define PICOJSON_NOEXCEPT noexcept
#else
#define PICOJSON_NOEXCEPT throw()
#endif
#endif
// experimental support for int64_t (see README.mkdn for detail)
#ifdef PICOJSON_USE_INT64
#define __STDC_FORMAT_MACROS
#include <cerrno>
#if __cplusplus >= 201103L
#include <cinttypes>
#else
extern "C" {
#include <inttypes.h>
}
#endif
#endif
// to disable the use of localeconv(3), set PICOJSON_USE_LOCALE to 0
#ifndef PICOJSON_USE_LOCALE
#define PICOJSON_USE_LOCALE 1
#endif
#if PICOJSON_USE_LOCALE
extern "C" {
#include <locale.h>
}
#endif
#ifndef PICOJSON_ASSERT
#define PICOJSON_ASSERT(e) \
do { \
if (!(e)) \
throw std::runtime_error(#e); \
} while (0)
#endif
#ifdef _MSC_VER
#define SNPRINTF _snprintf_s
#pragma warning(push)
#pragma warning(disable : 4244) // conversion from int to char
#pragma warning(disable : 4127) // conditional expression is constant
#pragma warning(disable : 4702) // unreachable code
#pragma warning(disable : 4706) // assignment within conditional expression
#else
#define SNPRINTF snprintf
#endif
namespace picojson {
enum {
null_type,
boolean_type,
number_type,
string_type,
array_type,
object_type
#ifdef PICOJSON_USE_INT64
,
int64_type
#endif
};
enum { INDENT_WIDTH = 2, DEFAULT_MAX_DEPTHS = 100 };
struct null {};
class value {
public:
typedef std::vector<value> array;
typedef std::map<std::string, value> object;
union _storage {
bool boolean_;
double number_;
#ifdef PICOJSON_USE_INT64
int64_t int64_;
#endif
std::string *string_;
array *array_;
object *object_;
};
protected:
int type_;
_storage u_;
public:
value();
value(int type, bool);
explicit value(bool b);
#ifdef PICOJSON_USE_INT64
explicit value(int64_t i);
#endif
explicit value(double n);
explicit value(const std::string &s);
explicit value(const array &a);
explicit value(const object &o);
#if PICOJSON_USE_RVALUE_REFERENCE
explicit value(std::string &&s);
explicit value(array &&a);
explicit value(object &&o);
#endif
explicit value(const char *s);
value(const char *s, size_t len);
~value();
value(const value &x);
value &operator=(const value &x);
#if PICOJSON_USE_RVALUE_REFERENCE
value(value &&x) PICOJSON_NOEXCEPT;
value &operator=(value &&x) PICOJSON_NOEXCEPT;
#endif
void swap(value &x) PICOJSON_NOEXCEPT;
template <typename T> bool is() const;
template <typename T> const T &get() const;
template <typename T> T &get();
template <typename T> void set(const T &);
#if PICOJSON_USE_RVALUE_REFERENCE
template <typename T> void set(T &&);
#endif
bool evaluate_as_boolean() const;
const value &get(const size_t idx) const;
const value &get(const std::string &key) const;
value &get(const size_t idx);
value &get(const std::string &key);
bool contains(const size_t idx) const;
bool contains(const std::string &key) const;
std::string to_str() const;
template <typename Iter> void serialize(Iter os, bool prettify = false) const;
std::string serialize(bool prettify = false) const;
private:
template <typename T> value(const T *); // intentionally defined to block implicit conversion of pointer to bool
template <typename Iter> static void _indent(Iter os, int indent);
template <typename Iter> void _serialize(Iter os, int indent) const;
std::string _serialize(int indent) const;
void clear();
};
typedef value::array array;
typedef value::object object;
inline value::value() : type_(null_type), u_() {
}
inline value::value(int type, bool) : type_(type), u_() {
switch (type) {
#define INIT(p, v) \
case p##type: \
u_.p = v; \
break
INIT(boolean_, false);
INIT(number_, 0.0);
#ifdef PICOJSON_USE_INT64
INIT(int64_, 0);
#endif
INIT(string_, new std::string());
INIT(array_, new array());
INIT(object_, new object());
#undef INIT
default:
break;
}
}
inline value::value(bool b) : type_(boolean_type), u_() {
u_.boolean_ = b;
}
#ifdef PICOJSON_USE_INT64
inline value::value(int64_t i) : type_(int64_type), u_() {
u_.int64_ = i;
}
#endif
inline value::value(double n) : type_(number_type), u_() {
if (
#ifdef _MSC_VER
!_finite(n)
#elif __cplusplus >= 201103L
std::isnan(n) || std::isinf(n)
#else
isnan(n) || isinf(n)
#endif
) {
throw std::overflow_error("");
}
u_.number_ = n;
}
inline value::value(const std::string &s) : type_(string_type), u_() {
u_.string_ = new std::string(s);
}
inline value::value(const array &a) : type_(array_type), u_() {
u_.array_ = new array(a);
}
inline value::value(const object &o) : type_(object_type), u_() {
u_.object_ = new object(o);
}
#if PICOJSON_USE_RVALUE_REFERENCE
inline value::value(std::string &&s) : type_(string_type), u_() {
u_.string_ = new std::string(std::move(s));
}
inline value::value(array &&a) : type_(array_type), u_() {
u_.array_ = new array(std::move(a));
}
inline value::value(object &&o) : type_(object_type), u_() {
u_.object_ = new object(std::move(o));
}
#endif
inline value::value(const char *s) : type_(string_type), u_() {
u_.string_ = new std::string(s);
}
inline value::value(const char *s, size_t len) : type_(string_type), u_() {
u_.string_ = new std::string(s, len);
}
inline void value::clear() {
switch (type_) {
#define DEINIT(p) \
case p##type: \
delete u_.p; \
break
DEINIT(string_);
DEINIT(array_);
DEINIT(object_);
#undef DEINIT
default:
break;
}
}
inline value::~value() {
clear();
}
inline value::value(const value &x) : type_(x.type_), u_() {
switch (type_) {
#define INIT(p, v) \
case p##type: \
u_.p = v; \
break
INIT(string_, new std::string(*x.u_.string_));
INIT(array_, new array(*x.u_.array_));
INIT(object_, new object(*x.u_.object_));
#undef INIT
default:
u_ = x.u_;
break;
}
}
inline value &value::operator=(const value &x) {
if (this != &x) {
value t(x);
swap(t);
}
return *this;
}
#if PICOJSON_USE_RVALUE_REFERENCE
inline value::value(value &&x) PICOJSON_NOEXCEPT : type_(null_type), u_() {
swap(x);
}
inline value &value::operator=(value &&x) PICOJSON_NOEXCEPT {
swap(x);
return *this;
}
#endif
inline void value::swap(value &x) PICOJSON_NOEXCEPT {
std::swap(type_, x.type_);
std::swap(u_, x.u_);
}
#define IS(ctype, jtype) \
template <> inline bool value::is<ctype>() const { \
return type_ == jtype##_type; \
}
IS(null, null)
IS(bool, boolean)
#ifdef PICOJSON_USE_INT64
IS(int64_t, int64)
#endif
IS(std::string, string)
IS(array, array)
IS(object, object)
#undef IS
template <> inline bool value::is<double>() const {
return type_ == number_type
#ifdef PICOJSON_USE_INT64
|| type_ == int64_type
#endif
;
}
#define GET(ctype, var) \
template <> inline const ctype &value::get<ctype>() const { \
PICOJSON_ASSERT("type mismatch! call is<type>() before get<type>()" && is<ctype>()); \
return var; \
} \
template <> inline ctype &value::get<ctype>() { \
PICOJSON_ASSERT("type mismatch! call is<type>() before get<type>()" && is<ctype>()); \
return var; \
}
GET(bool, u_.boolean_)
GET(std::string, *u_.string_)
GET(array, *u_.array_)
GET(object, *u_.object_)
#ifdef PICOJSON_USE_INT64
GET(double,
(type_ == int64_type && (const_cast<value *>(this)->type_ = number_type, (const_cast<value *>(this)->u_.number_ = u_.int64_)),
u_.number_))
GET(int64_t, u_.int64_)
#else
GET(double, u_.number_)
#endif
#undef GET
#define SET(ctype, jtype, setter) \
template <> inline void value::set<ctype>(const ctype &_val) { \
clear(); \
type_ = jtype##_type; \
setter \
}
SET(bool, boolean, u_.boolean_ = _val;)
SET(std::string, string, u_.string_ = new std::string(_val);)
SET(array, array, u_.array_ = new array(_val);)
SET(object, object, u_.object_ = new object(_val);)
SET(double, number, u_.number_ = _val;)
#ifdef PICOJSON_USE_INT64
SET(int64_t, int64, u_.int64_ = _val;)
#endif
#undef SET
#if PICOJSON_USE_RVALUE_REFERENCE
#define MOVESET(ctype, jtype, setter) \
template <> inline void value::set<ctype>(ctype && _val) { \
clear(); \
type_ = jtype##_type; \
setter \
}
MOVESET(std::string, string, u_.string_ = new std::string(std::move(_val));)
MOVESET(array, array, u_.array_ = new array(std::move(_val));)
MOVESET(object, object, u_.object_ = new object(std::move(_val));)
#undef MOVESET
#endif
inline bool value::evaluate_as_boolean() const {
switch (type_) {
case null_type:
return false;
case boolean_type:
return u_.boolean_;
case number_type:
return u_.number_ != 0;
#ifdef PICOJSON_USE_INT64
case int64_type:
return u_.int64_ != 0;
#endif
case string_type:
return !u_.string_->empty();
default:
return true;
}
}
inline const value &value::get(const size_t idx) const {
static value s_null;
PICOJSON_ASSERT(is<array>());
return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null;
}
inline value &value::get(const size_t idx) {
static value s_null;
PICOJSON_ASSERT(is<array>());
return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null;
}
inline const value &value::get(const std::string &key) const {
static value s_null;
PICOJSON_ASSERT(is<object>());
object::const_iterator i = u_.object_->find(key);
return i != u_.object_->end() ? i->second : s_null;
}
inline value &value::get(const std::string &key) {
static value s_null;
PICOJSON_ASSERT(is<object>());
object::iterator i = u_.object_->find(key);
return i != u_.object_->end() ? i->second : s_null;
}
inline bool value::contains(const size_t idx) const {
PICOJSON_ASSERT(is<array>());
return idx < u_.array_->size();
}
inline bool value::contains(const std::string &key) const {
PICOJSON_ASSERT(is<object>());
object::const_iterator i = u_.object_->find(key);
return i != u_.object_->end();
}
inline std::string value::to_str() const {
switch (type_) {
case null_type:
return "null";
case boolean_type:
return u_.boolean_ ? "true" : "false";
#ifdef PICOJSON_USE_INT64
case int64_type: {
char buf[sizeof("-9223372036854775808")];
SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_);
return buf;
}
#endif
case number_type: {
char buf[256];
double tmp;
SNPRINTF(buf, sizeof(buf), fabs(u_.number_) < (1ULL << 53) && modf(u_.number_, &tmp) == 0 ? "%.f" : "%.17g", u_.number_);
#if PICOJSON_USE_LOCALE
char *decimal_point = localeconv()->decimal_point;
if (strcmp(decimal_point, ".") != 0) {
size_t decimal_point_len = strlen(decimal_point);
for (char *p = buf; *p != '\0'; ++p) {
if (strncmp(p, decimal_point, decimal_point_len) == 0) {
return std::string(buf, p) + "." + (p + decimal_point_len);
}
}
}
#endif
return buf;
}
case string_type:
return *u_.string_;
case array_type:
return "array";
case object_type:
return "object";
default:
PICOJSON_ASSERT(0);
#ifdef _MSC_VER
__assume(0);
#endif
}
return std::string();
}
template <typename Iter> void copy(const std::string &s, Iter oi) {
std::copy(s.begin(), s.end(), oi);
}
template <typename Iter> struct serialize_str_char {
Iter oi;
void operator()(char c) {
switch (c) {
#define MAP(val, sym) \
case val: \
copy(sym, oi); \
break
MAP('"', "\\\"");
MAP('\\', "\\\\");
MAP('/', "\\/");
MAP('\b', "\\b");
MAP('\f', "\\f");
MAP('\n', "\\n");
MAP('\r', "\\r");
MAP('\t', "\\t");
#undef MAP
default:
if (static_cast<unsigned char>(c) < 0x20 || c == 0x7f) {
char buf[7];
SNPRINTF(buf, sizeof(buf), "\\u%04x", c & 0xff);
copy(buf, buf + 6, oi);
} else {
*oi++ = c;
}
break;
}
}
};
template <typename Iter> void serialize_str(const std::string &s, Iter oi) {
*oi++ = '"';
serialize_str_char<Iter> process_char = {oi};
std::for_each(s.begin(), s.end(), process_char);
*oi++ = '"';
}
template <typename Iter> void value::serialize(Iter oi, bool prettify) const {
return _serialize(oi, prettify ? 0 : -1);
}
inline std::string value::serialize(bool prettify) const {
return _serialize(prettify ? 0 : -1);
}
template <typename Iter> void value::_indent(Iter oi, int indent) {
*oi++ = '\n';
for (int i = 0; i < indent * INDENT_WIDTH; ++i) {
*oi++ = ' ';
}
}
template <typename Iter> void value::_serialize(Iter oi, int indent) const {
switch (type_) {
case string_type:
serialize_str(*u_.string_, oi);
break;
case array_type: {
*oi++ = '[';
if (indent != -1) {
++indent;
}
for (array::const_iterator i = u_.array_->begin(); i != u_.array_->end(); ++i) {
if (i != u_.array_->begin()) {
*oi++ = ',';
}
if (indent != -1) {
_indent(oi, indent);
}
i->_serialize(oi, indent);
}
if (indent != -1) {
--indent;
if (!u_.array_->empty()) {
_indent(oi, indent);
}
}
*oi++ = ']';
break;
}
case object_type: {
*oi++ = '{';
if (indent != -1) {
++indent;
}
for (object::const_iterator i = u_.object_->begin(); i != u_.object_->end(); ++i) {
if (i != u_.object_->begin()) {
*oi++ = ',';
}
if (indent != -1) {
_indent(oi, indent);
}
serialize_str(i->first, oi);
*oi++ = ':';
if (indent != -1) {
*oi++ = ' ';
}
i->second._serialize(oi, indent);
}
if (indent != -1) {
--indent;
if (!u_.object_->empty()) {
_indent(oi, indent);
}
}
*oi++ = '}';
break;
}
default:
copy(to_str(), oi);
break;
}
if (indent == 0) {
*oi++ = '\n';
}
}
inline std::string value::_serialize(int indent) const {
std::string s;
_serialize(std::back_inserter(s), indent);
return s;
}
template <typename Iter> class input {
protected:
Iter cur_, end_;
bool consumed_;
int line_;
public:
input(const Iter &first, const Iter &last) : cur_(first), end_(last), consumed_(false), line_(1) {
}
int getc() {
if (consumed_) {
if (*cur_ == '\n') {
++line_;
}
++cur_;
}
if (cur_ == end_) {
consumed_ = false;
return -1;
}
consumed_ = true;
return *cur_ & 0xff;
}
void ungetc() {
consumed_ = false;
}
Iter cur() const {
if (consumed_) {
input<Iter> *self = const_cast<input<Iter> *>(this);
self->consumed_ = false;
++self->cur_;
}
return cur_;
}
int line() const {
return line_;
}
void skip_ws() {
while (1) {
int ch = getc();
if (!(ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) {
ungetc();
break;
}
}
}
bool expect(const int expected) {
skip_ws();
if (getc() != expected) {
ungetc();
return false;
}
return true;
}
bool match(const std::string &pattern) {
for (std::string::const_iterator pi(pattern.begin()); pi != pattern.end(); ++pi) {
if (getc() != *pi) {
ungetc();
return false;
}
}
return true;
}
};
template <typename Iter> inline int _parse_quadhex(input<Iter> &in) {
int uni_ch = 0, hex;
for (int i = 0; i < 4; i++) {
if ((hex = in.getc()) == -1) {
return -1;
}
if ('0' <= hex && hex <= '9') {
hex -= '0';
} else if ('A' <= hex && hex <= 'F') {
hex -= 'A' - 0xa;
} else if ('a' <= hex && hex <= 'f') {
hex -= 'a' - 0xa;
} else {
in.ungetc();
return -1;
}
uni_ch = uni_ch * 16 + hex;
}
return uni_ch;
}
template <typename String, typename Iter> inline bool _parse_codepoint(String &out, input<Iter> &in) {
int uni_ch;
if ((uni_ch = _parse_quadhex(in)) == -1) {
return false;
}
if (0xd800 <= uni_ch && uni_ch <= 0xdfff) {
if (0xdc00 <= uni_ch) {
// a second 16-bit of a surrogate pair appeared
return false;
}
// first 16-bit of surrogate pair, get the next one
if (in.getc() != '\\' || in.getc() != 'u') {
in.ungetc();
return false;
}
int second = _parse_quadhex(in);
if (!(0xdc00 <= second && second <= 0xdfff)) {
return false;
}
uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff);
uni_ch += 0x10000;
}
if (uni_ch < 0x80) {
out.push_back(static_cast<char>(uni_ch));
} else {
if (uni_ch < 0x800) {
out.push_back(static_cast<char>(0xc0 | (uni_ch >> 6)));
} else {
if (uni_ch < 0x10000) {
out.push_back(static_cast<char>(0xe0 | (uni_ch >> 12)));
} else {
out.push_back(static_cast<char>(0xf0 | (uni_ch >> 18)));
out.push_back(static_cast<char>(0x80 | ((uni_ch >> 12) & 0x3f)));
}
out.push_back(static_cast<char>(0x80 | ((uni_ch >> 6) & 0x3f)));
}
out.push_back(static_cast<char>(0x80 | (uni_ch & 0x3f)));
}
return true;
}
template <typename String, typename Iter> inline bool _parse_string(String &out, input<Iter> &in) {
while (1) {
int ch = in.getc();
if (ch < ' ') {
in.ungetc();
return false;
} else if (ch == '"') {
return true;
} else if (ch == '\\') {
if ((ch = in.getc()) == -1) {
return false;
}
switch (ch) {
#define MAP(sym, val) \
case sym: \
out.push_back(val); \
break
MAP('"', '\"');
MAP('\\', '\\');
MAP('/', '/');
MAP('b', '\b');
MAP('f', '\f');
MAP('n', '\n');
MAP('r', '\r');
MAP('t', '\t');
#undef MAP
case 'u':
if (!_parse_codepoint(out, in)) {
return false;
}
break;
default:
return false;
}
} else {
out.push_back(static_cast<char>(ch));
}
}
return false;
}
template <typename Context, typename Iter> inline bool _parse_array(Context &ctx, input<Iter> &in) {
if (!ctx.parse_array_start()) {
return false;
}
size_t idx = 0;
if (in.expect(']')) {
return ctx.parse_array_stop(idx);
}
do {
if (!ctx.parse_array_item(in, idx)) {
return false;
}
idx++;
} while (in.expect(','));
return in.expect(']') && ctx.parse_array_stop(idx);
}
template <typename Context, typename Iter> inline bool _parse_object(Context &ctx, input<Iter> &in) {
if (!ctx.parse_object_start()) {
return false;
}
if (in.expect('}')) {
return ctx.parse_object_stop();
}
do {
std::string key;
if (!in.expect('"') || !_parse_string(key, in) || !in.expect(':')) {
return false;
}
if (!ctx.parse_object_item(in, key)) {
return false;
}
} while (in.expect(','));
return in.expect('}') && ctx.parse_object_stop();
}
template <typename Iter> inline std::string _parse_number(input<Iter> &in) {
std::string num_str;
while (1) {
int ch = in.getc();
if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == 'e' || ch == 'E') {
num_str.push_back(static_cast<char>(ch));
} else if (ch == '.') {
#if PICOJSON_USE_LOCALE
num_str += localeconv()->decimal_point;
#else
num_str.push_back('.');
#endif
} else {
in.ungetc();
break;
}
}
return num_str;
}
template <typename Context, typename Iter> inline bool _parse(Context &ctx, input<Iter> &in) {
in.skip_ws();
int ch = in.getc();
switch (ch) {
#define IS(ch, text, op) \
case ch: \
if (in.match(text) && op) { \
return true; \
} else { \
return false; \
}
IS('n', "ull", ctx.set_null());
IS('f', "alse", ctx.set_bool(false));
IS('t', "rue", ctx.set_bool(true));
#undef IS
case '"':
return ctx.parse_string(in);
case '[':
return _parse_array(ctx, in);
case '{':
return _parse_object(ctx, in);
default:
if (('0' <= ch && ch <= '9') || ch == '-') {
double f;
char *endp;
in.ungetc();
std::string num_str(_parse_number(in));
if (num_str.empty()) {
return false;
}
#ifdef PICOJSON_USE_INT64
{
errno = 0;
intmax_t ival = strtoimax(num_str.c_str(), &endp, 10);
if (errno == 0 && std::numeric_limits<int64_t>::min() <= ival && ival <= std::numeric_limits<int64_t>::max() &&
endp == num_str.c_str() + num_str.size()) {
ctx.set_int64(ival);
return true;
}
}
#endif
f = strtod(num_str.c_str(), &endp);
if (endp == num_str.c_str() + num_str.size()) {
ctx.set_number(f);
return true;
}
return false;
}
break;
}
in.ungetc();
return false;
}
class deny_parse_context {
public:
bool set_null() {
return false;
}
bool set_bool(bool) {
return false;
}
#ifdef PICOJSON_USE_INT64
bool set_int64(int64_t) {
return false;
}
#endif
bool set_number(double) {
return false;
}
template <typename Iter> bool parse_string(input<Iter> &) {
return false;
}
bool parse_array_start() {
return false;
}
template <typename Iter> bool parse_array_item(input<Iter> &, size_t) {
return false;
}
bool parse_array_stop(size_t) {
return false;
}
bool parse_object_start() {
return false;
}
template <typename Iter> bool parse_object_item(input<Iter> &, const std::string &) {
return false;
}
};
class default_parse_context {
protected:
value *out_;
size_t depths_;
public:
default_parse_context(value *out, size_t depths = DEFAULT_MAX_DEPTHS) : out_(out), depths_(depths) {
}
bool set_null() {
*out_ = value();
return true;
}
bool set_bool(bool b) {
*out_ = value(b);
return true;
}
#ifdef PICOJSON_USE_INT64
bool set_int64(int64_t i) {
*out_ = value(i);
return true;
}
#endif
bool set_number(double f) {
*out_ = value(f);
return true;
}
template <typename Iter> bool parse_string(input<Iter> &in) {
*out_ = value(string_type, false);
return _parse_string(out_->get<std::string>(), in);
}
bool parse_array_start() {
if (depths_ == 0)
return false;
--depths_;
*out_ = value(array_type, false);
return true;
}
template <typename Iter> bool parse_array_item(input<Iter> &in, size_t) {
array &a = out_->get<array>();
a.push_back(value());
default_parse_context ctx(&a.back(), depths_);
return _parse(ctx, in);
}
bool parse_array_stop(size_t) {
++depths_;
return true;
}
bool parse_object_start() {
if (depths_ == 0)
return false;
*out_ = value(object_type, false);
return true;
}
template <typename Iter> bool parse_object_item(input<Iter> &in, const std::string &key) {
object &o = out_->get<object>();
default_parse_context ctx(&o[key], depths_);
return _parse(ctx, in);
}
bool parse_object_stop() {
++depths_;
return true;
}
private:
default_parse_context(const default_parse_context &);
default_parse_context &operator=(const default_parse_context &);
};
class null_parse_context {
protected:
size_t depths_;
public:
struct dummy_str {
void push_back(int) {
}
};
public:
null_parse_context(size_t depths = DEFAULT_MAX_DEPTHS) : depths_(depths) {
}
bool set_null() {
return true;
}
bool set_bool(bool) {
return true;
}
#ifdef PICOJSON_USE_INT64
bool set_int64(int64_t) {
return true;
}
#endif
bool set_number(double) {
return true;
}
template <typename Iter> bool parse_string(input<Iter> &in) {
dummy_str s;
return _parse_string(s, in);
}
bool parse_array_start() {
if (depths_ == 0)
return false;
--depths_;
return true;
}
template <typename Iter> bool parse_array_item(input<Iter> &in, size_t) {
return _parse(*this, in);
}
bool parse_array_stop(size_t) {
++depths_;
return true;
}
bool parse_object_start() {
if (depths_ == 0)
return false;
--depths_;
return true;
}
template <typename Iter> bool parse_object_item(input<Iter> &in, const std::string &) {
++depths_;
return _parse(*this, in);
}
bool parse_object_stop() {
return true;
}
private:
null_parse_context(const null_parse_context &);
null_parse_context &operator=(const null_parse_context &);
};
// obsolete, use the version below
template <typename Iter> inline std::string parse(value &out, Iter &pos, const Iter &last) {
std::string err;
pos = parse(out, pos, last, &err);
return err;
}
template <typename Context, typename Iter> inline Iter _parse(Context &ctx, const Iter &first, const Iter &last, std::string *err) {
input<Iter> in(first, last);
if (!_parse(ctx, in) && err != NULL) {
char buf[64];
SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line());
*err = buf;
while (1) {
int ch = in.getc();
if (ch == -1 || ch == '\n') {
break;
} else if (ch >= ' ') {
err->push_back(static_cast<char>(ch));
}
}
}
return in.cur();
}
template <typename Iter> inline Iter parse(value &out, const Iter &first, const Iter &last, std::string *err) {
default_parse_context ctx(&out);
return _parse(ctx, first, last, err);
}
inline std::string parse(value &out, const std::string &s) {
std::string err;
parse(out, s.begin(), s.end(), &err);
return err;
}
inline std::string parse(value &out, std::istream &is) {
std::string err;
parse(out, std::istreambuf_iterator<char>(is.rdbuf()), std::istreambuf_iterator<char>(), &err);
return err;
}
template <typename T> struct last_error_t { static std::string s; };
template <typename T> std::string last_error_t<T>::s;
inline void set_last_error(const std::string &s) {
last_error_t<bool>::s = s;
}
inline const std::string &get_last_error() {
return last_error_t<bool>::s;
}
inline bool operator==(const value &x, const value &y) {
if (x.is<null>())
return y.is<null>();
#define PICOJSON_CMP(type) \
if (x.is<type>()) \
return y.is<type>() && x.get<type>() == y.get<type>()
PICOJSON_CMP(bool);
PICOJSON_CMP(double);
PICOJSON_CMP(std::string);
PICOJSON_CMP(array);
PICOJSON_CMP(object);
#undef PICOJSON_CMP
PICOJSON_ASSERT(0);
#ifdef _MSC_VER
__assume(0);
#endif
return false;
}
inline bool operator!=(const value &x, const value &y) {
return !(x == y);
}
}
#if !PICOJSON_USE_RVALUE_REFERENCE
namespace std {
template <> inline void swap(picojson::value &x, picojson::value &y) {
x.swap(y);
}
}
#endif
inline std::istream &operator>>(std::istream &is, picojson::value &x) {
picojson::set_last_error(std::string());
const std::string err(picojson::parse(x, is));
if (!err.empty()) {
picojson::set_last_error(err);
is.setstate(std::ios::failbit);
}
return is;
}
inline std::ostream &operator<<(std::ostream &os, const picojson::value &x) {
x.serialize(std::ostream_iterator<char>(os));
return os;
}
#ifdef _MSC_VER
#pragma warning(pop)
#endif
#endif
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册