...
 
Commits (57)
    https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/869f4267d5fabbdf6f7b18515ebf33f28a755b6c [speechx]Speechx directory refactor (#2746) 2022-12-16T12:17:03+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * refactor directory https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/f8caaf46c8c35dbecb879cc2d4acea0de13bb45d refactor cmake, rm absl/linsndfile, add strings unittest (#2765) 2022-12-27T16:30:57+08:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/5046d8ee9416904fd3fec8d8d802286bc46e84b3 [Speechx] add nnet prob cache && make 2 thread decode work (#2769) 2022-12-27T19:50:59+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * add nnet cache &amp;&amp; make 2 thread work * do not compile websocket https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/acf1d27230bdeb3144dfa88da7843cb22ea0aa9c [speechx] rm ds2 && rm boost (#2786) 2022-12-30T15:54:26+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * fix openfst download error * add acknowledgments of openfst * refactor directory * clean ctc_decoders dir * add nnet cache &amp;&amp; make 2 thread work * do not compile websocket * rm ds2 &amp;&amp; rm boost * rm ds2 example https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/c1b1ae0515e4ad1216e2378366b11f7a08abee66 [speechx]add kaldi-native-fbank && refactor frontend (#2794) 2023-01-04T16:52:19+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * replace kaldi-fbank with kaldi-native-fbank * make kaldi-native-fbank work https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/ee7c266f130182d1bac4db378932784eec8b48f6 [speechx] rm openblas && refactor kaldi-matrix, kaldi-vector (#2824) 2023-01-11T20:14:48+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * rm openblas &amp;&amp; refactor kaldi-matrix kaldi-vector https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/8a225b1708507e873e29b62559bb0756419d3ebe [speechx] thread decode (#2839) 2023-01-18T16:11:26+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * fix nnet thread crash &amp;&amp; rescore cost time * add nnet thread main https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/5042a1686a66679dabb5b25e0b42824db0762ad5 [speechx] add batch recognizer decode. (#2866) 2023-02-01T17:19:52+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * add recognizer_batch https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/21183d48b63009e49729da6e6864ad666c09ae4b add wfst decoder (#2886) 2023-02-07T16:46:45+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/8e1b4cd51301d667a0dcf73d9e945924380134f1 [engine] rename speechx (#2892) 2023-02-08T09:49:07+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * rename speechx * fix wfst decode error * replace reset with make_unique https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/2f8aad95e030d31b3f457eadef1580948a6839f6 Update .mergify.yml 2023-02-08T15:52:31+08:00 TianYuan white-sky@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/78e29c8ec4c12173357acf5e89281b742c444408 add cls engine (#2923) 2023-02-20T19:59:57+08:00 masimeng1994 1057010459@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/b35fc01a3a451b34655e23b9db183fb529c983e2 opt to compile asr,cls,vad; add vad; format code (#2968) 2023-02-28T10:38:28+08:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/e9da7e0e07894e4607514ec41ec290e100f78ee2 [runtime] add logging module, build on linux and android, normalize option na... 2023-03-05T15:12:35+08:00 Hui Zhang zhtclz@foxmail.com * rename option with WITH_; add logging module to replace with glog * refactor logging module; build pass on linux and andorid https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/2beb7ffce0fce64f92d417c4828b155bad73e4b6 fix asr compile bug (#2993) 2023-03-06T13:43:46+08:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/bf914a9c8b5b01c37932cfc63fae46ce3fa83928 [runtime] optimization compile and add vad interface (#3026) 2023-03-13T14:45:22+08:00 Hui Zhang zhtclz@foxmail.com * vad recipe ok * refactor vad, add vad conf, vad inerface, vad recipe * format * install vad lib/bin/inc * using cpack * add vad doc, fix vad state name * add comment * refactor fastdeploy download * add vad jni; format code * add timer; compute vad rtf; vad add beam param * andorid find library * fix log; add vad rtf * fix glog * fix BUILD_TYPE bug * update doc * rm jni https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/f0ef6f1cafad81cdf25ba7de549224009579b006 [runtime] vad jni demo (#3027) 2023-03-13T15:17:24+08:00 Hui Zhang zhtclz@foxmail.com * android vad jni demo * update src * rename https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/b9bdeca6c5cb5a08aea7737a02bfe5616e1359e4 add text blank preprocess, test=asr (#3025) 2023-03-14T14:02:27+08:00 jlqian98 49509499+jlqian98@users.noreply.github.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/f34d4ad4dfa46daaa2d71af851590ea7a2af4645 [runtime] fix vad and cls cmake (#3050) 2023-03-17T14:29:31+08:00 masimeng1994 1057010459@qq.com * fix vad and cls cmake https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/704e363a2dbf1afb154ae2187b4abd9752d0578f fix asr cmake (#3071) 2023-03-21T20:48:12+08:00 masimeng1994 1057010459@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/ab4217c2e421521442bdfabfe653fa06c4c4a594 [Engine] add TN/ITN functions (#3047) 2023-03-22T17:42:24+08:00 jlqian98 49509499+jlqian98@users.noreply.github.com * add AddBlk, ReverseFrac function * rename text processing functions https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/2be7e5725fc7d711025f9b13405021f021204c38 [engine]fix asr compile (#3078) 2023-03-22T17:55:22+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * fix asr compile * add pybind https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/767f6dd4e2eb6afa4e5d6bfb55df292f14caacea [engine] add recognizer_controller && fix build bugs (#3086) 2023-03-27T11:30:16+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * fix asr compile https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/9e5a39cae01ef4e19393c3602e06bc549e227af0 [runtime] support onnx runtime && support ios compile (#3101) 2023-03-28T19:41:52+08:00 masimeng1994 1057010459@qq.com * support vad ios compile * support onnx model recognize * add build ios sh https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/591b957b96b15eb4d0cb1cd66da69ce79cfb0a14 [runtime] fix linux && android cmake bug (#3112) 2023-03-30T14:15:33+08:00 masimeng1994 1057010459@qq.com * support vad ios compile * support onnx model recognize * add build ios sh https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/f35a87ab8901f516db34e846d3c168203857a9b2 [Engine] recognizer controller refactor (#3139) 2023-04-06T18:45:13+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * refactor recognizer_controller * clean frontend file https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/d03ebe872a04e7869eedc21ccc38176617e5e6f8 add vad interface GetVadResult (#3140) 2023-04-11T11:34:14+08:00 masimeng1994 1057010459@qq.com * add vad interface GetVadResult * fix comment https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/11ce08b260b098054035fc5273488563d1a3e911 [engine] replace onnx with fastdeploy (#3150) 2023-04-12T17:32:08+08:00 masimeng1994 1057010459@qq.com * onnxruntime change to fastdeploy https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/fbd27aab411c6a050e7cd90b42bc473ef7f23089 add amp for U2 conformer. 2023-04-17T03:24:33+00:00 zxcd 228587199@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/2f4414a5f8386aab92e807719572a80df7da74bd fix scaler save 2023-04-17T03:32:20+00:00 zxcd 228587199@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/7399d560e72dde8e18f438d0e8b942ae68c82b8a fix scaler save and load. 2023-04-17T09:06:02+00:00 zxcd 228587199@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/a1e5f27003a81f29c1d40e2dbcb91241058c5492 mv scaler.unscale_ blow grad_clip. 2023-04-17T09:32:40+00:00 zxcd 228587199@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/b05ead51d7a33fecb7ab8879bfe870483f0e0845 [engine]add recognizer api && clean params && make a shared decoder resource ... 2023-04-18T09:54:04+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/414de3747c9ac9a7422dc45f0c945d135245de6f VITS learning rate revised, test=tts 2023-04-20T11:18:06+00:00 WongLaw mailoflawrence@gmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/47e31f46cb229676c81273596958ba8672a80fa7 VITS learning rate revised, test=tts 2023-04-20T11:34:20+00:00 WongLaw mailoflawrence@gmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/5e2251afdaf1798483a8d7306af5a7507aa8776a [Engine] rename cls && add cls && vad android demo (#3188) 2023-04-21T10:56:23+08:00 masimeng1994 1057010459@qq.com * [Engine] rename cls to AudioClassification * [Engine] add android &amp;&amp; vad demo https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/fc670339d1dcc1b286daac3fb4fa3f9a8ce4d1c0 [TTS]Fix losses of StarGAN v2 VC (#3184) 2023-04-24T09:56:44+08:00 TianYuan white-sky@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/fdeb9b88a79f5a2e297666f5eb0b9a3aec63ba76 VITS learning rate revised, test=tts 2023-04-24T03:01:10+00:00 WongLaw mailoflawrence@gmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/305375c310680e59980bcc3302ebdfca7a292642 VITS learning rate revised, test=tts 2023-04-24T03:15:02+00:00 WongLaw mailoflawrence@gmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/8c2196ea0c588d616f7dcec4b1c30071f5c90690 [engine] add wfst recognizer in example (#3173) 2023-04-24T17:34:04+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com * update wfst script * add skip blank https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/6e0044be582ee12846a74d70e1ba024b11b561a3 [engine] merge develop into speechx (#3198) 2023-04-24T19:23:21+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com *merge develop into speechx https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/500f283dd6923fc1abfd240ad838d6b629c0b8f2 Revert "[engine] merge develop into speechx (#3198)" (#3199) 2023-04-24T19:28:20+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com This reverts commit <a href="/paddlepaddle/DeepSpeech/-/commit/6e0044be582ee12846a74d70e1ba024b11b561a3" data-original="6e0044be582ee12846a74d70e1ba024b11b561a3" data-link="false" data-link-reference="false" data-project="1980" data-commit="6e0044be582ee12846a74d70e1ba024b11b561a3" data-reference-type="commit" data-container="body" data-placement="top" title="[engine] merge develop into speechx (#3198)" class="gfm gfm-commit has-tooltip">6e0044be</a>. https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/a9027f18d08105262678719d810fb295fa054336 merge dev 2023-04-24T20:09:54+08:00 YangZhou goat.zhou@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/ce4af0e765cbb01dcd24e1c6c9c2ba815a82c3a9 Merge branch 'speechx' of github.com:PaddlePaddle/PaddleSpeech into speechx 2023-04-24T20:10:56+08:00 YangZhou goat.zhou@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/7cab869d63f65bd411744c22cdfa3d8eb74569f2 Merge pull request #3197 from PaddlePaddle/speechx 2023-04-25T10:16:38+08:00 Hui Zhang zhtclz@foxmail.com [engine] merge speechx https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/9d8660b2f62be245a964584c189b219b0474c35b add new aishell model for better CER. 2023-04-25T06:50:51+00:00 zxcd 228587199@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/f3d567f93bd584d4aa8577945cfdad5537c61dd1 add readme 2023-04-25T06:56:42+00:00 zxcd 228587199@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/bc365cbb52a86ba131df9cf79d7bee25e742b47b Merge branch 'develop' into amp 2023-04-25T14:58:38+08:00 zxcd 228587199@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/2e62ac8bfca0a11ae379805ff048155c32633115 Update README.md in engine 2023-04-25T15:02:47+08:00 YangZhou 56786796+SmileGoat@users.noreply.github.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/e3dcfa8815c50ebf5a144df4a975d020cffc954c Merge pull request #3186 from PaddlePaddle/vits_pr 2023-04-25T15:05:54+08:00 Hui Zhang zhtclz@foxmail.com [TTS]update lr schedulers from per iter to per epoch for VITS https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/225737d4e3318fb5a87bd86ae018aaa7e9e46975 [s2t] fix cli args to config (#3194) 2023-04-25T15:07:30+08:00 Hui Zhang zhtclz@foxmail.com * fix cli args to config * fix train cli https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/dd71e9a0e4078f760d502294c5ef402f49cef023 fix copyright, test=doc 2023-04-25T15:21:47+08:00 YangZhou goat.zhou@qq.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/8371d14f5d4fb53270693aaabc58b061d48e3734 Merge pull request #3167 from zxcd/amp 2023-04-25T15:57:05+08:00 Hui Zhang zhtclz@foxmail.com [ASR] add amp for U2 conformer https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/19a9d41c8cf35f556342b40c9a3f6e6ceae9fd9c Merge pull request #3200 from PaddlePaddle/SmileGoat-patch-1 2023-04-25T15:58:08+08:00 Hui Zhang zhtclz@foxmail.com [engine]Update README.md in engine https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/ee2c691a14c523729f3bedb52f0fb34b284c6594 Merge pull request #3201 from SmileGoat/fix_cp 2023-04-25T16:06:08+08:00 Hui Zhang zhtclz@foxmail.com [engine]fix copyright, test=doc https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/46c943f03b7f6cf387028e1e325fed24e6a25207 Update README.md 2023-04-25T16:07:47+08:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/8205343c6517a8d371a2e8c1161f5d6b43547065 Merge pull request #3202 from PaddlePaddle/zh794390558-patch-1 2023-04-25T16:07:58+08:00 Hui Zhang zhtclz@foxmail.com Update README.md
......@@ -136,7 +136,7 @@ pull_request_rules:
add: ["Docker"]
- name: "auto add label=Deployment"
conditions:
- files~=^speechx/
- files~=^runtime/
actions:
label:
add: ["Deployment"]
......@@ -3,8 +3,12 @@ repos:
rev: v0.16.0
hooks:
- id: yapf
files: \.py$
exclude: (?=third_party).*(\.py)$
name: yapf
language: python
entry: yapf
args: [-i, -vv]
types: [python]
exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: a11d9314b22d8f8c7556443875b731ef05965464
......@@ -31,7 +35,7 @@ repos:
- --ignore=E501,E228,E226,E261,E266,E128,E402,W503
- --builtins=G,request
- --jobs=1
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$
exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$
- repo : https://github.com/Lucas-C/pre-commit-hooks
rev: v1.0.1
......@@ -53,16 +57,16 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
- id: cpplint
name: cpplint
description: Static code analysis of C/C++ files
language: python
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
exclude: (?=runtime/engine/kaldi|runtime/engine/common/matrix|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$
entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:
- id: reorder-python-imports
exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$
exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$
......@@ -178,6 +178,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update
- 👑 2023.04.25: Add [AMP for U2 conformer](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167).
- 🔥 2023.04.06: Add [subtitle file (.srt format) generation example](./demos/streaming_asr_server).
- 🔥 2023.03.14: Add SVS(Singing Voice Synthesis) examples with Opencpop dataset, including [DiffSinger](./examples/opencpop/svs1)[PWGAN](./examples/opencpop/voc1) and [HiFiGAN](./examples/opencpop/voc5), the effect is continuously optimized.
- 👑 2023.03.09: Add [Wav2vec2ASR-zh](./examples/aishell/asr3).
......@@ -193,7 +194,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation.
- 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](./demos/speech_ssl), Support ASR and Feature Extraction.
- 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660).
- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](./speechx/examples/u2pp_ol/wenetspeech).
- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/runtime/examples/u2pp_ol/wenetspeech).
- 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3).
- 🔥 2022.10.26: Add [Prosody Prediction](./examples/other/rhy) for TTS.
- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
......@@ -897,7 +898,16 @@ The Text-to-Speech module is originally called [Parakeet](https://github.com/Pad
## Citation
To cite PaddleSpeech for research, please use the following format.
```text
@inproceedings{zhang2022paddlespeech,
title = {PaddleSpeech: An Easy-to-Use All-in-One Speech Toolkit},
author = {Hui Zhang, Tian Yuan, Junkun Chen, Xintong Li, Renjie Zheng, Yuxin Huang, Xiaojie Chen, Enlei Gong, Zeyu Chen, Xiaoguang Hu, dianhai yu, Yanjun Ma, Liang Huang},
booktitle = {Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies: Demonstrations},
year = {2022},
publisher = {Association for Computational Linguistics},
}
@InProceedings{pmlr-v162-bai22d,
title = {{A}$^3${T}: Alignment-Aware Acoustic and Text Pretraining for Speech Synthesis and Editing},
author = {Bai, He and Zheng, Renjie and Chen, Junkun and Ma, Mingbo and Li, Xintong and Huang, Liang},
......@@ -912,14 +922,6 @@ To cite PaddleSpeech for research, please use the following format.
url = {https://proceedings.mlr.press/v162/bai22d.html},
}
@inproceedings{zhang2022paddlespeech,
title = {PaddleSpeech: An Easy-to-Use All-in-One Speech Toolkit},
author = {Hui Zhang, Tian Yuan, Junkun Chen, Xintong Li, Renjie Zheng, Yuxin Huang, Xiaojie Chen, Enlei Gong, Zeyu Chen, Xiaoguang Hu, dianhai yu, Yanjun Ma, Liang Huang},
booktitle = {Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies: Demonstrations},
year = {2022},
publisher = {Association for Computational Linguistics},
}
@inproceedings{zheng2021fused,
title={Fused acoustic and text encoding for multimodal bilingual pretraining and speech translation},
author={Zheng, Renjie and Chen, Junkun and Ma, Mingbo and Huang, Liang},
......
......@@ -183,6 +183,7 @@
- 🧩 级联模型应用: 作为传统语音任务的扩展,我们结合了自然语言处理、计算机视觉等任务,实现更接近实际需求的产业级应用。
### 近期更新
- 👑 2023.04.25: 新增 [U2 conformer 的 AMP 训练](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167).
- 👑 2023.04.06: 新增 [srt格式字幕生成功能](./demos/streaming_asr_server)
- 🔥 2023.03.14: 新增基于 Opencpop 数据集的 SVS (歌唱合成) 示例,包含 [DiffSinger](./examples/opencpop/svs1)[PWGAN](./examples/opencpop/voc1)[HiFiGAN](./examples/opencpop/voc5),效果持续优化中。
- 👑 2023.03.09: 新增 [Wav2vec2ASR-zh](./examples/aishell/asr3)
......
......@@ -10,7 +10,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |-|
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |-|
[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz) | WenetSpeech Dataset | Char-based | 540 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python |[FP32](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz) </br>[INT8](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model.tar.gz) |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.5.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.051968 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) | inference/python |-|
......
......@@ -13,15 +13,15 @@ paddlespeech version: 1.0.1
## Conformer Streaming
paddle version: 2.2.2
paddlespeech version: 0.2.0
paddlespeech version: 1.4.1
Need set `decoding.decoding_chunk_size=16` when decoding.
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention | 16, -1 | - | 0.0551 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 0.0629 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 0.0629 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.0544 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention | 16, -1 | - | 0.056102 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 0.058160 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 0.058160 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 |
## Transformer
......
......@@ -179,7 +179,7 @@ generator_first: False # whether to start updating generator first
# OTHER TRAINING SETTING #
##########################################################
num_snapshots: 10 # max number of snapshots to keep while training
train_max_steps: 350000 # Number of training steps. == total_iters / ngpus, total_iters = 1000000
save_interval_steps: 1000 # Interval steps to save checkpoint.
eval_interval_steps: 250 # Interval steps to evaluate the network.
max_epoch: 1000 # Number of training epochs.
save_interval_epochs: 1 # Interval epochs to save checkpoint.
eval_interval_epochs: 1 # Interval steps to evaluate the network.
seed: 777 # random seed number
......@@ -8,4 +8,5 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1
--ngpu=1 \
--speaker-dict=dump/speaker_id_map.txt
......@@ -74,6 +74,9 @@ def build_vocab(manifest_paths="",
spm_vocab_size=0,
spm_model_prefix="",
spm_character_coverage=0.9995):
manifest_paths = [manifest_paths] if isinstance(manifest_paths,
str) else manifest_paths
fout = open(vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1
......
......@@ -58,6 +58,7 @@ def format_data(
unit_type="char",
vocab_path="examples/librispeech/data/vocab.txt",
spm_model_prefix=""):
manifest_paths = [manifest_paths] if isinstance(manifest_paths, str) else manifest_paths
fout = open(output_path, 'w', encoding='utf-8')
......
......@@ -228,6 +228,16 @@ asr_dynamic_pretrained_models = {
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.5.0.model.tar.gz',
'md5':
'a0adb2b204902982718bc1d8917f7038',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
},
"transformer_librispeech-en-16k": {
'1.0': {
......
......@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alignment for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments
......@@ -32,26 +32,10 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_config, config)
main(config, args)
......@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments
......@@ -32,22 +32,10 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_config, config)
main(config, args)
......@@ -15,14 +15,15 @@
import paddle
from kaldiio import ReadHelper
from paddleslim import PTQ
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
......@@ -173,32 +174,7 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_scp", type=str, help="path of the input audio file")
parser.add_argument(
"--num_utts",
type=int,
default=200,
help="num utts for quant calibrition.")
parser.add_argument(
"--export_path",
type=str,
default='export.jit.quant',
help="path of the input audio file")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
main(config, args)
......@@ -14,10 +14,10 @@
"""Evaluation for U2 model."""
import cProfile
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments
......@@ -34,27 +34,12 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_config, config)
# Setting for profiling
pr = cProfile.Profile()
......
......@@ -16,15 +16,14 @@ import os
import sys
from pathlib import Path
import distutils
import numpy as np
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
......@@ -125,27 +124,7 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="for debug.")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
main(config, args)
......@@ -15,14 +15,12 @@
import cProfile
import os
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer
from paddlespeech.s2t.training.cli import config_from_args
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.training.cli import maybe_dump_config
from paddlespeech.utils.argparse import print_arguments
# from paddlespeech.s2t.exps.u2.trainer import U2Trainer as Trainer
def main_sp(config, args):
exp = Trainer(config, args)
......@@ -39,17 +37,9 @@ if __name__ == "__main__":
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
config = config_from_args(args)
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
maybe_dump_config(args.dump_path, config)
# Setting for profiling
pr = cProfile.Profile()
......
......@@ -23,6 +23,7 @@ import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.nn.utils import clip_grad_norm_
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
......@@ -47,14 +48,16 @@ class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch_data, msg):
def train_batch(self, batch_index, batch_data, scaler, msg):
train_conf = self.config
start = time.time()
# forward
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
with paddle.amp.auto_cast(
level=self.amp_level, enable=True if scaler else False):
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
......@@ -77,12 +80,26 @@ class U2Trainer(Trainer):
# processes.
context = nullcontext
with context():
loss.backward()
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
# do global grad clip
if train_conf.global_grad_clip != 0:
if scaler:
scaler.unscale_(self.optimizer)
# need paddlepaddle==develop or paddlepaddle>=2.5
clip_grad_norm_(self.model.parameters(),
train_conf.global_grad_clip)
if scaler:
scaler.step(self.optimizer)
scaler.update()
else:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
......@@ -173,7 +190,8 @@ class U2Trainer(Trainer):
report("epoch", self.epoch)
report('step', self.iteration)
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.train_batch(batch_index, batch, self.scaler,
msg)
self.after_train_batch()
report('iter', batch_index + 1)
if not self.use_streamdata:
......@@ -253,6 +271,19 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.test_loader.vocab_size
model = U2Model.from_config(model_conf)
# For Mixed Precision Training
self.use_amp = self.config.get("use_amp", True)
self.amp_level = self.config.get("amp_level", "O1")
if self.train and self.use_amp:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.config.get(
"scale_loss", 32768.0)) #amp default num 32768.0
#Set amp_level
if self.amp_level == 'O2':
model = paddle.amp.decorate(models=model, level=self.amp_level)
else:
self.scaler = None
if self.parallel:
model = paddle.DataParallel(model)
......@@ -290,7 +321,6 @@ class U2Trainer(Trainer):
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
......
......@@ -13,6 +13,9 @@
# limitations under the License.
import argparse
import distutils
from yacs.config import CfgNode
class ExtendAction(argparse.Action):
"""
......@@ -68,7 +71,15 @@ def default_argument_parser(parser=None):
parser.register('action', 'extend', ExtendAction)
parser.add_argument(
'--conf', type=open, action=LoadFromFile, help="config file.")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="logging with debug mode.")
parser.add_argument(
"--dump_path", type=str, default=None, help="path to dump config file.")
# train group
train_group = parser.add_argument_group(
title='Train Options', description=None)
train_group.add_argument(
......@@ -103,14 +114,35 @@ def default_argument_parser(parser=None):
train_group.add_argument(
"--dump-config", metavar="FILE", help="dump config to `this` file.")
# test group
test_group = parser.add_argument_group(
title='Test Options', description=None)
test_group.add_argument(
"--decode_cfg",
metavar="DECODE_CONFIG_FILE",
help="decode config file.")
test_group.add_argument(
"--result_file", type=str, help="path of save the asr result")
test_group.add_argument(
"--audio_file", type=str, help="path of the input audio file")
# quant & export
quant_group = parser.add_argument_group(
title='Quant Options', description=None)
quant_group.add_argument(
"--audio_scp", type=str, help="path of the input audio scp file")
quant_group.add_argument(
"--num_utts",
type=int,
default=200,
help="num utts for quant calibrition.")
quant_group.add_argument(
"--export_path",
type=str,
default='export.jit.quant',
help="path of the jit model to save")
# profile group
profile_group = parser.add_argument_group(
title='Benchmark Options', description=None)
profile_group.add_argument(
......@@ -131,3 +163,28 @@ def default_argument_parser(parser=None):
help='max iteration for benchmark.')
return parser
def config_from_args(args):
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
return config
def maybe_dump_config(dump_path, config):
if dump_path:
with open(dump_path, 'w') as f:
print(config, file=f)
print(f"save config to {dump_path}")
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
from collections import OrderedDict
......@@ -110,6 +111,7 @@ class Trainer():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self._train = True
self.scaler = None
# print deps version
all_version()
......@@ -187,8 +189,13 @@ class Trainer():
infos.update({
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
"lr": self.optimizer.get_lr(),
})
if self.scaler:
scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
paddle.save(self.scaler.state_dict(), scaler_path)
self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
......@@ -211,6 +218,13 @@ class Trainer():
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
if os.path.exists(scaler_path):
scaler_state_dict = paddle.load(scaler_path)
self.scaler.load_state_dict(scaler_state_dict)
scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
......
......@@ -820,12 +820,13 @@ class StarGANv2VCCollateFn:
self.max_mel_length = max_mel_length
def random_clip(self, mel: np.array):
# [80, T]
mel_length = mel.shape[1]
# [T, 80]
mel_length = mel.shape[0]
if mel_length > self.max_mel_length:
random_start = np.random.randint(0,
mel_length - self.max_mel_length)
mel = mel[:, random_start:random_start + self.max_mel_length]
mel = mel[random_start:random_start + self.max_mel_length, :]
return mel
def __call__(self, exmaples):
......@@ -843,7 +844,6 @@ class StarGANv2VCCollateFn:
mel = [self.random_clip(item["mel"]) for item in examples]
ref_mel = [self.random_clip(item["ref_mel"]) for item in examples]
ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples]
mel = batch_sequences(mel)
ref_mel = batch_sequences(ref_mel)
ref_mel_2 = batch_sequences(ref_mel_2)
......
......@@ -113,6 +113,16 @@ def train_sp(args, config):
model_version = '1.0'
uncompress_path = download_and_decompress(StarGANv2VC_source[model_version],
MODEL_HOME)
# 根据 speaker 的个数修改 num_domains
# 源码的预训练模型和 default.yaml 里面默认是 20
if args.speaker_dict is not None:
with open(args.speaker_dict, 'rt', encoding='utf-8') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
config['mapping_network_params']['num_domains'] = spk_num
config['style_encoder_params']['num_domains'] = spk_num
config['discriminator_params']['num_domains'] = spk_num
generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params'])
......@@ -123,7 +133,7 @@ def train_sp(args, config):
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
asr_model_dir = os.path.join(uncompress_path, 'asr.pdz')
F0_model = JDCNet(num_class=1, seq_len=192)
F0_model = JDCNet(num_class=1, seq_len=config['max_mel_length'])
F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params'])
F0_model.eval()
......@@ -234,6 +244,11 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
args = parser.parse_args()
......
......@@ -230,17 +230,15 @@ def train_sp(args, config):
output_dir=output_dir)
trainer = Trainer(
updater,
stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir)
updater, stop_trigger=(config.max_epoch, 'epoch'), out=output_dir)
if dist.get_rank() == 0:
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
evaluator, trigger=(config.eval_interval_epochs, 'epoch'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
trigger=(config.save_interval_epochs, 'epoch'))
print("Trainer Done!")
trainer.run()
......
......@@ -19,35 +19,38 @@ import paddle.nn.functional as F
from .transforms import build_transforms
# 这些都写到 updater 里
def compute_d_loss(nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
use_r1_reg: bool=True,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):
def compute_d_loss(
nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
# TODO: should be True here, but r1_reg has some bug now
use_r1_reg: bool=False,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):
assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False
out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
if use_r1_reg:
loss_reg = r1_reg(out, x_real)
else:
loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
# loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_reg = paddle.zeros([1])
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_con_reg = paddle.zeros([1])
if use_con_reg:
t = build_transforms()
out_aug = nets['discriminator'](t(x_real).detach(), y_org)
......@@ -118,9 +121,10 @@ def compute_g_loss(nets: Dict[str, Any],
s_trg = nets['style_encoder'](x_ref, y_trg)
# compute ASR/F0 features (real)
with paddle.no_grad():
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)
# 源码没有用 .eval(), 使用了 no_grad()
# 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)
# adversarial loss
x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real)
......
......@@ -259,7 +259,7 @@ class StarGANv2VCEvaluator(StandardEvaluator):
y_org=y_org,
y_trg=y_trg,
z_trg=z_trg,
use_r1_reg=False,
use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls,
**self.d_loss_params)
......@@ -269,7 +269,7 @@ class StarGANv2VCEvaluator(StandardEvaluator):
y_org=y_org,
y_trg=y_trg,
x_ref=x_ref,
use_r1_reg=False,
use_r1_reg=self.use_r1_reg,
use_adv_cls=use_adv_cls,
**self.d_loss_params)
......
......@@ -166,7 +166,9 @@ class VITSUpdater(StandardUpdater):
gen_loss.backward()
self.optimizer_g.step()
self.scheduler_g.step()
# learning rate updates on each epoch.
if self.state.iteration % self.updates_per_epoch == 0:
self.scheduler_g.step()
# reset cache
if self.model.reuse_cache_gen or not self.model.training:
......@@ -202,7 +204,9 @@ class VITSUpdater(StandardUpdater):
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()
# learning rate updates on each epoch.
if self.state.iteration % self.updates_per_epoch == 0:
self.scheduler_d.step()
# reset cache
if self.model.reuse_cache_dis or not self.model.training:
......
engine/common/base/flags.h
engine/common/base/log.h
tools/valgrind*
*log
fc_patch/*
test
# >=3.17 support -DCMAKE_FIND_DEBUG_MODE=ON
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0077.cmake")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
include(system)
project(paddlespeech VERSION 0.1)
set(PPS_VERSION_MAJOR 1)
set(PPS_VERSION_MINOR 0)
set(PPS_VERSION_PATCH 0)
set(PPS_VERSION "${PPS_VERSION_MAJOR}.${PPS_VERSION_MINOR}.${PPS_VERSION_PATCH}")
# compiler option
# Keep the same with openfst, -fPIC or -fpic
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_FIND_DEBUG_MODE OFF)
set(PPS_CXX_STANDARD 14)
# set std-14
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
# Ninja Generator will set CMAKE_BUILD_TYPE to Debug
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" FORCE)
endif()
# find_* e.g. find_library work when Cross-Compiling
if(ANDROID)
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
endif()
if(BUILD_IN_MACOS)
add_definitions("-DOS_MACOSX")
endif()
# install dir into `build/install`
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install)
include(FetchContent)
include(ExternalProject)
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
###############################################################################
# Option Configurations
###############################################################################
# https://github.com/google/brotli/pull/655
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_PPS_DEBUG "debug option" OFF)
if (WITH_PPS_DEBUG)
add_definitions("-DPPS_DEBUG")
endif()
option(WITH_ASR "build asr" ON)
option(WITH_CLS "build cls" ON)
option(WITH_VAD "build vad" ON)
option(WITH_GPU "NNet using GPU." OFF)
option(WITH_PROFILING "enable c++ profling" OFF)
option(WITH_TESTING "unit test" ON)
option(WITH_ONNX "u2 support onnx runtime" OFF)
###############################################################################
# Include Third Party
###############################################################################
include(gflags)
include(glog)
include(pybind)
#onnx
if(WITH_ONNX)
add_definitions(-DUSE_ONNX)
endif()
# gtest
if(WITH_TESTING)
include(gtest) # download, build, install gtest
endif()
# fastdeploy
include(fastdeploy)
if(WITH_ASR)
# openfst
include(openfst)
add_dependencies(openfst gflags extern_glog)
endif()
###############################################################################
# Find Package
###############################################################################
# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207
find_package(Threads REQUIRED)
if(WITH_ASR)
# https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3
find_package(Python3 COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG)
if(Python3_FOUND)
message(STATUS "Python3_FOUND = ${Python3_FOUND}")
message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}")
message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}")
message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}")
message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}")
set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE)
set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE)
endif()
message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}")
include_directories(${PYTHON_INCLUDE_DIR})
if(pybind11_FOUND)
message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}")
message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}")
message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}")
endif()
# paddle libpaddle.so
# paddle include and link option
# -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so
set(EXECUTE_COMMAND "import os"
"import paddle"
"include_dir = paddle.sysconfig.get_include()"
"paddle_dir=os.path.split(include_dir)[0]"
"libs_dir=os.path.join(paddle_dir, 'libs')"
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
"out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir])"
"out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out)"
)
execute_process(
COMMAND python -c "${EXECUTE_COMMAND}"
OUTPUT_VARIABLE PADDLE_LINK_FLAGS
RESULT_VARIABLE SUCESS)
message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS})
string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS)
# paddle compile option
# -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include
set(EXECUTE_COMMAND "import paddle"
"include_dir = paddle.sysconfig.get_include()"
"print(f\"-I{include_dir}\")"
)
execute_process(
COMMAND python -c "${EXECUTE_COMMAND}"
OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS)
message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS})
string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
# for LD_LIBRARY_PATH
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
set(EXECUTE_COMMAND "import os"
"import paddle"
"include_dir=paddle.sysconfig.get_include()"
"paddle_dir=os.path.split(include_dir)[0]"
"libs_dir=os.path.join(paddle_dir, 'libs')"
"fluid_dir=os.path.join(paddle_dir, 'fluid')"
"out=':'.join([libs_dir, fluid_dir]); print(out)"
)
execute_process(
COMMAND python -c "${EXECUTE_COMMAND}"
OUTPUT_VARIABLE PADDLE_LIB_DIRS)
message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS})
endif()
include(summary)
###############################################################################
# Add local library
###############################################################################
set(ENGINE_ROOT ${CMAKE_SOURCE_DIR}/engine)
add_subdirectory(engine)
###############################################################################
# CPack library
###############################################################################
# build a CPack driven installer package
include (InstallRequiredSystemLibraries)
set(CPACK_PACKAGE_NAME "paddlespeech_library")
set(CPACK_PACKAGE_VENDOR "paddlespeech")
set(CPACK_PACKAGE_VERSION_MAJOR 1)
set(CPACK_PACKAGE_VERSION_MINOR 0)
set(CPACK_PACKAGE_VERSION_PATCH 0)
set(CPACK_PACKAGE_DESCRIPTION "paddlespeech library")
set(CPACK_PACKAGE_CONTACT "paddlespeech@baidu.com")
set(CPACK_SOURCE_GENERATOR "TGZ")
include (CPack)
# SpeechX -- All in One Speech Task Inference
## Environment
......@@ -9,7 +8,7 @@ We develop under:
* gcc/g++/gfortran - 8.2.0
* cmake - 3.16.0
> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx.
> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build engine.
> We make sure all things work fun under docker, and recommend using it to develop and deploy.
......@@ -33,13 +32,13 @@ docker run --privileged --net=host --ipc=host -it --rm -v /path/to/paddlespeech
bash tools/venv.sh
```
2. Build `speechx` and `examples`.
2. Build `engine` and `examples`.
For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version.
For example:
```
source venv/bin/activate
python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html
python -m pip install paddlepaddle==2.4.2 -i https://mirror.baidu.com/pypi/simple
./build.sh
```
......@@ -113,3 +112,11 @@ apt-get install gfortran-8
4. `Undefined reference to '_gfortran_concat_string'`
using gcc 8.2, gfortran 8.2.
5. `./boost/python/detail/wrap_python.hpp:57:11: fatal error: pyconfig.h: No such file or directory`
```
apt-get install python3-dev
```
for more info please see [here](https://github.com/okfn/piati/issues/65).
#!/usr/bin/env bash
set -xe
BUILD_ROOT=build/Linux
BUILD_DIR=${BUILD_ROOT}/x86_64
mkdir -p ${BUILD_DIR}
BUILD_TYPE=Release
#BUILD_TYPE=Debug
BUILD_SO=OFF
BUILD_ONNX=ON
BUILD_ASR=ON
BUILD_CLS=ON
BUILD_VAD=ON
PPS_DEBUG=OFF
FASTDEPLOY_INSTALL_DIR=""
# the build script had verified in the paddlepaddle docker image.
# please follow the instruction below to install PaddlePaddle image.
# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
#cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON -DFASTDEPLOY_INSTALL_DIR=/workspace/zhanghui/paddle/FastDeploy/build/Android/arm64-v8a-api-21/install
cmake -B ${BUILD_DIR} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DBUILD_SHARED_LIBS=${BUILD_SO} \
-DWITH_ONNX=${BUILD_ONNX} \
-DWITH_ASR=${BUILD_ASR} \
-DWITH_CLS=${BUILD_CLS} \
-DWITH_VAD=${BUILD_VAD} \
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
-DWITH_PPS_DEBUG=${PPS_DEBUG}
cmake --build ${BUILD_DIR} -j
#!/bin/bash
set -ex
ANDROID_NDK=/mnt/masimeng/workspace/software/android-ndk-r25b/
# Setting up Android toolchanin
ANDROID_ABI=arm64-v8a # 'arm64-v8a', 'armeabi-v7a'
ANDROID_PLATFORM="android-21" # API >= 21
ANDROID_STL=c++_shared # 'c++_shared', 'c++_static'
ANDROID_TOOLCHAIN=clang # 'clang' only
TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
# Create build directory
BUILD_ROOT=build/Android
BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21
FASTDEPLOY_INSTALL_DIR="/mnt/masimeng/workspace/FastDeploy/build/Android/arm64-v8a-api-21/install"
mkdir -p ${BUILD_DIR}
cd ${BUILD_DIR}
# CMake configuration with Android toolchain
cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DANDROID_ABI=${ANDROID_ABI} \
-DANDROID_NDK=${ANDROID_NDK} \
-DANDROID_PLATFORM=${ANDROID_PLATFORM} \
-DANDROID_STL=${ANDROID_STL} \
-DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \
-DBUILD_SHARED_LIBS=OFF \
-DWITH_ASR=OFF \
-DWITH_CLS=OFF \
-DWITH_VAD=ON \
-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \
-DCMAKE_FIND_DEBUG_MODE=OFF \
-Wno-dev ../../..
# Build FastDeploy Android C++ SDK
make
# https://www.jianshu.com/p/33672fb819f5
PATH="/Applications/CMake.app/Contents/bin":"$PATH"
tools_dir=$1
ios_toolchain_cmake=${tools_dir}/"/ios-cmake-4.2.0/ios.toolchain.cmake"
fastdeploy_dir=${tools_dir}"/fastdeploy-ort-mac-build/"
build_targets=("OS64")
build_type_array=("Release")
#static_name="libocr"
#lib_name="libocr"
# Switch to workpath
current_path=`cd $(dirname $0);pwd`
work_path=${current_path}/
build_path=${current_path}/build/
output_path=${current_path}/output/
cd ${work_path}
# Clean
rm -rf ${build_path}
rm -rf ${output_path}
if [ "$1"x = "clean"x ]; then
exit 0
fi
# Build Every Target
for target in "${build_targets[@]}"
do
for build_type in "${build_type_array[@]}"
do
echo -e "\033[1;36;40mBuilding ${build_type} ${target} ... \033[0m"
target_build_path=${build_path}/${target}/${build_type}/
mkdir -p ${target_build_path}
cd ${target_build_path}
if [ $? -ne 0 ];then
echo -e "\033[1;31;40mcd ${target_build_path} failed \033[0m"
exit -1
fi
if [ ${target} == "OS64" ];then
fastdeploy_install_dir=${fastdeploy_dir}/arm64
else
fastdeploy_install_dir=""
echo "fastdeploy_install_dir is null"
exit -1
fi
cmake -DCMAKE_TOOLCHAIN_FILE=${ios_toolchain_cmake} \
-DBUILD_IN_MACOS=ON \
-DBUILD_SHARED_LIBS=OFF \
-DWITH_ASR=OFF \
-DWITH_CLS=OFF \
-DWITH_VAD=ON \
-DFASTDEPLOY_INSTALL_DIR=${fastdeploy_install_dir} \
-DPLATFORM=${target} ../../../
cmake --build . --config ${build_type}
mkdir output
cp engine/vad/interface/libpps_vad_interface.a output
cp engine/vad/interface/vad_interface_main.app/vad_interface_main output
cp ${fastdeploy_install_dir}/lib/libfastdeploy.dylib output
cp ${fastdeploy_install_dir}/third_libs/install/onnxruntime/lib/libonnxruntime.dylib output
done
done
## combine all ios libraries
#DEVROOT=/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/
#LIPO_TOOL=${DEVROOT}/usr/bin/lipo
#LIBRARY_PATH=${build_path}
#LIBRARY_OUTPUT_PATH=${output_path}/IOS
#mkdir -p ${LIBRARY_OUTPUT_PATH}
#
#${LIPO_TOOL} \
# -arch i386 ${LIBRARY_PATH}/ios_x86/Release/${lib_name}.a \
# -arch x86_64 ${LIBRARY_PATH}/ios_x86_64/Release/${lib_name}.a \
# -arch armv7 ${LIBRARY_PATH}/ios_armv7/Release/${lib_name}.a \
# -arch armv7s ${LIBRARY_PATH}/ios_armv7s/Release/${lib_name}.a \
# -arch arm64 ${LIBRARY_PATH}/ios_armv8/Release/${lib_name}.a \
# -output ${LIBRARY_OUTPUT_PATH}/${lib_name}.a -create
#
#cp ${work_path}/lib/houyi/lib/ios/libhouyi_score.a ${LIBRARY_OUTPUT_PATH}/
#cp ${work_path}/interface/ocr-interface.h ${output_path}
#cp ${work_path}/version/release.v ${output_path}
#
#echo -e "\033[1;36;40mBuild All Target Success At:\n${output_path}\033[0m"
#exit 0
cmake_policy(SET CMP0077 NEW)
include(FetchContent)
set(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD 1 # Wrap download in script to log output
LOG_UPDATE 1 # Wrap update in script to log output
LOG_PATCH 1
LOG_CONFIGURE 1# Wrap configure in script to log output
LOG_BUILD 1 # Wrap build in script to log output
LOG_INSTALL 1
LOG_TEST 1 # Wrap test in script to log output
LOG_MERGED_STDOUTERR 1
LOG_OUTPUT_ON_FAILURE 1
)
if(NOT FASTDEPLOY_INSTALL_DIR)
if(ANDROID)
FetchContent_Declare(
fastdeploy
URL https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz
URL_HASH MD5=2a15301158e9eb157a4f11283689e7ba
${EXTERNAL_PROJECT_LOG_ARGS}
)
add_definitions("-DUSE_PADDLE_LITE_BAKEND")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE")
else() # Linux
FetchContent_Declare(
fastdeploy
URL https://paddlespeech.bj.bcebos.com/speechx/fastdeploy/fastdeploy-1.0.5-x86_64-onnx.tar.gz
URL_HASH MD5=33900d986ea71aa78635e52f0733227c
${EXTERNAL_PROJECT_LOG_ARGS}
)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3")
endif()
FetchContent_MakeAvailable(fastdeploy)
set(FASTDEPLOY_INSTALL_DIR ${fc_patch}/fastdeploy-src)
endif()
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# fix compiler flags conflict, since fastdeploy using c++11 for project
# this line must after `include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)`
set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD})
include_directories(${FASTDEPLOY_INCS})
# install fastdeploy and dependents lib
# install_fastdeploy_libraries(${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
# No dynamic libs need to install while using
# FastDeploy static lib.
if(ANDROID AND WITH_ANDROID_STATIC_LIB)
return()
endif()
set(DYN_LIB_SUFFIX "*.so*")
if(WIN32)
set(DYN_LIB_SUFFIX "*.dll")
elseif(APPLE)
set(DYN_LIB_SUFFIX "*.dylib*")
endif()
if(FastDeploy_DIR)
set(DYN_SEARCH_DIR ${FastDeploy_DIR})
elseif(FASTDEPLOY_INSTALL_DIR)
set(DYN_SEARCH_DIR ${FASTDEPLOY_INSTALL_DIR})
else()
message(FATAL_ERROR "Please set FastDeploy_DIR/FASTDEPLOY_INSTALL_DIR before call install_fastdeploy_libraries.")
endif()
file(GLOB_RECURSE ALL_NEED_DYN_LIBS ${DYN_SEARCH_DIR}/lib/${DYN_LIB_SUFFIX})
file(GLOB_RECURSE ALL_DEPS_DYN_LIBS ${DYN_SEARCH_DIR}/third_libs/${DYN_LIB_SUFFIX})
if(ENABLE_VISION)
# OpenCV
if(ANDROID)
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${DYN_LIB_SUFFIX})
else()
file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_DIR}/../../${DYN_LIB_SUFFIX})
endif()
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_OPENCV_DYN_LIBS})
if(WIN32)
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/x64/vc15/bin/${DYN_LIB_SUFFIX})
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
elseif(ANDROID AND (NOT WITH_ANDROID_OPENCV_STATIC))
file(GLOB OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${ANDROID_ABI}/${DYN_LIB_SUFFIX})
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
else() # linux/mac
file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/lib/${DYN_LIB_SUFFIX})
install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib)
endif()
# FlyCV
if(ENABLE_FLYCV)
file(GLOB_RECURSE ALL_FLYCV_DYN_LIBS ${FLYCV_LIB_DIR}/${DYN_LIB_SUFFIX})
list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_FLYCV_DYN_LIBS})
if(ANDROID AND (NOT WITH_ANDROID_FLYCV_STATIC))
install(FILES ${ALL_FLYCV_DYN_LIBS} DESTINATION lib)
endif()
endif()
endif()
if(ENABLE_OPENVINO_BACKEND)
# need plugins.xml for openvino backend
set(OPENVINO_RUNTIME_BIN_DIR ${OPENVINO_DIR}/bin)
file(GLOB OPENVINO_PLUGIN_XML ${OPENVINO_RUNTIME_BIN_DIR}/*.xml)
install(FILES ${OPENVINO_PLUGIN_XML} DESTINATION lib)
endif()
# Install other libraries
install(FILES ${ALL_NEED_DYN_LIBS} DESTINATION lib)
install(FILES ${ALL_DEPS_DYN_LIBS} DESTINATION lib)
......@@ -2,10 +2,13 @@ include(FetchContent)
FetchContent_Declare(
gflags
URL https://github.com/gflags/gflags/archive/v2.2.2.zip
URL https://paddleaudio.bj.bcebos.com/build/gflag-2.2.2.zip
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
)
FetchContent_MakeAvailable(gflags)
# openfst need
include_directories(${gflags_BINARY_DIR}/include)
\ No newline at end of file
include_directories(${gflags_BINARY_DIR}/include)
link_directories(${gflags_BINARY_DIR})
#install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib)
include(FetchContent)
if(ANDROID)
else() # UNIX
add_definitions(-DWITH_GLOG)
FetchContent_Declare(
glog
URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DWITH_GFLAGS=OFF
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
)
set(BUILD_TESTING OFF)
FetchContent_MakeAvailable(glog)
include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
endif()
if(ANDROID)
add_library(extern_glog INTERFACE)
add_dependencies(extern_glog gflags)
else() # UNIX
add_library(extern_glog ALIAS glog)
add_dependencies(glog gflags)
endif()
\ No newline at end of file
include(FetchContent)
if(ANDROID)
else() # UNIX
FetchContent_Declare(
gtest
URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip
URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
)
FetchContent_MakeAvailable(gtest)
include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src)
endif()
if(ANDROID)
add_library(extern_gtest INTERFACE)
else() # UNIX
add_dependencies(gtest gflags extern_glog)
add_library(extern_gtest ALIAS gtest)
endif()
if(WITH_TESTING)
enable_testing()
endif()
include(FetchContent)
set(openfst_PREFIX_DIR ${fc_patch}/openfst)
set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
include(FetchContent)
# openfst Acknowledgments:
#Cyril Allauzen, Michael Riley, Johan Schalkwyk, Wojciech Skut and Mehryar Mohri,
#"OpenFst: A General and Efficient Weighted Finite-State Transducer Library",
......@@ -10,18 +10,33 @@ set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
#Application of Automata, (CIAA 2007), volume 4783 of Lecture Notes in
#Computer Science, pages 11-23. Springer, 2007. http://www.openfst.org.
set(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD 1 # Wrap download in script to log output
LOG_UPDATE 1 # Wrap update in script to log output
LOG_CONFIGURE 1# Wrap configure in script to log output
LOG_BUILD 1 # Wrap build in script to log output
LOG_TEST 1 # Wrap test in script to log output
LOG_INSTALL 1 # Wrap install in script to log output
)
ExternalProject_Add(openfst
URL https://paddleaudio.bj.bcebos.com/build/openfst_1.7.2.zip
URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${openfst_PREFIX_DIR}
SOURCE_DIR ${openfst_SOURCE_DIR}
BINARY_DIR ${openfst_BINARY_DIR}
BUILD_ALWAYS 0
CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
"LIBS=-lgflags_nothreads -lglog -lpthread"
"LIBS=-lgflags_nothreads -lglog -lpthread -fPIC"
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
BUILD_COMMAND make -j 4
)
link_directories(${openfst_PREFIX_DIR}/lib)
include_directories(${openfst_PREFIX_DIR}/include)
message(STATUS "OpenFST inc dir: ${openfst_PREFIX_DIR}/include")
message(STATUS "OpenFST lib dir: ${openfst_PREFIX_DIR}/lib")
#the pybind11 is from:https://github.com/pybind/pybind11
# Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>, All rights reserved.
SET(PYBIND_ZIP "v2.10.0.zip")
SET(LOCAL_PYBIND_ZIP ${FETCHCONTENT_BASE_DIR}/${PYBIND_ZIP})
SET(PYBIND_SRC ${FETCHCONTENT_BASE_DIR}/pybind11)
SET(DOWNLOAD_URL "https://paddleaudio.bj.bcebos.com/build/v2.10.0.zip")
SET(PYBIND_TIMEOUT 600 CACHE STRING "Timeout in seconds when downloading pybind.")
IF(NOT EXISTS ${LOCAL_PYBIND_ZIP})
FILE(DOWNLOAD ${DOWNLOAD_URL}
${LOCAL_PYBIND_ZIP}
TIMEOUT ${PYBIND_TIMEOUT}
STATUS ERR
SHOW_PROGRESS
)
IF(ERR EQUAL 0)
MESSAGE(STATUS "download pybind success")
ELSE()
MESSAGE(FATAL_ERROR "download pybind fail")
ENDIF()
ENDIF()
IF(NOT EXISTS ${PYBIND_SRC})
EXECUTE_PROCESS(
COMMAND ${CMAKE_COMMAND} -E tar xfz ${LOCAL_PYBIND_ZIP}
WORKING_DIRECTORY ${FETCHCONTENT_BASE_DIR}
RESULT_VARIABLE tar_result
)
file(RENAME ${FETCHCONTENT_BASE_DIR}/pybind11-2.10.0 ${PYBIND_SRC})
IF (tar_result MATCHES 0)
MESSAGE(STATUS "unzip pybind success")
ELSE()
MESSAGE(FATAL_ERROR "unzip pybind fail")
ENDIF()
ENDIF()
include_directories(${PYBIND_SRC}/include)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(pps_summary)
message(STATUS "")
message(STATUS "*************PaddleSpeech Building Summary**********")
message(STATUS " PPS_VERSION : ${PPS_VERSION}")
message(STATUS " CMake version : ${CMAKE_VERSION}")
message(STATUS " CMake command : ${CMAKE_COMMAND}")
message(STATUS " UNIX : ${UNIX}")
message(STATUS " ANDROID : ${ANDROID}")
message(STATUS " System : ${CMAKE_SYSTEM_NAME}")
message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}")
message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}")
message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}")
message(STATUS " Build type : ${CMAKE_BUILD_TYPE}")
message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}")
get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS)
message(STATUS " Compile definitions : ${tmp}")
message(STATUS " CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}")
message(STATUS " CMAKE_CURRENT_BINARY_DIR : ${CMAKE_CURRENT_BINARY_DIR}")
message(STATUS " CMAKE_INSTALL_PREFIX : ${CMAKE_INSTALL_PREFIX}")
message(STATUS " CMAKE_INSTALL_LIBDIR : ${CMAKE_INSTALL_LIBDIR}")
message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}")
message(STATUS " CMAKE_SYSTEM_NAME : ${CMAKE_SYSTEM_NAME}")
message(STATUS "")
message(STATUS " WITH_ASR : ${WITH_ASR}")
message(STATUS " WITH_CLS : ${WITH_CLS}")
message(STATUS " WITH_VAD : ${WITH_VAD}")
message(STATUS " WITH_GPU : ${WITH_GPU}")
message(STATUS " WITH_TESTING : ${WITH_TESTING}")
message(STATUS " WITH_PROFILING : ${WITH_PROFILING}")
message(STATUS " FASTDEPLOY_INSTALL_DIR : ${FASTDEPLOY_INSTALL_DIR}")
message(STATUS " FASTDEPLOY_INCS : ${FASTDEPLOY_INCS}")
message(STATUS " FASTDEPLOY_LIBS : ${FASTDEPLOY_LIBS}")
if(WITH_GPU)
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
endif()
if(ANDROID)
message(STATUS " ANDROID_ABI : ${ANDROID_ABI}")
message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}")
message(STATUS " ANDROID_NDK : ${ANDROID_NDK}")
message(STATUS " ANDROID_NDK_VERSION : ${CMAKE_ANDROID_NDK_VERSION}")
endif()
if (WITH_ASR)
message(STATUS " Python executable : ${PYTHON_EXECUTABLE}")
message(STATUS " Python includes : ${PYTHON_INCLUDE_DIR}")
endif()
endfunction()
pps_summary()
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Detects the OS and sets appropriate variables.
# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is
# building for, but the host processor name like centos is necessary
# in some scenes to distinguish system for customization.
#
# for instance, protobuf libs path is <install_dir>/lib64
# on CentOS, but <install_dir>/lib on other systems.
if(UNIX AND NOT APPLE)
# except apple from nix*Os family
set(LINUX TRUE)
endif()
if(WIN32)
set(HOST_SYSTEM "win32")
else()
if(APPLE)
set(HOST_SYSTEM "macosx")
exec_program(
sw_vers ARGS
-productVersion
OUTPUT_VARIABLE HOST_SYSTEM_VERSION)
string(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}")
if(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET})
# Set cache variable - end user may change this during ccmake or cmake-gui configure.
set(CMAKE_OSX_DEPLOYMENT_TARGET
${MACOS_VERSION}
CACHE
STRING
"Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value."
)
endif()
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
else()
if(EXISTS "/etc/issue")
file(READ "/etc/issue" LINUX_ISSUE)
if(LINUX_ISSUE MATCHES "CentOS")
set(HOST_SYSTEM "centos")
elseif(LINUX_ISSUE MATCHES "Debian")
set(HOST_SYSTEM "debian")
elseif(LINUX_ISSUE MATCHES "Ubuntu")
set(HOST_SYSTEM "ubuntu")
elseif(LINUX_ISSUE MATCHES "Red Hat")
set(HOST_SYSTEM "redhat")
elseif(LINUX_ISSUE MATCHES "Fedora")
set(HOST_SYSTEM "fedora")
endif()
string(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION
"${LINUX_ISSUE}")
endif()
if(EXISTS "/etc/redhat-release")
file(READ "/etc/redhat-release" LINUX_ISSUE)
if(LINUX_ISSUE MATCHES "CentOS")
set(HOST_SYSTEM "centos")
endif()
endif()
if(NOT HOST_SYSTEM)
set(HOST_SYSTEM ${CMAKE_SYSTEM_NAME})
endif()
endif()
endif()
# query number of logical cores
cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES)
mark_as_advanced(HOST_SYSTEM CPU_CORES)
message(
STATUS
"Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}")
message(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores")
# external dependencies log output
set(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD
0 # Wrap download in script to log output
LOG_UPDATE
1 # Wrap update in script to log output
LOG_CONFIGURE
1 # Wrap configure in script to log output
LOG_BUILD
0 # Wrap build in script to log output
LOG_TEST
1 # Wrap test in script to log output
LOG_INSTALL
0 # Wrap install in script to log output
)
\ No newline at end of file
project(speechx LANGUAGES CXX)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common)
add_subdirectory(kaldi)
add_subdirectory(common)
if(WITH_ASR)
add_subdirectory(asr)
endif()
if(WITH_CLS)
add_subdirectory(audio_classification)
endif()
if(WITH_VAD)
add_subdirectory(vad)
endif()
add_subdirectory(codelab)
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(ASR LANGUAGES CXX)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/server)
add_subdirectory(decoder)
add_subdirectory(recognizer)
add_subdirectory(nnet)
add_subdirectory(server)
set(srcs decodable.cc)
if(USING_DS2)
list(APPEND srcs ds2_nnet.cc)
endif()
if(USING_U2)
list(APPEND srcs u2_nnet.cc)
endif()
add_library(nnet STATIC ${srcs})
target_link_libraries(nnet absl::strings)
if(USING_U2)
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
endif()
if(USING_DS2)
set(bin_name ds2_nnet_main)
set(srcs)
list(APPEND srcs
ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
)
add_library(decoder STATIC ${srcs})
target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
# test
set(TEST_BINS
ctc_prefix_beam_search_decoder_main
ctc_tlg_decoder_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
target_link_libraries(${bin_name} ${DEPS})
endif()
# test bin
if(USING_U2)
set(bin_name u2_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endif()
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
endforeach()
......@@ -22,51 +22,22 @@ namespace ppspeech {
struct CTCBeamSearchOptions {
// common
int blank;
// ds2
std::string dict_file;
std::string lm_path;
int beam_size;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int cutoff_top_n;
int num_proc_bsearch;
std::string word_symbol_table;
// u2
int first_beam_size;
int second_beam_size;
CTCBeamSearchOptions()
: blank(0),
dict_file("vocab.txt"),
lm_path(""),
beam_size(300),
alpha(1.9f),
beta(5.0),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10),
word_symbol_table("vocab.txt"),
first_beam_size(10),
second_beam_size(10) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "Ds2BeamSearchConfig: ";
opts->Register("dict", &dict_file, module + "vocab file path.");
opts->Register(
"lm-path", &lm_path, module + "ngram language model path.");
opts->Register("alpha", &alpha, module + "alpha");
opts->Register("beta", &beta, module + "beta");
opts->Register("beam-size",
&beam_size,
module + "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
std::string module = "CTCBeamSearchOptions: ";
opts->Register("word_symbol_table", &word_symbol_table, module + "vocab file path.");
opts->Register("blank", &blank, "blank id, default is 0.");
module = "U2BeamSearchConfig: ";
opts->Register(
"first-beam-size", &first_beam_size, module + "first beam size.");
opts->Register("second-beam-size",
......
......@@ -17,13 +17,12 @@
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "absl/strings/str_join.h"
#include "base/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "utils/math.h"
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
#include "paddle/fluid/platform/profiler.h"
using paddle::platform::RecordEvent;
using paddle::platform::TracerEventType;
......@@ -31,11 +30,10 @@ using paddle::platform::TracerEventType;
namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts)
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(vocab_path));
fst::SymbolTable::ReadText(opts.word_symbol_table));
CHECK(unit_table_ != nullptr);
Reset();
......@@ -66,7 +64,6 @@ void CTCPrefixBeamSearch::Reset() {
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
double search_cost = 0.0;
......@@ -78,21 +75,21 @@ void CTCPrefixBeamSearch::AdvanceDecode(
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
feat_nnet_cost += timer.Elapsed();
if (flag == false) {
VLOG(3) << "decoder advance decode exit." << frame_prob.size();
VLOG(2) << "decoder advance decode exit." << frame_prob.size();
break;
}
timer.Reset();
std::vector<std::vector<kaldi::BaseFloat>> likelihood;
likelihood.push_back(frame_prob);
likelihood.push_back(std::move(frame_prob));
AdvanceDecoding(likelihood);
search_cost += timer.Elapsed();
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_;
}
VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
VLOG(2) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
<< " sec.";
VLOG(1) << "AdvanceDecode search cost: " << search_cost << " sec.";
VLOG(2) << "AdvanceDecode search cost: " << search_cost << " sec.";
}
static bool PrefixScoreCompare(
......@@ -105,7 +102,7 @@ static bool PrefixScoreCompare(
void CTCPrefixBeamSearch::AdvanceDecoding(
const std::vector<std::vector<kaldi::BaseFloat>>& logp) {
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding",
TracerEventType::UserDefined,
1);
......
......@@ -27,8 +27,7 @@ namespace ppspeech {
class ContextGraph;
class CTCPrefixBeamSearch : public DecoderBase {
public:
CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts);
CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {}
SearchType Type() const { return SearchType::kPrefixBeamSearch; }
......@@ -45,7 +44,7 @@ class CTCPrefixBeamSearch : public DecoderBase {
void FinalizeSearch();
const std::shared_ptr<fst::SymbolTable> VocabTable() const {
const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
return unit_table_;
}
......@@ -57,7 +56,6 @@ class CTCPrefixBeamSearch : public DecoderBase {
}
const std::vector<std::vector<int>>& Times() const { return times_; }
protected:
std::string GetBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath() override;
......
......@@ -12,18 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/strings/str_split.h"
#include "base/common.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "frontend/data_cache.h"
#include "fst/symbol-table.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
#include "nnet/u2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(vocab_path, "", "vocab path");
DEFINE_string(word_symbol_table, "", "vocab path");
DEFINE_string(model_path, "", "paddle nnet model");
......@@ -40,7 +40,7 @@ using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test ds2 online decoder by feeding speech feature
// test u2 online decoder by feeding speech feature
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
......@@ -52,10 +52,10 @@ int main(int argc, char* argv[]) {
CHECK_NE(FLAGS_result_wspecifier, "");
CHECK_NE(FLAGS_feature_rspecifier, "");
CHECK_NE(FLAGS_vocab_path, "");
CHECK_NE(FLAGS_word_symbol_table, "");
CHECK_NE(FLAGS_model_path, "");
LOG(INFO) << "model path: " << FLAGS_model_path;
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path;
LOG(INFO) << "Reading vocab table " << FLAGS_word_symbol_table;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
......@@ -70,15 +70,18 @@ int main(int argc, char* argv[]) {
// decodeable
std::shared_ptr<ppspeech::DataCache> raw_data =
std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nnet, raw_data, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable =
std::make_shared<ppspeech::Decodable>(nnet, raw_data);
std::make_shared<ppspeech::Decodable>(nnet_producer);
// decoder
ppspeech::CTCBeamSearchOptions opts;
opts.blank = 0;
opts.first_beam_size = 10;
opts.second_beam_size = 10;
ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts);
opts.word_symbol_table = FLAGS_word_symbol_table;
ppspeech::CTCPrefixBeamSearch decoder(opts);
int32 chunk_size = FLAGS_receptive_field_length +
......@@ -122,15 +125,14 @@ int main(int argc, char* argv[]) {
}
kaldi::Vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim);
std::vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim);
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < this_chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row(
feature_chunk.Data() + row_id * feat_dim, feat_dim);
feature_chunk_row.CopyFromVec(feat_row);
std::memcpy(feature_chunk.data() + row_id * feat_dim,
feat_row.Data(),
feat_dim * sizeof(kaldi::BaseFloat));
++start;
}
......
......@@ -13,12 +13,14 @@
// limitations under the License.
#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) {
fst_ = opts.fst_ptr;
CHECK(fst_ != nullptr);
CHECK(!opts.word_symbol_table.empty());
word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table));
......@@ -29,6 +31,11 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
void TLGDecoder::Reset() {
decoder_->InitDecoding();
hypotheses_.clear();
likelihood_.clear();
olabels_.clear();
times_.clear();
num_frame_decoded_ = 0;
return;
}
......@@ -68,14 +75,52 @@ std::string TLGDecoder::GetPartialResult() {
return words;
}
void TLGDecoder::FinalizeSearch() {
decoder_->FinalizeDecoding();
kaldi::CompactLattice clat;
decoder_->GetLattice(&clat, true);
kaldi::Lattice lat, nbest_lat;
fst::ConvertLattice(clat, &lat);
fst::ShortestPath(lat, &nbest_lat, opts_.nbest);
std::vector<kaldi::Lattice> nbest_lats;
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
hypotheses_.clear();
hypotheses_.reserve(nbest_lats.size());
likelihood_.clear();
likelihood_.reserve(nbest_lats.size());
times_.clear();
times_.reserve(nbest_lats.size());
for (auto lat : nbest_lats) {
kaldi::LatticeWeight weight;
std::vector<int> hypothese;
std::vector<int> time;
std::vector<int> alignment;
std::vector<int> words_id;
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
int idx = 0;
for (; idx < alignment.size() - 1; ++idx) {
if (alignment[idx] == 0) continue;
if (alignment[idx] != alignment[idx + 1]) {
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
}
}
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
hypotheses_.push_back(hypothese);
times_.push_back(time);
olabels_.push_back(words_id);
likelihood_.push_back(-(weight.Value2() + weight.Value1()));
}
}
std::string TLGDecoder::GetFinalBestPath() {
if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
}
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
......
......@@ -18,13 +18,14 @@
#include "decoder/decoder_itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
#include "utils/file_utils.h"
DECLARE_string(graph_path);
DECLARE_string(word_symbol_table);
DECLARE_string(graph_path);
DECLARE_int32(max_active);
DECLARE_double(beam);
DECLARE_double(lattice_beam);
DECLARE_int32(nbest);
namespace ppspeech {
......@@ -33,17 +34,27 @@ struct TLGDecoderOptions {
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_ptr;
int nbest;
TLGDecoderOptions() : word_symbol_table(""), fst_path(""), fst_ptr(nullptr), nbest(10) {}
static TLGDecoderOptions InitFromFlags() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table;
LOG(INFO) << "symbole table: " << decoder_opts.word_symbol_table;
if (!decoder_opts.fst_path.empty()) {
CHECK(FileExists(decoder_opts.fst_path));
decoder_opts.fst_ptr.reset(fst::Fst<fst::StdArc>::Read(FLAGS_graph_path));
}
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
decoder_opts.nbest = FLAGS_nbest;
LOG(INFO) << "LatticeFasterDecoder max active: "
<< decoder_opts.opts.max_active;
LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam;
......@@ -59,20 +70,38 @@ class TLGDecoder : public DecoderBase {
explicit TLGDecoder(TLGDecoderOptions opts);
~TLGDecoder() = default;
void InitDecoder();
void Reset();
void InitDecoder() override;
void Reset() override;
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
const std::shared_ptr<kaldi::DecodableInterface>& decodable) override;
void Decode();
std::string GetFinalBestPath() override;
std::string GetPartialResult() override;
const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
return word_symbol_table_;
}
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
const std::vector<std::string>& nbest_words);
void FinalizeSearch() override;
const std::vector<std::vector<int>>& Inputs() const override {
return hypotheses_;
}
const std::vector<std::vector<int>>& Outputs() const override {
return olabels_;
} // outputs_; }
const std::vector<float>& Likelihood() const override {
return likelihood_;
}
const std::vector<std::vector<int>>& Times() const override {
return times_;
}
protected:
std::string GetBestPath() override {
CHECK(false);
......@@ -90,10 +119,17 @@ class TLGDecoder : public DecoderBase {
private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable);
int num_frame_decoded_;
std::vector<std::vector<int>> hypotheses_;
std::vector<std::vector<int>> olabels_;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
TLGDecoderOptions opts_;
};
} // namespace ppspeech
\ No newline at end of file
} // namespace ppspeech
......@@ -14,21 +14,24 @@
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "base/common.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/param.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
DEFINE_string(nnet_prob_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test decoder by feeding nnet posterior probability
// test TLG decoder by feeding speech feature.
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
......@@ -36,41 +39,51 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader likelihood_reader(
FLAGS_nnet_prob_respecifier);
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
kaldi::SequentialBaseFloatMatrixReader nnet_prob_reader(
FLAGS_nnet_prob_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::TLGDecoderOptions opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nullptr, nullptr, 1.0);
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nullptr, nullptr));
new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale));
decoder.InitDecoder();
kaldi::Timer timer;
for (; !likelihood_reader.Done(); likelihood_reader.Next()) {
string utt = likelihood_reader.Key();
const kaldi::Matrix<BaseFloat> likelihood = likelihood_reader.Value();
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << likelihood.NumRows();
LOG(INFO) << "cols: " << likelihood.NumCols();
decodable->Acceptlikelihood(likelihood);
for (; !nnet_prob_reader.Done(); nnet_prob_reader.Next()) {
string utt = nnet_prob_reader.Key();
kaldi::Matrix<BaseFloat> prob = nnet_prob_reader.Value();
decodable->Acceptlikelihood(prob);
decoder.AdvanceDecode(decodable);
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -16,6 +15,7 @@
#pragma once
#include "base/common.h"
#include "fst/symbol-table.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
......@@ -41,6 +41,14 @@ class DecoderInterface {
virtual std::string GetPartialResult() = 0;
virtual const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const = 0;
virtual void FinalizeSearch() = 0;
virtual const std::vector<std::vector<int>>& Inputs() const = 0;
virtual const std::vector<std::vector<int>>& Outputs() const = 0;
virtual const std::vector<float>& Likelihood() const = 0;
virtual const std::vector<std::vector<int>>& Times() const = 0;
protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
......
......@@ -15,8 +15,6 @@
#pragma once
#include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
// feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
......@@ -37,36 +35,22 @@ DEFINE_int32(subsampling_rate,
"two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet
DEFINE_string(vocab_path, "", "nnet vocab path.");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
#ifdef USE_ONNX
DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path");
#endif
// decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "", "decoder graph");
DEFINE_string(word_symbol_table, "", "word symbol table");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_double(blank_threshold, 0.98, "blank skip threshold");
// DecodeOptions flags
// DEFINE_int32(chunk_size, -1, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight,
0.5,
......
set(srcs decodable.cc nnet_producer.cc)
list(APPEND srcs u2_nnet.cc)
if(WITH_ONNX)
list(APPEND srcs u2_onnx_nnet.cc)
endif()
add_library(nnet STATIC ${srcs})
target_link_libraries(nnet utils)
if(WITH_ONNX)
target_link_libraries(nnet ${FASTDEPLOY_LIBS})
endif()
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
# test bin
#set(bin_name u2_nnet_main)
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
\ No newline at end of file
......@@ -21,29 +21,25 @@ using kaldi::Matrix;
using kaldi::Vector;
using std::vector;
Decodable::Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend,
Decodable::Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
kaldi::BaseFloat acoustic_scale)
: frontend_(frontend),
nnet_(nnet),
: nnet_producer_(nnet_producer),
frame_offset_(0),
frames_ready_(0),
acoustic_scale_(acoustic_scale) {}
// for debug
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_out_cache_ = likelihood;
frames_ready_ += likelihood.NumRows();
nnet_producer_->Acceptlikelihood(likelihood);
}
// return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; }
// frame idx is from 0 to frame_ready_ -1;
bool Decodable::IsLastFrame(int32 frame) {
bool flag = EnsureFrameHaveComputed(frame);
EnsureFrameHaveComputed(frame);
return frame >= frames_ready_;
}
......@@ -64,32 +60,10 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) {
bool Decodable::AdvanceChunk() {
kaldi::Timer timer;
// read feats
Vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
VLOG(3) << "decodable exit;";
return false;
}
CHECK_GE(frontend_->Dim(), 0);
VLOG(1) << "AdvanceChunk feat cost: " << timer.Elapsed() << " sec.";
VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats.";
// forward feats
NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
Vector<BaseFloat>& logprobs = out.logprobs;
VLOG(2) << "Forward out " << logprobs.Dim() / vocab_dim
<< " decoder frames.";
// cache nnet outupts
nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim);
nnet_out_cache_.CopyRowsFromVec(logprobs);
// update state, decoding frame.
bool flag = nnet_producer_->Read(&framelikelihood_);
if (flag == false) return false;
frame_offset_ = frames_ready_;
frames_ready_ += nnet_out_cache_.NumRows();
frames_ready_ += 1;
VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed()
<< " sec.";
return true;
......@@ -101,17 +75,17 @@ bool Decodable::AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
return false;
}
int nrows = nnet_out_cache_.NumRows();
CHECK(nrows == (frames_ready_ - frame_offset_));
if (nrows <= 0) {
if (framelikelihood_.empty()) {
LOG(WARNING) << "No new nnet out in cache.";
return false;
}
logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols());
logprobs->CopyRowsFromMat(nnet_out_cache_);
*vocab_dim = nnet_out_cache_.NumCols();
size_t dim = framelikelihood_.size();
logprobs->Resize(framelikelihood_.size());
std::memcpy(logprobs->Data(),
framelikelihood_.data(),
dim * sizeof(kaldi::BaseFloat));
*vocab_dim = framelikelihood_.size();
return true;
}
......@@ -122,19 +96,8 @@ bool Decodable::FrameLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
return false;
}
int nrows = nnet_out_cache_.NumRows();
CHECK(nrows == (frames_ready_ - frame_offset_));
int vocab_size = nnet_out_cache_.NumCols();
likelihood->resize(vocab_size);
for (int32 idx = 0; idx < vocab_size; ++idx) {
(*likelihood)[idx] =
nnet_out_cache_(frame - frame_offset_, idx) * acoustic_scale_;
VLOG(4) << "nnet out: " << frame << " offset:" << frame_offset_ << " "
<< nnet_out_cache_.NumRows()
<< " logprob: " << nnet_out_cache_(frame - frame_offset_, idx);
}
CHECK_EQ(1, (frames_ready_ - frame_offset_));
*likelihood = framelikelihood_;
return true;
}
......@@ -143,37 +106,31 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return false;
}
CHECK_LE(index, nnet_out_cache_.NumCols());
CHECK_LE(index, framelikelihood_.size());
CHECK_LE(frame, frames_ready_);
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
BaseFloat logprob = 0.0;
int32 frame_idx = frame - frame_offset_;
BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index));
if (nnet_->IsLogProb()) {
logprob = nnet_out;
} else {
logprob = std::log(nnet_out + std::numeric_limits<float>::epsilon());
}
CHECK(!std::isnan(logprob) && !std::isinf(logprob));
CHECK_EQ(frame_idx, 0);
logprob = framelikelihood_[TokenId2NnetId(index)];
return acoustic_scale_ * logprob;
}
void Decodable::Reset() {
if (frontend_ != nullptr) frontend_->Reset();
if (nnet_ != nullptr) nnet_->Reset();
if (nnet_producer_ != nullptr) nnet_producer_->Reset();
frame_offset_ = 0;
frames_ready_ = 0;
nnet_out_cache_.Resize(0, 0);
framelikelihood_.clear();
}
void Decodable::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
kaldi::Timer timer;
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec.";
}
} // namespace ppspeech
\ No newline at end of file
} // namespace ppspeech
......@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/decoder/decodable-itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "nnet/nnet_producer.h"
namespace ppspeech {
......@@ -24,12 +26,9 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend,
explicit Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
kaldi::BaseFloat acoustic_scale = 1.0);
// void Init(DecodableOpts config);
// nnet logprob output, used by wfst
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
......@@ -57,23 +56,17 @@ class Decodable : public kaldi::DecodableInterface {
void Reset();
bool IsInputFinished() const { return frontend_->IsFinished(); }
bool IsInputFinished() const { return nnet_producer_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame);
int32 TokenId2NnetId(int32 token_id);
std::shared_ptr<NnetBase> Nnet() { return nnet_; }
// for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
private:
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
// nnet outputs' cache
kaldi::Matrix<kaldi::BaseFloat> nnet_out_cache_;
std::shared_ptr<NnetProducer> nnet_producer_;
// the frame is nnet prob frame rather than audio feature frame
// nnet frame subsample the feature frame
......@@ -85,6 +78,7 @@ class Decodable : public kaldi::DecodableInterface {
// so use subsampled_frame
int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_;
std::vector<kaldi::BaseFloat> framelikelihood_;
kaldi::BaseFloat acoustic_scale_;
};
......
......@@ -15,7 +15,6 @@
#include "base/basic_types.h"
#include "kaldi/base/kaldi-types.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
DECLARE_int32(subsampling_rate);
......@@ -25,26 +24,20 @@ DECLARE_string(model_input_names);
DECLARE_string(model_output_names);
DECLARE_string(model_cache_names);
DECLARE_string(model_cache_shapes);
#ifdef USE_ONNX
DECLARE_bool(with_onnx_model);
#endif
namespace ppspeech {
struct ModelOptions {
// common
int subsample_rate{1};
int thread_num{1}; // predictor thread pool size for ds2;
bool use_gpu{false};
std::string model_path;
std::string param_path;
// ds2 for inference
std::string input_names{};
std::string output_names{};
std::string cache_names{};
std::string cache_shape{};
bool switch_ir_optim{false};
bool enable_fc_padding{false};
bool enable_profile{false};
#ifdef USE_ONNX
bool with_onnx_model{false};
#endif
static ModelOptions InitFromFlags() {
ModelOptions opts;
......@@ -52,26 +45,17 @@ struct ModelOptions {
LOG(INFO) << "subsampling rate: " << opts.subsample_rate;
opts.model_path = FLAGS_model_path;
LOG(INFO) << "model path: " << opts.model_path;
opts.param_path = FLAGS_param_path;
LOG(INFO) << "param path: " << opts.param_path;
LOG(INFO) << "DS2 param: ";
opts.cache_names = FLAGS_model_cache_names;
LOG(INFO) << " cache names: " << opts.cache_names;
opts.cache_shape = FLAGS_model_cache_shapes;
LOG(INFO) << " cache shape: " << opts.cache_shape;
opts.input_names = FLAGS_model_input_names;
LOG(INFO) << " input names: " << opts.input_names;
opts.output_names = FLAGS_model_output_names;
LOG(INFO) << " output names: " << opts.output_names;
#ifdef USE_ONNX
opts.with_onnx_model = FLAGS_with_onnx_model;
LOG(INFO) << "with onnx model: " << opts.with_onnx_model;
#endif
return opts;
}
};
struct NnetOut {
// nnet out. maybe logprob or prob. Almost time this is logprob.
kaldi::Vector<kaldi::BaseFloat> logprobs;
std::vector<kaldi::BaseFloat> logprobs;
int32 vocab_dim;
// nnet state. Only using in Attention model.
......@@ -89,7 +73,7 @@ class NnetInterface {
// nnet do not cache feats, feats cached by frontend.
// nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache,
// frame_offset.
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
virtual void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) = 0;
......@@ -105,14 +89,14 @@ class NnetInterface {
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0;
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const = 0;
};
class NnetBase : public NnetInterface {
public:
int SubsamplingRate() const { return subsampling_rate_; }
virtual std::shared_ptr<NnetBase> Clone() const = 0;
protected:
int subsampling_rate_{1};
};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet_producer.h"
#include "matrix/kaldi-matrix.h"
namespace ppspeech {
using kaldi::BaseFloat;
using std::vector;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold)
: nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) {
Reset();
}
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
frontend_->Accept(inputs);
}
void NnetProducer::Acceptlikelihood(
const kaldi::Matrix<BaseFloat>& likelihood) {
std::vector<BaseFloat> prob;
prob.resize(likelihood.NumCols());
for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) {
for (size_t col = 0; col < likelihood.NumCols(); ++col) {
prob[col] = likelihood(idx, col);
}
cache_.push_back(prob);
}
}
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
bool flag = cache_.pop(nnet_prob);
return flag;
}
bool NnetProducer::Compute() {
vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
if (frontend_->IsFinished() == true) {
finished_ = true;
}
return false;
}
CHECK_GE(frontend_->Dim(), 0);
VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
size_t nframes = out.logprobs.size() / vocab_dim;
VLOG(1) << "Forward out " << nframes << " decoder frames.";
for (size_t idx = 0; idx < nframes; ++idx) {
std::vector<BaseFloat> logprob(
out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim);
// process blank prob
float blank_prob = std::exp(logprob[0]);
if (blank_prob > blank_threshold_) {
last_frame_logprob_ = logprob;
is_last_frame_skip_ = true;
continue;
} else {
int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin();
if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) {
cache_.push_back(last_frame_logprob_);
last_max_elem_ = cur_max;
}
last_max_elem_ = cur_max;
is_last_frame_skip_ = false;
cache_.push_back(logprob);
}
}
return true;
}
void NnetProducer::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
}
} // namespace ppspeech
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "base/safe_queue.h"
#include "frontend/frontend_itf.h"
#include "nnet/nnet_itf.h"
namespace ppspeech {
class NnetProducer {
public:
explicit NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend,
float blank_threshold);
// Feed feats or waves
void Accept(const std::vector<kaldi::BaseFloat>& inputs);
void Acceptlikelihood(const kaldi::Matrix<BaseFloat>& likelihood);
// nnet
bool Read(std::vector<kaldi::BaseFloat>* nnet_prob);
bool Empty() const { return cache_.empty(); }
void SetInputFinished() {
LOG(INFO) << "set finished";
frontend_->SetFinished();
}
// the compute thread exit
bool IsFinished() const {
return (frontend_->IsFinished() && finished_);
}
~NnetProducer() {}
void Reset() {
if (frontend_ != NULL) frontend_->Reset();
if (nnet_ != NULL) nnet_->Reset();
cache_.clear();
finished_ = false;
}
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score);
bool Compute();
private:
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::vector<BaseFloat> last_frame_logprob_;
bool is_last_frame_skip_ = false;
int last_max_elem_ = -1;
float blank_threshold_ = 0.0;
bool finished_;
DISALLOW_COPY_AND_ASSIGN(NnetProducer);
};
} // namespace ppspeech
......@@ -17,12 +17,13 @@
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.cc
#include "nnet/u2_nnet.h"
#include <type_traits>
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
#include "paddle/fluid/platform/profiler.h"
using paddle::platform::RecordEvent;
using paddle::platform::TracerEventType;
#endif // end USE_PROFILING
#endif // end WITH_PROFILING
namespace ppspeech {
......@@ -30,7 +31,7 @@ namespace ppspeech {
void U2Nnet::LoadModel(const std::string& model_path_w_prefix) {
paddle::jit::utils::InitKernelSignatureMap();
#ifdef USE_GPU
#ifdef WITH_GPU
dev_ = phi::GPUPlace();
#else
dev_ = phi::CPUPlace();
......@@ -62,12 +63,12 @@ void U2Nnet::LoadModel(const std::string& model_path_w_prefix) {
}
void U2Nnet::Warmup() {
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event("warmup", TracerEventType::UserDefined, 1);
#endif
{
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event(
"warmup-encoder-ctc", TracerEventType::UserDefined, 1);
#endif
......@@ -91,7 +92,7 @@ void U2Nnet::Warmup() {
}
{
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event("warmup-decoder", TracerEventType::UserDefined, 1);
#endif
auto hyps =
......@@ -101,10 +102,10 @@ void U2Nnet::Warmup() {
auto encoder_out = paddle::ones(
{1, 20, 512}, paddle::DataType::FLOAT32, phi::CPUPlace());
std::vector<paddle::experimental::Tensor> inputs{
std::vector<paddle::Tensor> inputs{
hyps, hyps_lens, encoder_out};
std::vector<paddle::experimental::Tensor> outputs =
std::vector<paddle::Tensor> outputs =
forward_attention_decoder_(inputs);
}
......@@ -118,27 +119,46 @@ U2Nnet::U2Nnet(const ModelOptions& opts) : opts_(opts) {
// shallow copy
U2Nnet::U2Nnet(const U2Nnet& other) {
// copy meta
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidecoder_ = other.is_bidecoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
forward_encoder_chunk_ = other.forward_encoder_chunk_;
forward_attention_decoder_ = other.forward_attention_decoder_;
ctc_activation_ = other.ctc_activation_;
offset_ = other.offset_;
// copy model ptr
model_ = other.model_;
// model_ = other.model_->Clone();
// hack, fix later
#ifdef WITH_GPU
dev_ = phi::GPUPlace();
#else
dev_ = phi::CPUPlace();
#endif
paddle::jit::Layer model = paddle::jit::Load(other.opts_.model_path, dev_);
model_ = std::make_shared<paddle::jit::Layer>(std::move(model));
ctc_activation_ = model_->Function("ctc_activation");
subsampling_rate_ = model_->Attribute<int>("subsampling_rate");
right_context_ = model_->Attribute<int>("right_context");
sos_ = model_->Attribute<int>("sos_symbol");
eos_ = model_->Attribute<int>("eos_symbol");
is_bidecoder_ = model_->Attribute<int>("is_bidirectional_decoder");
forward_encoder_chunk_ = model_->Function("forward_encoder_chunk");
forward_attention_decoder_ = model_->Function("forward_attention_decoder");
ctc_activation_ = model_->Function("ctc_activation");
CHECK(forward_encoder_chunk_.IsValid());
CHECK(forward_attention_decoder_.IsValid());
CHECK(ctc_activation_.IsValid());
LOG(INFO) << "Paddle Model Info: ";
LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
LOG(INFO) << "\tright context " << right_context_;
LOG(INFO) << "\tsos " << sos_;
LOG(INFO) << "\teos " << eos_;
LOG(INFO) << "\tis bidecoder " << is_bidecoder_ << std::endl;
// ignore inner states
}
std::shared_ptr<NnetBase> U2Nnet::Copy() const {
std::shared_ptr<NnetBase> U2Nnet::Clone() const {
auto asr_model = std::make_shared<U2Nnet>(*this);
// reset inner state for new decoding
asr_model->Reset();
......@@ -154,6 +174,7 @@ void U2Nnet::Reset() {
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
encoder_outs_.clear();
VLOG(1) << "FeedForward cost: " << cost_time_ << " sec. ";
VLOG(3) << "u2nnet reset";
}
......@@ -165,23 +186,18 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) {
}
void U2Nnet::FeedForward(const kaldi::Vector<BaseFloat>& features,
void U2Nnet::FeedForward(const std::vector<BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) {
kaldi::Timer timer;
std::vector<kaldi::BaseFloat> chunk_feats(features.Data(),
features.Data() + features.Dim());
std::vector<kaldi::BaseFloat> ctc_probs;
ForwardEncoderChunkImpl(
chunk_feats, feature_dim, &ctc_probs, &out->vocab_dim);
out->logprobs.Resize(ctc_probs.size(), kaldi::kSetZero);
std::memcpy(out->logprobs.Data(),
ctc_probs.data(),
ctc_probs.size() * sizeof(kaldi::BaseFloat));
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
<< chunk_feats.size() / feature_dim << " frames.";
features, feature_dim, &out->logprobs, &out->vocab_dim);
float forward_chunk_time = timer.Elapsed();
VLOG(1) << "FeedForward cost: " << forward_chunk_time << " sec. "
<< features.size() / feature_dim << " frames.";
cost_time_ += forward_chunk_time;
}
......@@ -190,7 +206,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* out_prob,
int32* vocab_dim) {
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event(
"ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1);
#endif
......@@ -210,7 +226,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
// not cache feature in nnet
CHECK_EQ(cached_feats_.size(), 0);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
CHECK_EQ((std::is_same<float, kaldi::BaseFloat>::value), true);
std::memcpy(feats_ptr,
chunk_feats.data(),
chunk_feats.size() * sizeof(kaldi::BaseFloat));
......@@ -218,7 +234,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
VLOG(3) << "feats shape: " << feats.shape()[0] << ", " << feats.shape()[1]
<< ", " << feats.shape()[2];
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("feat", std::ios_base::app | std::ios_base::out);
path << offset_;
......@@ -237,7 +253,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
#endif
// Endocer chunk forward
#ifdef USE_GPU
#ifdef WITH_GPU
feats = feats.copy_to(paddle::GPUPlace(), /*blocking*/ false);
att_cache_ = att_cache_.copy_to(paddle::GPUPlace()), /*blocking*/ false;
cnn_cache_ = cnn_cache_.copy_to(Paddle::GPUPlace(), /*blocking*/ false);
......@@ -254,7 +270,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
std::vector<paddle::Tensor> outputs = forward_encoder_chunk_(inputs);
CHECK_EQ(outputs.size(), 3);
#ifdef USE_GPU
#ifdef WITH_GPU
paddle::Tensor chunk_out = outputs[0].copy_to(paddle::CPUPlace());
att_cache_ = outputs[1].copy_to(paddle::CPUPlace());
cnn_cache_ = outputs[2].copy_to(paddle::CPUPlace());
......@@ -264,7 +280,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
cnn_cache_ = outputs[2];
#endif
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logits",
std::ios_base::app | std::ios_base::out);
......@@ -294,7 +310,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
encoder_outs_.push_back(chunk_out);
VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size();
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logits_list",
std::ios_base::app | std::ios_base::out);
......@@ -313,7 +329,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
}
#endif // end TEST_DEBUG
#ifdef USE_GPU
#ifdef WITH_GPU
#error "Not implementation."
......@@ -327,7 +343,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
CHECK_EQ(outputs.size(), 1);
paddle::Tensor ctc_log_probs = outputs[0];
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logprob",
std::ios_base::app | std::ios_base::out);
......@@ -349,7 +365,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
}
#endif // end TEST_DEBUG
#endif // end USE_GPU
#endif // end WITH_GPU
// Copy to output, (B=1,T,D)
std::vector<int64_t> ctc_log_probs_shape = ctc_log_probs.shape();
......@@ -366,7 +382,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
std::memcpy(
out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat));
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logits_list_ctc",
std::ios_base::app | std::ios_base::out);
......@@ -415,7 +431,7 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob,
void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
#ifdef USE_PROFILING
#ifdef WITH_PROFILING
RecordEvent event("AttentionRescoring", TracerEventType::UserDefined, 1);
#endif
CHECK(rescoring_score != nullptr);
......@@ -457,7 +473,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
}
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_logits_concat",
std::ios_base::app | std::ios_base::out);
......@@ -481,7 +497,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
paddle::Tensor encoder_out = paddle::concat(encoder_outs_, 1);
VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size();
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_out0",
std::ios_base::app | std::ios_base::out);
......@@ -500,7 +516,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
#endif // end TEST_DEBUG
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("encoder_out",
std::ios_base::app | std::ios_base::out);
......@@ -519,7 +535,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
#endif // end TEST_DEBUG
std::vector<paddle::experimental::Tensor> inputs{
std::vector<paddle::Tensor> inputs{
hyps_tensor, hyps_lens, encoder_out};
std::vector<paddle::Tensor> outputs = forward_attention_decoder_(inputs);
CHECK_EQ(outputs.size(), 2);
......@@ -531,7 +547,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
CHECK_EQ(probs_shape[0], num_hyps);
CHECK_EQ(probs_shape[1], max_hyps_len);
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("decoder_logprob",
std::ios_base::app | std::ios_base::out);
......@@ -549,7 +565,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
#endif // end TEST_DEBUG
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("hyps_lens",
std::ios_base::app | std::ios_base::out);
......@@ -565,7 +581,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
#endif // end TEST_DEBUG
#ifdef TEST_DEBUG
#ifdef PPS_DEBUG
{
std::stringstream path("hyps_tensor",
std::ios_base::app | std::ios_base::out);
......@@ -590,7 +606,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} else {
// dump r_probs
CHECK_EQ(r_probs_shape.size(), 1);
CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
//CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
}
// compute rescoring score
......@@ -600,15 +616,15 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
VLOG(2) << "split prob: " << probs_v.size() << " "
<< probs_v[0].shape().size() << " 0: " << probs_v[0].shape()[0]
<< ", " << probs_v[0].shape()[1] << ", " << probs_v[0].shape()[2];
CHECK(static_cast<int>(probs_v.size()) == num_hyps)
<< ": is " << probs_v.size() << " expect: " << num_hyps;
//CHECK(static_cast<int>(probs_v.size()) == num_hyps)
// << ": is " << probs_v.size() << " expect: " << num_hyps;
std::vector<paddle::Tensor> r_probs_v;
if (is_bidecoder_ && reverse_weight > 0) {
r_probs_v = paddle::experimental::split_with_num(r_probs, num_hyps, 0);
CHECK(static_cast<int>(r_probs_v.size()) == num_hyps)
<< "r_probs_v size: is " << r_probs_v.size()
<< " expect: " << num_hyps;
//CHECK(static_cast<int>(r_probs_v.size()) == num_hyps)
// << "r_probs_v size: is " << r_probs_v.size()
// << " expect: " << num_hyps;
}
for (int i = 0; i < num_hyps; ++i) {
......@@ -638,7 +654,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
void U2Nnet::EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const {
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const {
// list of (B=1,T,D)
int size = encoder_outs_.size();
VLOG(3) << "encoder_outs_ size: " << size;
......@@ -650,18 +666,18 @@ void U2Nnet::EncoderOuts(
const int& B = shape[0];
const int& T = shape[1];
const int& D = shape[2];
CHECK(B == 1) << "Only support batch one.";
//CHECK(B == 1) << "Only support batch one.";
VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << ","
<< D << ")";
const float* this_tensor_ptr = item.data<float>();
for (int j = 0; j < T; j++) {
const float* cur = this_tensor_ptr + j * D;
kaldi::Vector<kaldi::BaseFloat> out(D);
std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat));
std::vector<kaldi::BaseFloat> out(D);
std::memcpy(out.data(), cur, D * sizeof(kaldi::BaseFloat));
encoder_out->emplace_back(out);
}
}
}
} // namespace ppspeech
\ No newline at end of file
} // namespace ppspeech
......@@ -18,7 +18,7 @@
#pragma once
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "paddle/extension.h"
#include "paddle/jit/all.h"
......@@ -42,7 +42,7 @@ class U2NnetBase : public NnetBase {
num_left_chunks_ = num_left_chunks;
}
virtual std::shared_ptr<NnetBase> Copy() const = 0;
virtual std::shared_ptr<NnetBase> Clone() const = 0;
protected:
virtual void ForwardEncoderChunkImpl(
......@@ -76,7 +76,7 @@ class U2Nnet : public U2NnetBase {
explicit U2Nnet(const ModelOptions& opts);
U2Nnet(const U2Nnet& other);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) override;
......@@ -91,7 +91,7 @@ class U2Nnet : public U2NnetBase {
std::shared_ptr<paddle::jit::Layer> model() const { return model_; }
std::shared_ptr<NnetBase> Copy() const override;
std::shared_ptr<NnetBase> Clone() const override;
void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
......@@ -111,10 +111,10 @@ class U2Nnet : public U2NnetBase {
void FeedEncoderOuts(const paddle::Tensor& encoder_out);
void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const;
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
ModelOptions opts_; // hack, fix later
private:
ModelOptions opts_;
phi::Place dev_;
std::shared_ptr<paddle::jit::Layer> model_{nullptr};
......@@ -127,6 +127,7 @@ class U2Nnet : public U2NnetBase {
paddle::jit::Function forward_encoder_chunk_;
paddle::jit::Function forward_attention_decoder_;
paddle::jit::Function ctc_activation_;
float cost_time_ = 0.0;
};
} // namespace ppspeech
\ No newline at end of file
......@@ -15,8 +15,8 @@
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "frontend/assembler.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/u2_nnet.h"
......
......@@ -12,16 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef USE_ONNX
#include "nnet/u2_nnet.h"
#else
#include "nnet/u2_onnx_nnet.h"
#endif
#include "base/common.h"
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "frontend/feature_pipeline.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "recognizer/recognizer.h"
#include "nnet/decodable.h"
#include "nnet/nnet_producer.h"
#include "nnet/u2_nnet.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(wav_rspecifier, "", "test wav rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
......@@ -30,76 +42,104 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
ppspeech::RecognizerResource resource =
ppspeech::RecognizerResource::InitFromFlags();
ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int32 num_done = 0, num_err = 0;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0);
CHECK_GT(FLAGS_model_path.size(), 0);
LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier;
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
ppspeech::FeaturePipelineOptions feature_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
feature_opts.assembler_opts.fill_zero = false;
#ifndef USE_ONNX
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
#else
std::shared_ptr<ppspeech::U2OnnxNnet> nnet(new ppspeech::U2OnnxNnet(model_opts));
#endif
std::shared_ptr<ppspeech::FeaturePipeline> feature_pipeline(
new ppspeech::FeaturePipeline(feature_opts));
std::shared_ptr<ppspeech::NnetProducer> nnet_producer(
new ppspeech::NnetProducer(nnet, feature_pipeline));
kaldi::Timer timer;
float tot_wav_duration = 0;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
tot_wav_duration += tot_samples * 1.0 / sample_rate;
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
kaldi::Timer timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
wav_chunk[i] = waveform(sample_offset + i);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer.Accept(wav_chunk);
nnet_producer->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
nnet_producer->SetInputFinished();
}
recognizer.Decode();
// no overlap
sample_offset += cur_chunk_size;
}
std::string result;
result = recognizer.GetFinalResult();
recognizer.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
CHECK(sample_offset == tot_samples);
std::vector<std::vector<kaldi::BaseFloat>> prob_vec;
while (1) {
std::vector<kaldi::BaseFloat> logprobs;
bool isok = nnet_producer->Read(&logprobs);
if (nnet_producer->IsFinished()) break;
if (isok == false) continue;
prob_vec.push_back(logprobs);
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
{
// writer nnet output
kaldi::MatrixIndexT nrow = prob_vec.size();
kaldi::MatrixIndexT ncol = prob_vec[0].size();
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> nnet_out(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx];
}
}
nnet_out_writer.Write(utt, nnet_out);
}
nnet_producer->Reset();
}
nnet_producer->Wait();
double elapsed = timer.Elapsed();
KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done);
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s";
KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration;
LOG(INFO) << "Program cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.cc
#include "nnet/u2_onnx_nnet.h"
#include "common/base/config.h"
namespace ppspeech {
void U2OnnxNnet::LoadModel(const std::string& model_dir) {
std::string encoder_onnx_path = model_dir + "/encoder.onnx";
std::string rescore_onnx_path = model_dir + "/decoder.onnx";
std::string ctc_onnx_path = model_dir + "/ctc.onnx";
std::string param_path = model_dir + "/param.onnx";
// 1. Load sessions
try {
encoder_ = std::make_shared<fastdeploy::Runtime>();
ctc_ = std::make_shared<fastdeploy::Runtime>();
rescore_ = std::make_shared<fastdeploy::Runtime>();
fastdeploy::RuntimeOption runtime_option;
runtime_option.UseOrtBackend();
runtime_option.UseCpu();
runtime_option.SetCpuThreadNum(1);
runtime_option.SetModelPath(encoder_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(encoder_->Init(runtime_option));
runtime_option.SetModelPath(rescore_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(rescore_->Init(runtime_option));
runtime_option.SetModelPath(ctc_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX);
assert(ctc_->Init(runtime_option));
} catch (std::exception const& e) {
LOG(ERROR) << "error when load onnx model: " << e.what();
exit(0);
}
Config conf(param_path);
encoder_output_size_ = conf.Read("output_size", encoder_output_size_);
num_blocks_ = conf.Read("num_blocks", num_blocks_);
head_ = conf.Read("head", head_);
cnn_module_kernel_ = conf.Read("cnn_module_kernel", cnn_module_kernel_);
subsampling_rate_ = conf.Read("subsampling_rate", subsampling_rate_);
right_context_ = conf.Read("right_context", right_context_);
sos_= conf.Read("sos_symbol", sos_);
eos_= conf.Read("eos_symbol", eos_);
is_bidecoder_= conf.Read("is_bidirectional_decoder", is_bidecoder_);
chunk_size_= conf.Read("chunk_size", chunk_size_);
num_left_chunks_ = conf.Read("left_chunks", num_left_chunks_);
LOG(INFO) << "Onnx Model Info:";
LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
LOG(INFO) << "\tnum_blocks " << num_blocks_;
LOG(INFO) << "\thead " << head_;
LOG(INFO) << "\tcnn_module_kernel " << cnn_module_kernel_;
LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
LOG(INFO) << "\tright_context " << right_context_;
LOG(INFO) << "\tsos " << sos_;
LOG(INFO) << "\teos " << eos_;
LOG(INFO) << "\tis bidirectional decoder " << is_bidecoder_;
LOG(INFO) << "\tchunk_size " << chunk_size_;
LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_;
// 3. Read model nodes
LOG(INFO) << "Onnx Encoder:";
GetInputOutputInfo(encoder_, &encoder_in_names_, &encoder_out_names_);
LOG(INFO) << "Onnx CTC:";
GetInputOutputInfo(ctc_, &ctc_in_names_, &ctc_out_names_);
LOG(INFO) << "Onnx Rescore:";
GetInputOutputInfo(rescore_, &rescore_in_names_, &rescore_out_names_);
}
U2OnnxNnet::U2OnnxNnet(const ModelOptions& opts) : opts_(opts) {
LoadModel(opts_.model_path);
}
// shallow copy
U2OnnxNnet::U2OnnxNnet(const U2OnnxNnet& other) {
// metadatas
encoder_output_size_ = other.encoder_output_size_;
num_blocks_ = other.num_blocks_;
head_ = other.head_;
cnn_module_kernel_ = other.cnn_module_kernel_;
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidecoder_ = other.is_bidecoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
offset_ = other.offset_;
// session
encoder_ = other.encoder_;
ctc_ = other.ctc_;
rescore_ = other.rescore_;
// node names
encoder_in_names_ = other.encoder_in_names_;
encoder_out_names_ = other.encoder_out_names_;
ctc_in_names_ = other.ctc_in_names_;
ctc_out_names_ = other.ctc_out_names_;
rescore_in_names_ = other.rescore_in_names_;
rescore_out_names_ = other.rescore_out_names_;
}
void U2OnnxNnet::GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
std::vector<std::string>* in_names, std::vector<std::string>* out_names) {
std::vector<fastdeploy::TensorInfo> inputs_info = runtime->GetInputInfos();
(*in_names).resize(inputs_info.size());
for (int i = 0; i < inputs_info.size(); ++i){
fastdeploy::TensorInfo info = inputs_info[i];
std::stringstream shape;
for(int j = 0; j < info.shape.size(); ++j){
shape << info.shape[j];
shape << " ";
}
LOG(INFO) << "\tInput " << i << " : name=" << info.name << " type=" << info.dtype
<< " dims=" << shape.str();
(*in_names)[i] = info.name;
}
std::vector<fastdeploy::TensorInfo> outputs_info = runtime->GetOutputInfos();
(*out_names).resize(outputs_info.size());
for (int i = 0; i < outputs_info.size(); ++i){
fastdeploy::TensorInfo info = outputs_info[i];
std::stringstream shape;
for(int j = 0; j < info.shape.size(); ++j){
shape << info.shape[j];
shape << " ";
}
LOG(INFO) << "\tOutput " << i << " : name=" << info.name << " type=" << info.dtype
<< " dims=" << shape.str();
(*out_names)[i] = info.name;
}
}
std::shared_ptr<NnetBase> U2OnnxNnet::Clone() const {
auto asr_model = std::make_shared<U2OnnxNnet>(*this);
// reset inner state for new decoding
asr_model->Reset();
return asr_model;
}
void U2OnnxNnet::Reset() {
offset_ = 0;
encoder_outs_.clear();
cached_feats_.clear();
// Reset att_cache
if (num_left_chunks_ > 0) {
int required_cache_size = chunk_size_ * num_left_chunks_;
offset_ = required_cache_size;
att_cache_.resize(num_blocks_ * head_ * required_cache_size *
encoder_output_size_ / head_ * 2,
0.0);
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, required_cache_size,
encoder_output_size_ / head_ * 2};
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
} else {
att_cache_.resize(0, 0.0);
const std::vector<int64_t> att_cache_shape = {num_blocks_, head_, 0,
encoder_output_size_ / head_ * 2};
att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data());
}
// Reset cnn_cache
cnn_cache_.resize(
num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0);
const std::vector<int64_t> cnn_cache_shape = {num_blocks_, 1, encoder_output_size_,
cnn_module_kernel_ - 1};
cnn_cache_ort_.SetExternalData(cnn_cache_shape, fastdeploy::FDDataType::FP32, cnn_cache_.data());
}
void U2OnnxNnet::FeedForward(const std::vector<BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) {
kaldi::Timer timer;
std::vector<kaldi::BaseFloat> ctc_probs;
ForwardEncoderChunkImpl(
features, feature_dim, &out->logprobs, &out->vocab_dim);
VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
<< features.size() / feature_dim << " frames.";
}
void U2OnnxNnet::ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* out_prob,
int32* vocab_dim) {
// 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
// chunk
int num_frames = chunk_feats.size() / feat_dim;
VLOG(3) << "num_frames: " << num_frames;
VLOG(3) << "feat_dim: " << feat_dim;
const int feature_dim = feat_dim;
std::vector<float> feats;
feats.insert(feats.end(), chunk_feats.begin(), chunk_feats.end());
fastdeploy::FDTensor feats_ort;
const std::vector<int64_t> feats_shape = {1, num_frames, feature_dim};
feats_ort.SetExternalData(feats_shape, fastdeploy::FDDataType::FP32, feats.data());
// offset
int64_t offset_int64 = static_cast<int64_t>(offset_);
fastdeploy::FDTensor offset_ort;
offset_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &offset_int64);
// required_cache_size
int64_t required_cache_size = chunk_size_ * num_left_chunks_;
fastdeploy::FDTensor required_cache_size_ort("");
required_cache_size_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &required_cache_size);
// att_mask
fastdeploy::FDTensor att_mask_ort;
std::vector<uint8_t> att_mask(required_cache_size + chunk_size_, 1);
if (num_left_chunks_ > 0) {
int chunk_idx = offset_ / chunk_size_ - num_left_chunks_;
if (chunk_idx < num_left_chunks_) {
for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) {
att_mask[i] = 0;
}
}
const std::vector<int64_t> att_mask_shape = {1, 1, required_cache_size + chunk_size_};
att_mask_ort.SetExternalData(att_mask_shape, fastdeploy::FDDataType::BOOL, reinterpret_cast<bool*>(att_mask.data()));
}
// 2. Encoder chunk forward
std::vector<fastdeploy::FDTensor> inputs(encoder_in_names_.size());
for (int i = 0; i < encoder_in_names_.size(); ++i) {
std::string name = encoder_in_names_[i];
if (!strcmp(name.data(), "chunk")) {
inputs[i] = std::move(feats_ort);
inputs[i].name = "chunk";
} else if (!strcmp(name.data(), "offset")) {
inputs[i] = std::move(offset_ort);
inputs[i].name = "offset";
} else if (!strcmp(name.data(), "required_cache_size")) {
inputs[i] = std::move(required_cache_size_ort);
inputs[i].name = "required_cache_size";
} else if (!strcmp(name.data(), "att_cache")) {
inputs[i] = std::move(att_cache_ort_);
inputs[i].name = "att_cache";
} else if (!strcmp(name.data(), "cnn_cache")) {
inputs[i] = std::move(cnn_cache_ort_);
inputs[i].name = "cnn_cache";
} else if (!strcmp(name.data(), "att_mask")) {
inputs[i] = std::move(att_mask_ort);
inputs[i].name = "att_mask";
}
}
std::vector<fastdeploy::FDTensor> ort_outputs;
assert(encoder_->Infer(inputs, &ort_outputs));
offset_ += static_cast<int>(ort_outputs[0].shape[1]);
att_cache_ort_ = std::move(ort_outputs[1]);
cnn_cache_ort_ = std::move(ort_outputs[2]);
std::vector<fastdeploy::FDTensor> ctc_inputs;
ctc_inputs.emplace_back(std::move(ort_outputs[0]));
// ctc_inputs[0] = std::move(ort_outputs[0]);
ctc_inputs[0].name = ctc_in_names_[0];
std::vector<fastdeploy::FDTensor> ctc_ort_outputs;
assert(ctc_->Infer(ctc_inputs, &ctc_ort_outputs));
encoder_outs_.emplace_back(std::move(ctc_inputs[0])); // *****
float* logp_data = reinterpret_cast<float*>(ctc_ort_outputs[0].Data());
// Copy to output, (B=1,T,D)
std::vector<int64_t> ctc_log_probs_shape = ctc_ort_outputs[0].shape;
CHECK_EQ(ctc_log_probs_shape.size(), 3);
int B = ctc_log_probs_shape[0];
CHECK_EQ(B, 1);
int T = ctc_log_probs_shape[1];
int D = ctc_log_probs_shape[2];
*vocab_dim = D;
out_prob->resize(T * D);
std::memcpy(
out_prob->data(), logp_data, T * D * sizeof(kaldi::BaseFloat));
return;
}
float U2OnnxNnet::ComputeAttentionScore(const float* prob,
const std::vector<int>& hyp, int eos,
int decode_out_len) {
float score = 0.0f;
for (size_t j = 0; j < hyp.size(); ++j) {
score += *(prob + j * decode_out_len + hyp[j]);
}
score += *(prob + hyp.size() * decode_out_len + eos);
return score;
}
void U2OnnxNnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
rescoring_score->resize(num_hyps, 0.0f);
if (num_hyps == 0) {
return;
}
// No encoder output
if (encoder_outs_.size() == 0) {
return;
}
std::vector<int64_t> hyps_lens;
int max_hyps_len = 0;
for (size_t i = 0; i < num_hyps; ++i) {
int length = hyps[i].size() + 1;
max_hyps_len = std::max(length, max_hyps_len);
hyps_lens.emplace_back(static_cast<int64_t>(length));
}
std::vector<float> rescore_input;
int encoder_len = 0;
for (int i = 0; i < encoder_outs_.size(); i++) {
float* encoder_outs_data = reinterpret_cast<float*>(encoder_outs_[i].Data());
for (int j = 0; j < encoder_outs_[i].Numel(); j++) {
rescore_input.emplace_back(encoder_outs_data[j]);
}
encoder_len += encoder_outs_[i].shape[1];
}
std::vector<int64_t> hyps_pad;
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
hyps_pad.emplace_back(sos_);
size_t j = 0;
for (; j < hyp.size(); ++j) {
hyps_pad.emplace_back(hyp[j]);
}
if (j == max_hyps_len - 1) {
continue;
}
for (; j < max_hyps_len - 1; ++j) {
hyps_pad.emplace_back(0);
}
}
const std::vector<int64_t> hyps_pad_shape = {num_hyps, max_hyps_len};
const std::vector<int64_t> hyps_lens_shape = {num_hyps};
const std::vector<int64_t> decode_input_shape = {1, encoder_len, encoder_output_size_};
fastdeploy::FDTensor hyps_pad_tensor_;
hyps_pad_tensor_.SetExternalData(hyps_pad_shape, fastdeploy::FDDataType::INT64, hyps_pad.data());
fastdeploy::FDTensor hyps_lens_tensor_;
hyps_lens_tensor_.SetExternalData(hyps_lens_shape, fastdeploy::FDDataType::INT64, hyps_lens.data());
fastdeploy::FDTensor decode_input_tensor_;
decode_input_tensor_.SetExternalData(decode_input_shape, fastdeploy::FDDataType::FP32, rescore_input.data());
std::vector<fastdeploy::FDTensor> rescore_inputs(3);
rescore_inputs[0] = std::move(hyps_pad_tensor_);
rescore_inputs[0].name = rescore_in_names_[0];
rescore_inputs[1] = std::move(hyps_lens_tensor_);
rescore_inputs[1].name = rescore_in_names_[1];
rescore_inputs[2] = std::move(decode_input_tensor_);
rescore_inputs[2].name = rescore_in_names_[2];
std::vector<fastdeploy::FDTensor> rescore_outputs;
assert(rescore_->Infer(rescore_inputs, &rescore_outputs));
float* decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[0].Data());
float* r_decoder_outs_data = reinterpret_cast<float*>(rescore_outputs[1].Data());
int decode_out_len = rescore_outputs[0].shape[2];
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
float score = 0.0f;
// left to right decoder score
score = ComputeAttentionScore(
decoder_outs_data + max_hyps_len * decode_out_len * i, hyp, eos_,
decode_out_len);
// Optional: Used for right to left score
float r_score = 0.0f;
if (is_bidecoder_ && reverse_weight > 0) {
std::vector<int> r_hyp(hyp.size());
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
// right to left decoder score
r_score = ComputeAttentionScore(
r_decoder_outs_data + max_hyps_len * decode_out_len * i, r_hyp, eos_,
decode_out_len);
}
// combined left-to-right and right-to-left score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
}
}
void U2OnnxNnet::EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const {
}
} //namepace ppspeech
\ No newline at end of file
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -11,87 +12,86 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h
#pragma once
#include <numeric>
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
#include "paddle_inference_api.h"
#include "nnet/u2_nnet.h"
#include "fastdeploy/runtime.h"
namespace ppspeech {
class U2OnnxNnet : public U2NnetBase {
template <typename T>
class Tensor {
public:
Tensor() {}
explicit Tensor(const std::vector<int>& shape) : _shape(shape) {
int neml = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
LOG(INFO) << "Tensor neml: " << neml;
_data.resize(neml, 0);
}
void reshape(const std::vector<int>& shape) {
_shape = shape;
int neml = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
_data.resize(neml, 0);
}
const std::vector<int>& get_shape() const { return _shape; }
std::vector<T>& get_data() { return _data; }
private:
std::vector<int> _shape;
std::vector<T> _data;
};
explicit U2OnnxNnet(const ModelOptions& opts);
U2OnnxNnet(const U2OnnxNnet& other);
class PaddleNnet : public NnetBase {
public:
explicit PaddleNnet(const ModelOptions& opts);
void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features,
void FeedForward(const std::vector<kaldi::BaseFloat>& features,
const int32& feature_dim,
NnetOut* out) override;
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override {
VLOG(2) << "deepspeech2 not has AttentionRescoring.";
}
void Reset() override;
bool IsLogProb() override { return true; }
void Dim();
void Reset() override;
void LoadModel(const std::string& model_dir);
bool IsLogProb() override { return false; }
std::shared_ptr<NnetBase> Clone() const override;
void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) override;
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
int eos, int decode_out_len);
void InitCacheEncouts(const ModelOptions& opts);
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
void EncoderOuts(std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out)
const override {}
void EncoderOuts(
std::vector<std::vector<kaldi::BaseFloat>>* encoder_out) const;
void GetInputOutputInfo(const std::shared_ptr<fastdeploy::Runtime>& runtime,
std::vector<std::string>* in_names,
std::vector<std::string>* out_names);
private:
paddle_infer::Predictor* GetPredictor();
int ReleasePredictor(paddle_infer::Predictor* predictor);
ModelOptions opts_;
std::unique_ptr<paddle_infer::services::PredictorPool> pool;
std::vector<bool> pool_usages;
std::mutex pool_mutex;
std::map<paddle_infer::Predictor*, int> predictor_to_thread_id;
std::map<std::string, int> cache_names_idx_;
std::vector<std::shared_ptr<Tensor<kaldi::BaseFloat>>> cache_encouts_;
int encoder_output_size_ = 0;
int num_blocks_ = 0;
int cnn_module_kernel_ = 0;
int head_ = 0;
ModelOptions opts_;
// sessions
std::shared_ptr<fastdeploy::Runtime> encoder_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> rescore_ = nullptr;
std::shared_ptr<fastdeploy::Runtime> ctc_ = nullptr;
public:
DISALLOW_COPY_AND_ASSIGN(PaddleNnet);
// node names
std::vector<std::string> encoder_in_names_, encoder_out_names_;
std::vector<std::string> ctc_in_names_, ctc_out_names_;
std::vector<std::string> rescore_in_names_, rescore_out_names_;
// caches
fastdeploy::FDTensor att_cache_ort_;
fastdeploy::FDTensor cnn_cache_ort_;
std::vector<fastdeploy::FDTensor> encoder_outs_;
std::vector<float> att_cache_;
std::vector<float> cnn_cache_;
};
} // namespace ppspeech
} // namespace ppspeech
\ No newline at end of file
set(srcs)
list(APPEND srcs
recognizer_controller.cc
recognizer_controller_impl.cc
recognizer_instance.cc
recognizer.cc
)
add_library(recognizer STATIC ${srcs})
target_link_libraries(recognizer PUBLIC decoder)
set(TEST_BINS
recognizer_batch_main
recognizer_batch_main2
recognizer_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl)
endforeach()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,58 +13,34 @@
// limitations under the License.
#include "recognizer/recognizer.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
Recognizer::Recognizer(const RecognizerResource& resource) {
// resource_ = resource;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
BaseFloat ac_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
decoder_.reset(new TLGDecoder(resource.tlg_opts));
input_finished_ = false;
}
void Recognizer::Accept(const Vector<BaseFloat>& waves) {
feature_pipeline_->Accept(waves);
#include "recognizer/recognizer_instance.h"
bool InitRecognizer(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance) {
return ppspeech::RecognizerInstance::GetInstance().Init(model_file,
word_symbol_table_file,
fst_file,
num_instance);
}
void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); }
std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
int GetRecognizerInstanceId() {
return ppspeech::RecognizerInstance::GetInstance().GetRecognizerInstanceId();
}
std::string Recognizer::GetPartialResult() {
return decoder_->GetPartialResult();
void InitDecoder(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().InitDecoder(instance_id);
}
void Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
void AcceptData(const std::vector<float>& waves, int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().Accept(waves, instance_id);
}
bool Recognizer::IsFinished() { return input_finished_; }
void Recognizer::Reset() {
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
void SetInputFinished(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().SetInputFinished(instance_id);
}
} // namespace ppspeech
\ No newline at end of file
std::string GetFinalResult(int instance_id) {
return ppspeech::RecognizerInstance::GetInstance().GetResult(instance_id);
}
\ No newline at end of file
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
bool InitRecognizer(const std::string& model_file,
const std::string& word_symbol_table_file,
const std::string& fst_file,
int num_instance);
int GetRecognizerInstanceId();
void InitDecoder(int instance_id);
void AcceptData(const std::vector<float>& waves, int instance_id);
void SetInputFinished(int instance_id);
std::string GetFinalResult(int instance_id);
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer_controller.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(njob, 3, "njob");
using std::string;
using std::vector;
void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists,
int njob) {
vector<string> wavlist;
wavlists->resize(njob);
uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
string utt_str = wavlist[idx];
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]);
}
}
void recognizer_func(ppspeech::RecognizerController* recognizer_controller,
std::vector<string> wavlist,
std::vector<string> uttlist,
std::vector<string>* results) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return;
results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
std::string utt = uttlist[idx];
std::string wav_file = wavlist[idx];
std::ifstream infile;
infile.open(wav_file, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 recog_id = -1;
while (recog_id == -1) {
recog_id = recognizer_controller->GetRecognizerInstanceId();
}
recognizer_controller->InitDecoder(recog_id);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
recognizer_controller->Accept(wav_chunk, recog_id);
// no overlap
sample_offset += cur_chunk_size;
}
recognizer_controller->SetInputFinished(recog_id);
CHECK(sample_offset == tot_samples);
std::string result = recognizer_controller->GetFinalResult(recog_id);
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
result = " ";
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
results->push_back(result);
++num_done;
}
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int njob = FLAGS_njob;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
ppspeech::RecognizerResource resource =
ppspeech::RecognizerResource::InitFromFlags();
ppspeech::RecognizerController recognizer_controller(njob, resource);
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;
vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist;
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func,
&recognizer_controller,
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f));
}
for (size_t i = 0; i < njob; ++i) {
futurelist[i].get();
}
for (size_t idx = 0; idx < njob; ++idx) {
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
string utt = uttlist[idx][utt_idx];
string result = resultlist[idx][utt_idx];
result_writer.Write(utt, result);
}
}
return 0;
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common/base/thread_pool.h"
#include "common/utils/file_utils.h"
#include "common/utils/strings.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(njob, 3, "njob");
using std::string;
using std::vector;
void SplitUtt(string wavlist_file,
vector<vector<string>>* uttlists,
vector<vector<string>>* wavlists,
int njob) {
vector<string> wavlist;
wavlists->resize(njob);
uttlists->resize(njob);
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
string utt_str = wavlist[idx];
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
LOG(INFO) << utt_wav[0];
CHECK_EQ(utt_wav.size(), size_t(2));
uttlists->at(idx % njob).push_back(utt_wav[0]);
wavlists->at(idx % njob).push_back(utt_wav[1]);
}
}
void recognizer_func(std::vector<string> wavlist,
std::vector<string> uttlist,
std::vector<string>* results) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
if (wavlist.empty()) return;
results->reserve(wavlist.size());
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
std::string utt = uttlist[idx];
std::string wav_file = wavlist[idx];
std::ifstream infile;
infile.open(wav_file, std::ifstream::in);
kaldi::WaveData wave_data;
wave_data.Read(infile);
int32 recog_id = -1;
while (recog_id == -1) {
recog_id = GetRecognizerInstanceId();
}
InitDecoder(recog_id);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
AcceptData(wav_chunk, recog_id);
// no overlap
sample_offset += cur_chunk_size;
}
SetInputFinished(recog_id);
CHECK(sample_offset == tot_samples);
std::string result = GetFinalResult(recog_id);
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
LOG(INFO) << " the result of " << utt << " is empty";
result = " ";
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
results->push_back(result);
++num_done;
}
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int njob = FLAGS_njob;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob);
ThreadPool threadpool(njob);
vector<vector<string>> wavlist;
vector<vector<string>> uttlist;
vector<vector<string>> resultlist(njob);
vector<std::future<void>> futurelist;
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
for (size_t i = 0; i < njob; ++i) {
std::future<void> f = threadpool.enqueue(recognizer_func,
wavlist[i],
uttlist[i],
&resultlist[i]);
futurelist.push_back(std::move(f));
}
for (size_t i = 0; i < njob; ++i) {
futurelist[i].get();
}
for (size_t idx = 0; idx < njob; ++idx) {
for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) {
string utt = uttlist[idx][utt_idx];
string result = resultlist[idx][utt_idx];
result_writer.Write(utt, result);
}
}
return 0;
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/recognizer_controller.h"
#include "nnet/u2_nnet.h"
namespace ppspeech {
RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) {
recognizer_workers.resize(num_worker);
for (size_t i = 0; i < num_worker; ++i) {
recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource));
waiting_workers.push(i);
}
}
int RecognizerController::GetRecognizerInstanceId() {
if (waiting_workers.empty()) {
return -1;
}
int idx = -1;
{
std::unique_lock<std::mutex> lock(mutex_);
idx = waiting_workers.front();
waiting_workers.pop();
}
return idx;
}
RecognizerController::~RecognizerController() {
for (size_t i = 0; i < recognizer_workers.size(); ++i) {
recognizer_workers[i]->WaitFinished();
}
}
void RecognizerController::InitDecoder(int idx) {
recognizer_workers[idx]->InitDecoder();
}
std::string RecognizerController::GetFinalResult(int idx) {
recognizer_workers[idx]->WaitDecoderFinished();
recognizer_workers[idx]->AttentionRescoring();
std::string result = recognizer_workers[idx]->GetFinalResult();
{
std::unique_lock<std::mutex> lock(mutex_);
waiting_workers.push(idx);
}
return result;
}
void RecognizerController::Accept(std::vector<float> data, int idx) {
recognizer_workers[idx]->Accept(data);
}
void RecognizerController::SetInputFinished(int idx) {
recognizer_workers[idx]->SetInputFinished();
}
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include <memory>
#include "recognizer/recognizer_controller_impl.h"
namespace ppspeech {
class RecognizerController {
public:
explicit RecognizerController(int num_worker, RecognizerResource resource);
~RecognizerController();
int GetRecognizerInstanceId();
void InitDecoder(int idx);
void Accept(std::vector<float> data, int idx);
void SetInputFinished(int idx);
std::string GetFinalResult(int idx);
private:
std::queue<int> waiting_workers;
std::mutex mutex_;
std::vector<std::unique_ptr<ppspeech::RecognizerControllerImpl>> recognizer_workers;
DISALLOW_COPY_AND_ASSIGN(RecognizerController);
};
}
\ No newline at end of file
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,86 +14,180 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
#include "recognizer/recognizer_controller_impl.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "common/utils/strings.h"
namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr;
using std::vector;
U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
: opts_(resource) {
RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource)
: opts_(resource) {
BaseFloat am_scale = resource.acoustic_scale;
BaseFloat blank_threshold = resource.blank_threshold;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<FeaturePipeline> feature_pipeline(
new FeaturePipeline(feature_opts));
std::shared_ptr<NnetBase> nnet;
#ifndef USE_ONNX
nnet = resource.nnet->Clone();
#else
if (resource.model_opts.with_onnx_model){
nnet.reset(new U2OnnxNnet(resource.model_opts));
} else {
nnet = resource.nnet->Clone();
}
#endif
nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline, blank_threshold));
nnet_thread_ = std::thread(RunNnetEvaluation, this);
decodable_.reset(new Decodable(nnet_producer_, am_scale));
if (resource.decoder_opts.tlg_decoder_opts.fst_path.empty()) {
LOG(INFO) << "Init PrefixBeamSearch Decoder";
decoder_ = std::make_unique<CTCPrefixBeamSearch>(
resource.decoder_opts.ctc_prefix_search_opts);
} else {
LOG(INFO) << "Init TLGDecoder";
decoder_ = std::make_unique<TLGDecoder>(
resource.decoder_opts.tlg_decoder_opts);
}
std::shared_ptr<NnetBase> nnet(new U2Nnet(resource.model_opts));
symbol_table_ = decoder_->WordSymbolTable();
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
result_.clear();
}
BaseFloat am_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));
RecognizerControllerImpl::~RecognizerControllerImpl() {
WaitFinished();
}
CHECK_NE(resource.vocab_path, "");
decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
void RecognizerControllerImpl::Reset() {
nnet_producer_->Reset();
}
unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_;
void RecognizerControllerImpl::RunDecoder(RecognizerControllerImpl* me) {
me->RunDecoderInternal();
}
input_finished_ = false;
void RecognizerControllerImpl::RunDecoderInternal() {
LOG(INFO) << "DecoderInternal begin";
while (!nnet_producer_->IsFinished()) {
nnet_condition_.notify_one();
decoder_->AdvanceDecode(decodable_);
}
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
LOG(INFO) << "DecoderInternal exit";
}
Reset();
void RecognizerControllerImpl::WaitDecoderFinished() {
if (decoder_thread_.joinable()) decoder_thread_.join();
}
void U2Recognizer::Reset() {
global_frame_offset_ = 0;
num_frames_ = 0;
result_.clear();
void RecognizerControllerImpl::RunNnetEvaluation(RecognizerControllerImpl* me) {
me->RunNnetEvaluationInternal();
}
decodable_->Reset();
decoder_->Reset();
void RecognizerControllerImpl::SetInputFinished() {
nnet_producer_->SetInputFinished();
nnet_condition_.notify_one();
LOG(INFO) << "Set Input Finished";
}
void U2Recognizer::ResetContinuousDecoding() {
global_frame_offset_ = num_frames_;
void RecognizerControllerImpl::WaitFinished() {
abort_ = true;
LOG(INFO) << "nnet wait finished";
nnet_condition_.notify_one();
if (nnet_thread_.joinable()) {
nnet_thread_.join();
}
}
void RecognizerControllerImpl::RunNnetEvaluationInternal() {
bool result = false;
LOG(INFO) << "NnetEvaluationInteral begin";
while (!abort_) {
std::unique_lock<std::mutex> lock(nnet_mutex_);
nnet_condition_.wait(lock);
do {
result = nnet_producer_->Compute();
decoder_condition_.notify_one();
} while (result);
}
LOG(INFO) << "NnetEvaluationInteral exit";
}
void RecognizerControllerImpl::Accept(std::vector<float> data) {
nnet_producer_->Accept(data);
nnet_condition_.notify_one();
}
void RecognizerControllerImpl::InitDecoder() {
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
result_.clear();
decodable_->Reset();
decoder_->Reset();
decoder_thread_ = std::thread(RunDecoder, this);
}
void RecognizerControllerImpl::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(false);
void U2Recognizer::Accept(const VectorBase<BaseFloat>& waves) {
kaldi::Timer timer;
feature_pipeline_->Accept(waves);
VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim()
<< " samples.";
}
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
void U2Recognizer::Decode() {
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
}
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
void U2Recognizer::Rescoring() {
// Do attention Rescoring
AttentionRescoring();
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
void U2Recognizer::UpdateResult(bool finish) {
std::string RecognizerControllerImpl::GetFinalResult() { return result_[0].sentence; }
std::string RecognizerControllerImpl::GetPartialResult() { return result_[0].sentence; }
void RecognizerControllerImpl::UpdateResult(bool finish) {
const auto& hypotheses = decoder_->Outputs();
const auto& inputs = decoder_->Inputs();
const auto& likelihood = decoder_->Likelihood();
const auto& times = decoder_->Times();
result_.clear();
CHECK_EQ(hypotheses.size(), likelihood.size());
CHECK_EQ(inputs.size(), likelihood.size());
for (size_t i = 0; i < hypotheses.size(); i++) {
const std::vector<int>& hypothesis = hypotheses[i];
......@@ -99,21 +195,16 @@ void U2Recognizer::UpdateResult(bool finish) {
path.score = likelihood[i];
for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (decoder_->Type() == kWfstBeamSearch) {
path.sentence += (" " + word);
} else {
path.sentence += (word);
}
path.sentence += (" " + word);
}
path.sentence = DelBlank(path.sentence);
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we
// use time stamp of the input(e2e model unit), which is more accurate,
// and it requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) {
if (symbol_table_ != nullptr && finish) {
int offset = global_frame_offset_ * FrameShiftInMs();
const std::vector<int>& input = inputs[i];
......@@ -121,7 +212,7 @@ void U2Recognizer::UpdateResult(bool finish) {
CHECK_EQ(input.size(), time_stamp.size());
for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]);
std::string word = symbol_table_->Find(input[j]);
int start =
time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
......@@ -163,56 +254,4 @@ void U2Recognizer::UpdateResult(bool finish) {
}
}
void U2Recognizer::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(true);
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {
LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!";
return;
}
LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!";
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = decoder_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
std::vector<float> rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
result_[i].score =
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; }
void U2Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
} // namespace ppspeech
\ No newline at end of file
} // namespace ppspeech
此差异已折叠。
......@@ -10,4 +10,4 @@ target_link_libraries(websocket_server_main PUBLIC fst websocket ${DEPS})
add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc)
target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS})
target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS})
\ No newline at end of file
# add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND")
add_definitions("-DUSE_ORT_BACKEND")
add_subdirectory(nnet)
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。