Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
41feecbd
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41feecbd
编写于
3月 07, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format
上级
11dc485d
变更
32
显示空白变更内容
内联
并排
Showing
32 changed file
with
2112 addition
and
1653 deletion
+2112
-1653
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-2
paddlespeech/s2t/io/sampler.py
paddlespeech/s2t/io/sampler.py
+1
-1
paddlespeech/t2s/modules/transformer/repeat.py
paddlespeech/t2s/modules/transformer/repeat.py
+1
-1
speechx/examples/README.md
speechx/examples/README.md
+0
-1
speechx/examples/decoder/offline-decoder-main.cc
speechx/examples/decoder/offline-decoder-main.cc
+55
-38
speechx/examples/feat/feature-mfcc-test.cc
speechx/examples/feat/feature-mfcc-test.cc
+645
-611
speechx/examples/feat/linear-spectrogram-main.cc
speechx/examples/feat/linear-spectrogram-main.cc
+231
-99
speechx/examples/nnet/pp-model-test.cc
speechx/examples/nnet/pp-model-test.cc
+78
-49
speechx/speechx/base/basic_types.h
speechx/speechx/base/basic_types.h
+27
-27
speechx/speechx/base/common.h
speechx/speechx/base/common.h
+5
-5
speechx/speechx/base/macros.h
speechx/speechx/base/macros.h
+2
-2
speechx/speechx/base/thread_pool.h
speechx/speechx/base/thread_pool.h
+41
-51
speechx/speechx/decoder/common.h
speechx/speechx/decoder/common.h
+17
-3
speechx/speechx/decoder/ctc_beam_search_decoder.cc
speechx/speechx/decoder/ctc_beam_search_decoder.cc
+231
-219
speechx/speechx/decoder/ctc_beam_search_decoder.h
speechx/speechx/decoder/ctc_beam_search_decoder.h
+57
-40
speechx/speechx/frontend/fbank.h
speechx/speechx/frontend/fbank.h
+5
-4
speechx/speechx/frontend/feature_extractor_controller.h
speechx/speechx/frontend/feature_extractor_controller.h
+14
-0
speechx/speechx/frontend/feature_extractor_controller_impl.h
speechx/speechx/frontend/feature_extractor_controller_impl.h
+14
-0
speechx/speechx/frontend/feature_extractor_interface.h
speechx/speechx/frontend/feature_extractor_interface.h
+2
-1
speechx/speechx/frontend/linear_spectrogram.cc
speechx/speechx/frontend/linear_spectrogram.cc
+104
-104
speechx/speechx/frontend/linear_spectrogram.h
speechx/speechx/frontend/linear_spectrogram.h
+24
-11
speechx/speechx/frontend/normalizer.cc
speechx/speechx/frontend/normalizer.cc
+143
-125
speechx/speechx/frontend/normalizer.h
speechx/speechx/frontend/normalizer.h
+33
-13
speechx/speechx/frontend/window.h
speechx/speechx/frontend/window.h
+0
-1
speechx/speechx/nnet/decodable-itf.h
speechx/speechx/nnet/decodable-itf.h
+72
-43
speechx/speechx/nnet/decodable.cc
speechx/speechx/nnet/decodable.cc
+34
-28
speechx/speechx/nnet/decodable.h
speechx/speechx/nnet/decodable.h
+23
-6
speechx/speechx/nnet/nnet_interface.h
speechx/speechx/nnet/nnet_interface.h
+15
-2
speechx/speechx/nnet/paddle_nnet.cc
speechx/speechx/nnet/paddle_nnet.cc
+123
-102
speechx/speechx/nnet/paddle_nnet.h
speechx/speechx/nnet/paddle_nnet.h
+83
-60
speechx/speechx/utils/file_utils.cc
speechx/speechx/utils/file_utils.cc
+14
-1
speechx/speechx/utils/file_utils.h
speechx/speechx/utils/file_utils.h
+16
-3
未找到文件。
.pre-commit-config.yaml
浏览文件 @
41feecbd
...
@@ -50,13 +50,13 @@ repos:
...
@@ -50,13 +50,13 @@ repos:
entry
:
bash .pre-commit-hooks/clang-format.hook -i
entry
:
bash .pre-commit-hooks/clang-format.hook -i
language
:
system
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude
:
(?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude
:
(?=speechx/speechx/kaldi
|speechx/patch
).*(\.cpp|\.cc|\.h|\.py)$
-
id
:
copyright_checker
-
id
:
copyright_checker
name
:
copyright_checker
name
:
copyright_checker
entry
:
python .pre-commit-hooks/copyright-check.hook
entry
:
python .pre-commit-hooks/copyright-check.hook
language
:
system
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude
:
(?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude
:
(?=third_party|pypinyin|speechx/speechx/kaldi
|speechx/patch
).*(\.cpp|\.cc|\.h|\.py)$
-
repo
:
https://github.com/asottile/reorder_python_imports
-
repo
:
https://github.com/asottile/reorder_python_imports
rev
:
v2.4.0
rev
:
v2.4.0
hooks
:
hooks
:
...
...
paddlespeech/s2t/io/sampler.py
浏览文件 @
41feecbd
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
"""
rng
=
np
.
random
.
RandomState
(
epoch
)
rng
=
np
.
random
.
RandomState
(
epoch
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
rng
.
shuffle
(
batch_indices
)
rng
.
shuffle
(
batch_indices
)
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
assert
clipped
is
False
assert
clipped
is
False
...
...
paddlespeech/t2s/modules/transformer/repeat.py
浏览文件 @
41feecbd
...
@@ -36,4 +36,4 @@ def repeat(N, fn):
...
@@ -36,4 +36,4 @@ def repeat(N, fn):
Returns:
Returns:
MultiSequential: Repeated model instance.
MultiSequential: Repeated model instance.
"""
"""
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)])
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)])
speechx/examples/README.md
浏览文件 @
41feecbd
...
@@ -3,4 +3,3 @@
...
@@ -3,4 +3,3 @@
*
decoder - offline decoder
*
decoder - offline decoder
*
feat - mfcc, linear
*
feat - mfcc, linear
*
nnet - ds2 nn
*
nnet - ds2 nn
speechx/examples/decoder/offline-decoder-main.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "nnet/paddle_nnet.h"
#include "nnet/decodable.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string
(
feature_respecifier
,
""
,
"test nnet prob"
);
DEFINE_string
(
feature_respecifier
,
""
,
"test nnet prob"
);
...
@@ -13,7 +27,7 @@ using kaldi::BaseFloat;
...
@@ -13,7 +27,7 @@ using kaldi::BaseFloat;
using
kaldi
::
Matrix
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
using
std
::
vector
;
//
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
//
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
...
@@ -23,7 +37,8 @@ int main(int argc, char* argv[]) {
...
@@ -23,7 +37,8 @@ int main(int argc, char* argv[]) {
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InitGoogleLogging
(
argv
[
0
]);
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_respecifier
);
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_respecifier
);
// test nnet_output --> decoder result
// test nnet_output --> decoder result
int32
num_done
=
0
,
num_err
=
0
;
int32
num_done
=
0
,
num_err
=
0
;
...
@@ -32,11 +47,13 @@ int main(int argc, char* argv[]) {
...
@@ -32,11 +47,13 @@ int main(int argc, char* argv[]) {
ppspeech
::
CTCBeamSearch
decoder
(
opts
);
ppspeech
::
CTCBeamSearch
decoder
(
opts
);
ppspeech
::
ModelOptions
model_opts
;
ppspeech
::
ModelOptions
model_opts
;
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
new
ppspeech
::
PaddleNnet
(
model_opts
));
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
new
ppspeech
::
PaddleNnet
(
model_opts
));
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
));
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
));
//
int32 chunk_size = 35;
//
int32 chunk_size = 35;
decoder
.
InitDecoder
();
decoder
.
InitDecoder
();
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
string
utt
=
feature_reader
.
Key
();
string
utt
=
feature_reader
.
Key
();
...
...
speechx/examples/feat/feature-mfcc-test.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// feat/feature-mfcc-test.cc
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
...
@@ -20,17 +34,15 @@
...
@@ -20,17 +34,15 @@
#include <iostream>
#include <iostream>
#include "feat/feature-mfcc.h"
#include "base/kaldi-math.h"
#include "base/kaldi-math.h"
#include "
matrix/kaldi-matrix-inl
.h"
#include "
feat/feature-mfcc
.h"
#include "feat/wave-reader.h"
#include "feat/wave-reader.h"
#include "matrix/kaldi-matrix-inl.h"
using
namespace
kaldi
;
using
namespace
kaldi
;
static
void
UnitTestReadWave
()
{
static
void
UnitTestReadWave
()
{
std
::
cout
<<
"=== UnitTestReadWave() ===
\n
"
;
std
::
cout
<<
"=== UnitTestReadWave() ===
\n
"
;
Vector
<
BaseFloat
>
v
,
v2
;
Vector
<
BaseFloat
>
v
,
v2
;
...
@@ -47,15 +59,15 @@ static void UnitTestReadWave() {
...
@@ -47,15 +59,15 @@ static void UnitTestReadWave() {
v
.
CopyFromVec
(
data
.
Row
(
0
));
v
.
CopyFromVec
(
data
.
Row
(
0
));
}
}
std
::
cout
<<
"<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab
\n
"
;
std
::
cout
std
::
ifstream
input
(
<<
"<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab
\n
"
;
"test_data/test_matlab.ascii"
std
::
ifstream
input
(
"test_data/test_matlab.ascii"
);
);
KALDI_ASSERT
(
input
.
good
());
KALDI_ASSERT
(
input
.
good
());
v2
.
Read
(
input
,
false
);
v2
.
Read
(
input
,
false
);
input
.
close
();
input
.
close
();
std
::
cout
<<
"<<<=== Comparing freshly read waveform to 'libsndfile' waveform
\n
"
;
std
::
cout
<<
"<<<=== Comparing freshly read waveform to 'libsndfile' waveform
\n
"
;
KALDI_ASSERT
(
v
.
Dim
()
==
v2
.
Dim
());
KALDI_ASSERT
(
v
.
Dim
()
==
v2
.
Dim
());
for
(
int32
i
=
0
;
i
<
v
.
Dim
();
i
++
)
{
for
(
int32
i
=
0
;
i
<
v
.
Dim
();
i
++
)
{
KALDI_ASSERT
(
v
(
i
)
==
v2
(
i
));
KALDI_ASSERT
(
v
(
i
)
==
v2
(
i
));
...
@@ -66,11 +78,9 @@ static void UnitTestReadWave() {
...
@@ -66,11 +78,9 @@ static void UnitTestReadWave() {
// std::cout << v;
// std::cout << v;
std
::
cout
<<
"Test passed :)
\n\n
"
;
std
::
cout
<<
"Test passed :)
\n\n
"
;
}
}
/**
/**
*/
*/
static
void
UnitTestSimple
()
{
static
void
UnitTestSimple
()
{
...
@@ -81,7 +91,7 @@ static void UnitTestSimple() {
...
@@ -81,7 +91,7 @@ static void UnitTestSimple() {
// init with noise
// init with noise
for
(
int32
i
=
0
;
i
<
v
.
Dim
();
i
++
)
{
for
(
int32
i
=
0
;
i
<
v
.
Dim
();
i
++
)
{
v
(
i
)
=
(
abs
(
i
*
433024253
)
%
65535
)
-
(
65535
/
2
);
v
(
i
)
=
(
abs
(
i
*
433024253
)
%
65535
)
-
(
65535
/
2
);
}
}
std
::
cout
<<
"<<<=== Just make sure it runs... Nothing is compared
\n
"
;
std
::
cout
<<
"<<<=== Just make sure it runs... Nothing is compared
\n
"
;
...
@@ -147,9 +157,7 @@ static void UnitTestHTKCompare1() {
...
@@ -147,9 +157,7 @@ static void UnitTestHTKCompare1() {
DeltaFeaturesOptions
delta_opts
;
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
// compare the results
bool
passed
=
true
;
bool
passed
=
true
;
...
@@ -158,32 +166,36 @@ static void UnitTestHTKCompare1() {
...
@@ -158,32 +166,36 @@ static void UnitTestHTKCompare1() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
// print the non-matching data only once per-line
if
(
i_old
!=
i
)
{
if
(
i_old
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
i_old
=
i
;
}
}
// print indices of non-matching cells
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
// write the htk features for later inspection
HtkHeader
header
=
{
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
kaldi_features
.
NumRows
(),
100000
,
// 10ms
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
021406
// MFCC_D_A_0
};
};
{
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.1"
,
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.1"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
}
...
@@ -231,9 +243,7 @@ static void UnitTestHTKCompare2() {
...
@@ -231,9 +243,7 @@ static void UnitTestHTKCompare2() {
DeltaFeaturesOptions
delta_opts
;
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
// compare the results
bool
passed
=
true
;
bool
passed
=
true
;
...
@@ -242,32 +252,36 @@ static void UnitTestHTKCompare2() {
...
@@ -242,32 +252,36 @@ static void UnitTestHTKCompare2() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
// print the non-matching data only once per-line
if
(
i_old
!=
i
)
{
if
(
i_old
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
i_old
=
i
;
}
}
// print indices of non-matching cells
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
// write the htk features for later inspection
HtkHeader
header
=
{
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
kaldi_features
.
NumRows
(),
100000
,
// 10ms
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
021406
// MFCC_D_A_0
};
};
{
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.2"
,
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.2"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
}
...
@@ -305,7 +319,7 @@ static void UnitTestHTKCompare3() {
...
@@ -305,7 +319,7 @@ static void UnitTestHTKCompare3() {
op
.
htk_compat
=
true
;
op
.
htk_compat
=
true
;
op
.
use_energy
=
true
;
// Use energy.
op
.
use_energy
=
true
;
// Use energy.
op
.
mel_opts
.
low_freq
=
20.0
;
op
.
mel_opts
.
low_freq
=
20.0
;
//
op.mel_opts.debug_mel = true;
//
op.mel_opts.debug_mel = true;
op
.
mel_opts
.
htk_mode
=
true
;
op
.
mel_opts
.
htk_mode
=
true
;
Mfcc
mfcc
(
op
);
Mfcc
mfcc
(
op
);
...
@@ -316,9 +330,7 @@ static void UnitTestHTKCompare3() {
...
@@ -316,9 +330,7 @@ static void UnitTestHTKCompare3() {
DeltaFeaturesOptions
delta_opts
;
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
// compare the results
bool
passed
=
true
;
bool
passed
=
true
;
...
@@ -327,32 +339,36 @@ static void UnitTestHTKCompare3() {
...
@@ -327,32 +339,36 @@ static void UnitTestHTKCompare3() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
// print the non-matching data only once per-line
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
i_old
=
i
;
}
}
// print indices of non-matching cells
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
// write the htk features for later inspection
HtkHeader
header
=
{
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
kaldi_features
.
NumRows
(),
100000
,
// 10ms
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
021406
// MFCC_D_A_0
};
};
{
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.3"
,
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.3"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
}
...
@@ -399,9 +415,7 @@ static void UnitTestHTKCompare4() {
...
@@ -399,9 +415,7 @@ static void UnitTestHTKCompare4() {
DeltaFeaturesOptions
delta_opts
;
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
// compare the results
bool
passed
=
true
;
bool
passed
=
true
;
...
@@ -410,32 +424,36 @@ static void UnitTestHTKCompare4() {
...
@@ -410,32 +424,36 @@ static void UnitTestHTKCompare4() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
// print the non-matching data only once per-line
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
i_old
=
i
;
}
}
// print indices of non-matching cells
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
// write the htk features for later inspection
HtkHeader
header
=
{
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
kaldi_features
.
NumRows
(),
100000
,
// 10ms
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
021406
// MFCC_D_A_0
};
};
{
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.4"
,
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.4"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
}
...
@@ -476,7 +494,8 @@ static void UnitTestHTKCompare5() {
...
@@ -476,7 +494,8 @@ static void UnitTestHTKCompare5() {
op
.
mel_opts
.
vtln_high
=
7500.0
;
op
.
mel_opts
.
vtln_high
=
7500.0
;
op
.
mel_opts
.
htk_mode
=
true
;
op
.
mel_opts
.
htk_mode
=
true
;
BaseFloat
vtln_warp
=
1.1
;
// our approach identical to htk for warp factor >1,
BaseFloat
vtln_warp
=
1.1
;
// our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc
mfcc
(
op
);
Mfcc
mfcc
(
op
);
...
@@ -487,9 +506,7 @@ static void UnitTestHTKCompare5() {
...
@@ -487,9 +506,7 @@ static void UnitTestHTKCompare5() {
DeltaFeaturesOptions
delta_opts
;
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
// compare the results
bool
passed
=
true
;
bool
passed
=
true
;
...
@@ -498,32 +515,36 @@ static void UnitTestHTKCompare5() {
...
@@ -498,32 +515,36 @@ static void UnitTestHTKCompare5() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
// print the non-matching data only once per-line
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
i_old
=
i
;
}
}
// print indices of non-matching cells
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
// write the htk features for later inspection
HtkHeader
header
=
{
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
kaldi_features
.
NumRows
(),
100000
,
// 10ms
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
021406
// MFCC_D_A_0
};
};
{
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.5"
,
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.5"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
}
...
@@ -572,9 +593,7 @@ static void UnitTestHTKCompare6() {
...
@@ -572,9 +593,7 @@ static void UnitTestHTKCompare6() {
DeltaFeaturesOptions
delta_opts
;
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
// compare the results
bool
passed
=
true
;
bool
passed
=
true
;
...
@@ -583,32 +602,36 @@ static void UnitTestHTKCompare6() {
...
@@ -583,32 +602,36 @@ static void UnitTestHTKCompare6() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
// print the non-matching data only once per-line
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
i_old
=
i
;
}
}
// print indices of non-matching cells
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
// write the htk features for later inspection
HtkHeader
header
=
{
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
kaldi_features
.
NumRows
(),
100000
,
// 10ms
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
021406
// MFCC_D_A_0
};
};
{
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.6"
,
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.6"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
}
...
@@ -619,36 +642,51 @@ static void UnitTestHTKCompare6() {
...
@@ -619,36 +642,51 @@ static void UnitTestHTKCompare6() {
void
UnitTestVtln
()
{
void
UnitTestVtln
()
{
// Test the function VtlnWarpFreq.
// Test the function VtlnWarpFreq.
BaseFloat
low_freq
=
10
,
high_freq
=
780
0
,
BaseFloat
low_freq
=
10
,
high_freq
=
7800
,
vtln_low_cutoff
=
2
0
,
vtln_low_cutoff
=
20
,
vtln_high_cutoff
=
7400
;
vtln_high_cutoff
=
7400
;
for
(
size_t
i
=
0
;
i
<
100
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
100
;
i
++
)
{
BaseFloat
freq
=
5000
,
warp_factor
=
0.9
+
RandUniform
()
*
0.2
;
BaseFloat
freq
=
5000
,
warp_factor
=
0.9
+
RandUniform
()
*
0.2
;
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
freq
),
freq
),
freq
/
warp_factor
);
freq
/
warp_factor
);
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
low_freq
),
low_freq
),
low_freq
);
low_freq
);
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
high_freq
),
high_freq
),
high_freq
);
high_freq
);
BaseFloat
freq2
=
low_freq
+
(
high_freq
-
low_freq
)
*
RandUniform
(),
BaseFloat
freq2
=
low_freq
+
(
high_freq
-
low_freq
)
*
RandUniform
(),
freq3
=
freq2
+
(
high_freq
-
freq2
)
*
RandUniform
();
// freq3>=freq2
freq3
=
freq2
+
BaseFloat
w2
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
(
high_freq
-
freq2
)
*
RandUniform
();
// freq3>=freq2
low_freq
,
high_freq
,
warp_factor
,
BaseFloat
w2
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
freq2
);
freq2
);
BaseFloat
w3
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
BaseFloat
w3
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
freq3
);
freq3
);
KALDI_ASSERT
(
w3
>=
w2
);
// increasing function.
KALDI_ASSERT
(
w3
>=
w2
);
// increasing function.
BaseFloat
w3dash
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
BaseFloat
w3dash
=
MelBanks
::
VtlnWarpFreq
(
low_freq
,
high_freq
,
1.0
,
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
1.0
,
freq3
);
freq3
);
AssertEqual
(
w3dash
,
freq3
);
AssertEqual
(
w3dash
,
freq3
);
}
}
}
}
...
@@ -670,11 +708,9 @@ static void UnitTestFeat() {
...
@@ -670,11 +708,9 @@ static void UnitTestFeat() {
}
}
int
main
()
{
int
main
()
{
try
{
try
{
for
(
int
i
=
0
;
i
<
5
;
i
++
)
for
(
int
i
=
0
;
i
<
5
;
i
++
)
UnitTestFeat
();
UnitTestFeat
();
std
::
cout
<<
"Tests succeeded.
\n
"
;
std
::
cout
<<
"Tests succeeded.
\n
"
;
return
0
;
return
0
;
}
catch
(
const
std
::
exception
&
e
)
{
}
catch
(
const
std
::
exception
&
e
)
{
...
@@ -682,5 +718,3 @@ int main() {
...
@@ -682,5 +718,3 @@ int main() {
return
1
;
return
1
;
}
}
}
}
speechx/examples/feat/linear-spectrogram-main.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/linear_spectrogram.h"
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "frontend/normalizer.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test wav path"
);
DEFINE_string
(
wav_rspecifier
,
""
,
"test wav path"
);
DEFINE_string
(
feature_wspecifier
,
""
,
"test wav ark"
);
DEFINE_string
(
feature_wspecifier
,
""
,
"test wav ark"
);
...
@@ -15,12 +29,120 @@ DEFINE_string(feature_check_wspecifier, "", "test wav ark");
...
@@ -15,12 +29,120 @@ DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string
(
cmvn_write_path
,
"./cmvn.ark"
,
"test wav ark"
);
DEFINE_string
(
cmvn_write_path
,
"./cmvn.ark"
,
"test wav ark"
);
std
::
vector
<
float
>
mean_
{
-
13730251.531853663
,
-
12982852.199316509
,
-
13673844.299583456
,
-
13089406.559646806
,
-
12673095.524938712
,
-
12823859.223276224
,
-
13590267.158903603
,
-
14257618.467152044
,
-
14374605.116185192
,
-
14490009.21822485
,
-
14849827.158924166
,
-
15354435.470563512
,
-
15834149.206532761
,
-
16172971.985514281
,
-
16348740.496746974
,
-
16423536.699409386
,
-
16556246.263649225
,
-
16744088.772748645
,
-
16916184.08510357
,
-
17054034.840031497
,
-
17165612.509455364
,
-
17255955.470915023
,
-
17322572.527648456
,
-
17408943.862033736
,
-
17521554.799865916
,
-
17620623.254924215
,
-
17699792.395918526
,
-
17723364.411134344
,
-
17741483.4433254
,
-
17747426.888704527
,
-
17733315.928209435
,
-
17748780.160905756
,
-
17808336.883775543
,
-
17895918.671983004
,
-
18009812.59173023
,
-
18098188.66548325
,
-
18195798.958462656
,
-
18293617.62980999
,
-
18397432.92077201
,
-
18505834.787318766
,
-
18585451.8100908
,
-
18652438.235649142
,
-
18700960.306275308
,
-
18734944.58792185
,
-
18737426.313365128
,
-
18735347.165987637
,
-
18738813.444170244
,
-
18737086.848890636
,
-
18731576.2474336
,
-
18717405.44095871
,
-
18703089.25545657
,
-
18691014.546456724
,
-
18692460.568905357
,
-
18702119.628629155
,
-
18727710.621126678
,
-
18761582.72034647
,
-
18806745.835547544
,
-
18850674.8692112
,
-
18884431.510951452
,
-
18919999.992506847
,
-
18939303.799078144
,
-
18952946.273760635
,
-
18980289.22996379
,
-
19011610.17803294
,
-
19040948.61805145
,
-
19061021.429847397
,
-
19112055.53768819
,
-
19149667.414264943
,
-
19201127.05091321
,
-
19270250.82564605
,
-
19334606.883057203
,
-
19390513.336589377
,
-
19444176.259208687
,
-
19502755.000038862
,
-
19544333.014549147
,
-
19612668.183176614
,
-
19681902.19006569
,
-
19771969.951249883
,
-
19873329.723376893
,
-
19996752.59235844
,
-
20110031.131400537
,
-
20231658.612529557
,
-
20319378.894054495
,
-
20378534.45718066
,
-
20413332.089584175
,
-
20438147.844177883
,
-
20443710.248040095
,
-
20465457.02238927
,
-
20488610.969337028
,
-
20516295.16424432
,
-
20541423.795738827
,
-
20553192.874953747
,
-
20573605.50701977
,
-
20577871.61936797
,
-
20571807.008916274
,
-
20556242.38912231
,
-
20542199.30819195
,
-
20521239.063551214
,
-
20519150.80004532
,
-
20527204.80248933
,
-
20536933.769257784
,
-
20543470.522332076
,
-
20549700.089992985
,
-
20551525.24958494
,
-
20554873.406493705
,
-
20564277.65794227
,
-
20572211.740052115
,
-
20574305.69550465
,
-
20575494.450104576
,
-
20567092.577932164
,
-
20549302.929608088
,
-
20545445.11878376
,
-
20546625.326603737
,
-
20549190.03499401
,
-
20554824.947828256
,
-
20568341.378989458
,
-
20577582.331383612
,
-
20577980.519402675
,
-
20566603.03458152
,
-
20560131.592262644
,
-
20552166.469060015
,
-
20549063.06763577
,
-
20544490.562339947
,
-
20539817.82346569
,
-
20528747.715731595
,
-
20518026.24576161
,
-
20510977.844974525
,
-
20506874.36087992
,
-
20506731.11977665
,
-
20510482.133420516
,
-
20507760.92101862
,
-
20494644.834457114
,
-
20480107.89304893
,
-
20461312.091867123
,
-
20442941.75080173
,
-
20426123.02834838
,
-
20424607.675283
,
-
20426810.369107097
,
-
20434024.50097819
,
-
20437404.75544205
,
-
20447688.63916367
,
-
20460893.335563846
,
-
20482922.735127095
,
-
20503610.119434915
,
-
20527062.76448319
,
-
20557830.035128627
,
-
20593274.72068722
,
-
20632528.452965066
,
-
20673637.471334763
,
-
20733106.97143075
,
-
20842921.0447562
,
-
21054357.83621519
,
-
21416569.534189366
,
-
21978460.272811692
,
-
22753170.052172784
,
-
23671344.10563395
,
-
24613499.293358143
,
-
25406477.12230188
,
-
25884377.82156489
,
-
26049040.62791664
,
-
26996879.104431007
};
std
::
vector
<
float
>
mean_
{
std
::
vector
<
float
>
variance_
{
213747175.10846674
,
188395815.34302503
,
212706429.10966414
,
199109025.81461075
,
189235901.23864496
,
194901336.53253657
,
217481594.29306737
,
238689869.12327808
,
243977501.24115244
,
248479623.6431067
,
259766741.47116545
,
275516766.7790273
,
291271202.3691234
,
302693239.8220509
,
308627358.3997694
,
311143911.38788426
,
315446105.07731867
,
321705430.9341829
,
327458907.4659941
,
332245072.43223983
,
336251717.5935284
,
339694069.7639722
,
342188204.4322228
,
345587110.31313115
,
349903086.2875232
,
353660214.20643026
,
356700344.5270885
,
357665362.3529641
,
358493352.05658793
,
358857951.620328
,
358375239.52774596
,
358899733.6342954
,
361051818.3511561
,
364361716.05025816
,
368750322.3771452
,
372047800.6462831
,
375655861.1349018
,
379358519.1980013
,
383327605.3935181
,
387458599.282341
,
390434692.3406868
,
392994486.35057056
,
394874418.04603153
,
396230525.79763395
,
396365592.0414835
,
396334819.8242737
,
396488353.19250053
,
396438877.00744957
,
396197980.4459586
,
395590921.6672991
,
395001107.62072515
,
394528291.7318225
,
394593110.424006
,
395018405.59353715
,
396110577.5415993
,
397506704.0371068
,
399400197.4657644
,
401243568.2468382
,
402687134.7805103
,
404136047.2872507
,
404883170.001883
,
405522253.219517
,
406660365.3626476
,
407919346.0991902
,
409045348.5384909
,
409759588.7889818
,
411974821.8564483
,
413489718.78201455
,
415535392.56684107
,
418466481.97674364
,
421104678.35678065
,
423405392.5200779
,
425550570.40798235
,
427929423.9579701
,
429585274.253478
,
432368493.55181056
,
435193587.13513297
,
438886855.20476013
,
443058876.8633751
,
448181232.5093362
,
452883835.6332396
,
458056721.77926534
,
461816531.22735566
,
464363620.1970998
,
465886343.5057493
,
466928872.0651
,
467180536.42647296
,
468111848.70714295
,
469138695.3071312
,
470378429.6930793
,
471517958.7132626
,
472109050.4262365
,
473087417.0177867
,
473381322.04648733
,
473220195.85483915
,
472666071.8998819
,
472124669.87879956
,
471298571.411737
,
471251033.2902761
,
471672676.43128747
,
472177147.2193172
,
472572361.7711908
,
472968783.7751127
,
473156295.4164052
,
473398034.82676554
,
473897703.5203811
,
474328271.33112127
,
474452670.98002136
,
474549003.99284613
,
474252887.13567275
,
473557462.909069
,
473483385.85193115
,
473609738.04855174
,
473746944.82085115
,
474016729.91696435
,
474617321.94138587
,
475045097.237122
,
475125402.586558
,
474664112.9824912
,
474426247.5800283
,
474104075.42796475
,
473978219.7273978
,
473773171.7798875
,
473578534.69508696
,
473102924.16904145
,
472651240.5232615
,
472374383.1810912
,
472209479.6956096
,
472202298.8921673
,
472370090.76781124
,
472220933.99374026
,
471625467.37106377
,
470994646.51883453
,
470182428.9637543
,
469348211.5939578
,
468570387.4467277
,
468540442.7225135
,
468672018.90414184
,
468994346.9533251
,
469138757.58201426
,
469553915.95710236
,
470134523.38582784
,
471082421.62055486
,
471962316.51804745
,
472939745.1708408
,
474250621.5944825
,
475773933.43199486
,
477465399.71087736
,
479218782.61382693
,
481752299.7930922
,
486608947.8984568
,
496119403.2067917
,
512730085.5704984
,
539048915.2641417
,
576285298.3548826
,
621610270.2240586
,
669308196.4436442
,
710656993.5957186
,
736344437.3725077
,
745481288.0241544
,
801121432.9925804
};
-
13730251.531853663
,
-
12982852.199316509
,
-
13673844.299583456
,
-
13089406.559646806
,
-
12673095.524938712
,
-
12823859.223276224
,
-
13590267.158903603
,
-
14257618.467152044
,
-
14374605.116185192
,
-
14490009.21822485
,
-
14849827.158924166
,
-
15354435.470563512
,
-
15834149.206532761
,
-
16172971.985514281
,
-
16348740.496746974
,
-
16423536.699409386
,
-
16556246.263649225
,
-
16744088.772748645
,
-
16916184.08510357
,
-
17054034.840031497
,
-
17165612.509455364
,
-
17255955.470915023
,
-
17322572.527648456
,
-
17408943.862033736
,
-
17521554.799865916
,
-
17620623.254924215
,
-
17699792.395918526
,
-
17723364.411134344
,
-
17741483.4433254
,
-
17747426.888704527
,
-
17733315.928209435
,
-
17748780.160905756
,
-
17808336.883775543
,
-
17895918.671983004
,
-
18009812.59173023
,
-
18098188.66548325
,
-
18195798.958462656
,
-
18293617.62980999
,
-
18397432.92077201
,
-
18505834.787318766
,
-
18585451.8100908
,
-
18652438.235649142
,
-
18700960.306275308
,
-
18734944.58792185
,
-
18737426.313365128
,
-
18735347.165987637
,
-
18738813.444170244
,
-
18737086.848890636
,
-
18731576.2474336
,
-
18717405.44095871
,
-
18703089.25545657
,
-
18691014.546456724
,
-
18692460.568905357
,
-
18702119.628629155
,
-
18727710.621126678
,
-
18761582.72034647
,
-
18806745.835547544
,
-
18850674.8692112
,
-
18884431.510951452
,
-
18919999.992506847
,
-
18939303.799078144
,
-
18952946.273760635
,
-
18980289.22996379
,
-
19011610.17803294
,
-
19040948.61805145
,
-
19061021.429847397
,
-
19112055.53768819
,
-
19149667.414264943
,
-
19201127.05091321
,
-
19270250.82564605
,
-
19334606.883057203
,
-
19390513.336589377
,
-
19444176.259208687
,
-
19502755.000038862
,
-
19544333.014549147
,
-
19612668.183176614
,
-
19681902.19006569
,
-
19771969.951249883
,
-
19873329.723376893
,
-
19996752.59235844
,
-
20110031.131400537
,
-
20231658.612529557
,
-
20319378.894054495
,
-
20378534.45718066
,
-
20413332.089584175
,
-
20438147.844177883
,
-
20443710.248040095
,
-
20465457.02238927
,
-
20488610.969337028
,
-
20516295.16424432
,
-
20541423.795738827
,
-
20553192.874953747
,
-
20573605.50701977
,
-
20577871.61936797
,
-
20571807.008916274
,
-
20556242.38912231
,
-
20542199.30819195
,
-
20521239.063551214
,
-
20519150.80004532
,
-
20527204.80248933
,
-
20536933.769257784
,
-
20543470.522332076
,
-
20549700.089992985
,
-
20551525.24958494
,
-
20554873.406493705
,
-
20564277.65794227
,
-
20572211.740052115
,
-
20574305.69550465
,
-
20575494.450104576
,
-
20567092.577932164
,
-
20549302.929608088
,
-
20545445.11878376
,
-
20546625.326603737
,
-
20549190.03499401
,
-
20554824.947828256
,
-
20568341.378989458
,
-
20577582.331383612
,
-
20577980.519402675
,
-
20566603.03458152
,
-
20560131.592262644
,
-
20552166.469060015
,
-
20549063.06763577
,
-
20544490.562339947
,
-
20539817.82346569
,
-
20528747.715731595
,
-
20518026.24576161
,
-
20510977.844974525
,
-
20506874.36087992
,
-
20506731.11977665
,
-
20510482.133420516
,
-
20507760.92101862
,
-
20494644.834457114
,
-
20480107.89304893
,
-
20461312.091867123
,
-
20442941.75080173
,
-
20426123.02834838
,
-
20424607.675283
,
-
20426810.369107097
,
-
20434024.50097819
,
-
20437404.75544205
,
-
20447688.63916367
,
-
20460893.335563846
,
-
20482922.735127095
,
-
20503610.119434915
,
-
20527062.76448319
,
-
20557830.035128627
,
-
20593274.72068722
,
-
20632528.452965066
,
-
20673637.471334763
,
-
20733106.97143075
,
-
20842921.0447562
,
-
21054357.83621519
,
-
21416569.534189366
,
-
21978460.272811692
,
-
22753170.052172784
,
-
23671344.10563395
,
-
24613499.293358143
,
-
25406477.12230188
,
-
25884377.82156489
,
-
26049040.62791664
,
-
26996879.104431007
};
std
::
vector
<
float
>
variance_
{
213747175.10846674
,
188395815.34302503
,
212706429.10966414
,
199109025.81461075
,
189235901.23864496
,
194901336.53253657
,
217481594.29306737
,
238689869.12327808
,
243977501.24115244
,
248479623.6431067
,
259766741.47116545
,
275516766.7790273
,
291271202.3691234
,
302693239.8220509
,
308627358.3997694
,
311143911.38788426
,
315446105.07731867
,
321705430.9341829
,
327458907.4659941
,
332245072.43223983
,
336251717.5935284
,
339694069.7639722
,
342188204.4322228
,
345587110.31313115
,
349903086.2875232
,
353660214.20643026
,
356700344.5270885
,
357665362.3529641
,
358493352.05658793
,
358857951.620328
,
358375239.52774596
,
358899733.6342954
,
361051818.3511561
,
364361716.05025816
,
368750322.3771452
,
372047800.6462831
,
375655861.1349018
,
379358519.1980013
,
383327605.3935181
,
387458599.282341
,
390434692.3406868
,
392994486.35057056
,
394874418.04603153
,
396230525.79763395
,
396365592.0414835
,
396334819.8242737
,
396488353.19250053
,
396438877.00744957
,
396197980.4459586
,
395590921.6672991
,
395001107.62072515
,
394528291.7318225
,
394593110.424006
,
395018405.59353715
,
396110577.5415993
,
397506704.0371068
,
399400197.4657644
,
401243568.2468382
,
402687134.7805103
,
404136047.2872507
,
404883170.001883
,
405522253.219517
,
406660365.3626476
,
407919346.0991902
,
409045348.5384909
,
409759588.7889818
,
411974821.8564483
,
413489718.78201455
,
415535392.56684107
,
418466481.97674364
,
421104678.35678065
,
423405392.5200779
,
425550570.40798235
,
427929423.9579701
,
429585274.253478
,
432368493.55181056
,
435193587.13513297
,
438886855.20476013
,
443058876.8633751
,
448181232.5093362
,
452883835.6332396
,
458056721.77926534
,
461816531.22735566
,
464363620.1970998
,
465886343.5057493
,
466928872.0651
,
467180536.42647296
,
468111848.70714295
,
469138695.3071312
,
470378429.6930793
,
471517958.7132626
,
472109050.4262365
,
473087417.0177867
,
473381322.04648733
,
473220195.85483915
,
472666071.8998819
,
472124669.87879956
,
471298571.411737
,
471251033.2902761
,
471672676.43128747
,
472177147.2193172
,
472572361.7711908
,
472968783.7751127
,
473156295.4164052
,
473398034.82676554
,
473897703.5203811
,
474328271.33112127
,
474452670.98002136
,
474549003.99284613
,
474252887.13567275
,
473557462.909069
,
473483385.85193115
,
473609738.04855174
,
473746944.82085115
,
474016729.91696435
,
474617321.94138587
,
475045097.237122
,
475125402.586558
,
474664112.9824912
,
474426247.5800283
,
474104075.42796475
,
473978219.7273978
,
473773171.7798875
,
473578534.69508696
,
473102924.16904145
,
472651240.5232615
,
472374383.1810912
,
472209479.6956096
,
472202298.8921673
,
472370090.76781124
,
472220933.99374026
,
471625467.37106377
,
470994646.51883453
,
470182428.9637543
,
469348211.5939578
,
468570387.4467277
,
468540442.7225135
,
468672018.90414184
,
468994346.9533251
,
469138757.58201426
,
469553915.95710236
,
470134523.38582784
,
471082421.62055486
,
471962316.51804745
,
472939745.1708408
,
474250621.5944825
,
475773933.43199486
,
477465399.71087736
,
479218782.61382693
,
481752299.7930922
,
486608947.8984568
,
496119403.2067917
,
512730085.5704984
,
539048915.2641417
,
576285298.3548826
,
621610270.2240586
,
669308196.4436442
,
710656993.5957186
,
736344437.3725077
,
745481288.0241544
,
801121432.9925804
};
int
count_
=
912592
;
int
count_
=
912592
;
void
WriteMatrix
()
{
void
WriteMatrix
()
{
kaldi
::
Matrix
<
double
>
cmvn_stats
(
2
,
mean_
.
size
()
+
1
);
kaldi
::
Matrix
<
double
>
cmvn_stats
(
2
,
mean_
.
size
()
+
1
);
for
(
size_t
idx
=
0
;
idx
<
mean_
.
size
();
++
idx
)
{
for
(
size_t
idx
=
0
;
idx
<
mean_
.
size
();
++
idx
)
{
cmvn_stats
(
0
,
idx
)
=
mean_
[
idx
];
cmvn_stats
(
0
,
idx
)
=
mean_
[
idx
];
cmvn_stats
(
1
,
idx
)
=
variance_
[
idx
];
cmvn_stats
(
1
,
idx
)
=
variance_
[
idx
];
...
@@ -33,12 +155,15 @@ int main(int argc, char* argv[]) {
...
@@ -33,12 +155,15 @@ int main(int argc, char* argv[]) {
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InitGoogleLogging
(
argv
[
0
]);
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
FLAGS_wav_rspecifier
);
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
FLAGS_wav_rspecifier
);
kaldi
::
BaseFloatMatrixWriter
feat_writer
(
FLAGS_feature_wspecifier
);
kaldi
::
BaseFloatMatrixWriter
feat_writer
(
FLAGS_feature_wspecifier
);
kaldi
::
BaseFloatMatrixWriter
feat_cmvn_check_writer
(
FLAGS_feature_check_wspecifier
);
kaldi
::
BaseFloatMatrixWriter
feat_cmvn_check_writer
(
FLAGS_feature_check_wspecifier
);
WriteMatrix
();
WriteMatrix
();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> cmvn
int32
num_done
=
0
,
num_err
=
0
;
int32
num_done
=
0
,
num_err
=
0
;
ppspeech
::
LinearSpectrogramOptions
opt
;
ppspeech
::
LinearSpectrogramOptions
opt
;
opt
.
frame_opts
.
frame_length_ms
=
20
;
opt
.
frame_opts
.
frame_length_ms
=
20
;
...
@@ -46,7 +171,8 @@ int main(int argc, char* argv[]) {
...
@@ -46,7 +171,8 @@ int main(int argc, char* argv[]) {
ppspeech
::
DecibelNormalizerOptions
db_norm_opt
;
ppspeech
::
DecibelNormalizerOptions
db_norm_opt
;
std
::
unique_ptr
<
ppspeech
::
FeatureExtractorInterface
>
base_feature_extractor
(
std
::
unique_ptr
<
ppspeech
::
FeatureExtractorInterface
>
base_feature_extractor
(
new
ppspeech
::
DecibelNormalizer
(
db_norm_opt
));
new
ppspeech
::
DecibelNormalizer
(
db_norm_opt
));
ppspeech
::
LinearSpectrogram
linear_spectrogram
(
opt
,
std
::
move
(
base_feature_extractor
));
ppspeech
::
LinearSpectrogram
linear_spectrogram
(
opt
,
std
::
move
(
base_feature_extractor
));
ppspeech
::
CMVN
cmvn
(
FLAGS_cmvn_write_path
);
ppspeech
::
CMVN
cmvn
(
FLAGS_cmvn_write_path
);
...
@@ -66,16 +192,18 @@ int main(int argc, char* argv[]) {
...
@@ -66,16 +192,18 @@ int main(int argc, char* argv[]) {
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
std
::
string
utt
=
wav_reader
.
Key
();
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
int32
this_channel
=
0
;
int32
this_channel
=
0
;
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
int
tot_samples
=
waveform
.
Dim
();
int
tot_samples
=
waveform
.
Dim
();
int
sample_offset
=
0
;
int
sample_offset
=
0
;
std
::
vector
<
kaldi
::
Matrix
<
BaseFloat
>>
feats
;
std
::
vector
<
kaldi
::
Matrix
<
BaseFloat
>>
feats
;
int
feature_rows
=
0
;
int
feature_rows
=
0
;
while
(
sample_offset
<
tot_samples
)
{
while
(
sample_offset
<
tot_samples
)
{
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
(
i
)
=
waveform
(
sample_offset
+
i
);
wav_chunk
(
i
)
=
waveform
(
sample_offset
+
i
);
...
@@ -90,11 +218,14 @@ int main(int argc, char* argv[]) {
...
@@ -90,11 +218,14 @@ int main(int argc, char* argv[]) {
}
}
int
cur_idx
=
0
;
int
cur_idx
=
0
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
features
(
feature_rows
,
feats
[
0
].
NumCols
());
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
features
(
feature_rows
,
feats
[
0
].
NumCols
());
for
(
auto
feat
:
feats
)
{
for
(
auto
feat
:
feats
)
{
for
(
int
row_idx
=
0
;
row_idx
<
feat
.
NumRows
();
++
row_idx
)
{
for
(
int
row_idx
=
0
;
row_idx
<
feat
.
NumRows
();
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
feat
.
NumCols
();
++
col_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
feat
.
NumCols
();
++
col_idx
)
{
features
(
cur_idx
,
col_idx
)
=
(
feat
(
row_idx
,
col_idx
)
-
mean_
[
col_idx
])
*
variance_
[
col_idx
];
features
(
cur_idx
,
col_idx
)
=
(
feat
(
row_idx
,
col_idx
)
-
mean_
[
col_idx
])
*
variance_
[
col_idx
];
}
}
++
cur_idx
;
++
cur_idx
;
}
}
...
@@ -102,7 +233,8 @@ int main(int argc, char* argv[]) {
...
@@ -102,7 +233,8 @@ int main(int argc, char* argv[]) {
feat_writer
.
Write
(
utt
,
features
);
feat_writer
.
Write
(
utt
,
features
);
cur_idx
=
0
;
cur_idx
=
0
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
features_check
(
feature_rows
,
feats
[
0
].
NumCols
());
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
features_check
(
feature_rows
,
feats
[
0
].
NumCols
());
for
(
auto
feat
:
feats
)
{
for
(
auto
feat
:
feats
)
{
for
(
int
row_idx
=
0
;
row_idx
<
feat
.
NumRows
();
++
row_idx
)
{
for
(
int
row_idx
=
0
;
row_idx
<
feat
.
NumRows
();
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
feat
.
NumCols
();
++
col_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
feat
.
NumCols
();
++
col_idx
)
{
...
...
speechx/examples/nnet/pp-model-test.cc
浏览文件 @
41feecbd
#include "paddle_inference_api.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <iostream>
#include <algorithm>
#include <thread>
#include <fstream>
#include <fstream>
#include <functional>
#include <iostream>
#include <iterator>
#include <iterator>
#include <algorithm>
#include <numeric>
#include <numeric>
#include <functional>
#include <thread>
#include "paddle_inference_api.h"
using
std
::
cout
;
using
std
::
cout
;
using
std
::
endl
;
using
std
::
endl
;
...
@@ -39,7 +53,8 @@ void model_forward_test() {
...
@@ -39,7 +53,8 @@ void model_forward_test() {
std
::
vector
<
std
::
vector
<
float
>>
feats
;
std
::
vector
<
std
::
vector
<
float
>>
feats
;
produce_data
(
&
feats
);
produce_data
(
&
feats
);
std
::
cout
<<
"2. load the model"
<<
std
::
endl
;;
std
::
cout
<<
"2. load the model"
<<
std
::
endl
;
;
std
::
string
model_graph
=
FLAGS_model_path
;
std
::
string
model_graph
=
FLAGS_model_path
;
std
::
string
model_params
=
FLAGS_param_path
;
std
::
string
model_params
=
FLAGS_param_path
;
cout
<<
"model path: "
<<
model_graph
<<
endl
;
cout
<<
"model path: "
<<
model_graph
<<
endl
;
...
@@ -53,9 +68,10 @@ void model_forward_test() {
...
@@ -53,9 +68,10 @@ void model_forward_test() {
cout
<<
"DisableFCPadding: "
<<
endl
;
cout
<<
"DisableFCPadding: "
<<
endl
;
auto
predictor
=
paddle_infer
::
CreatePredictor
(
config
);
auto
predictor
=
paddle_infer
::
CreatePredictor
(
config
);
std
::
cout
<<
"3. feat shape, row="
<<
feats
.
size
()
<<
",col="
<<
feats
[
0
].
size
()
<<
std
::
endl
;
std
::
cout
<<
"3. feat shape, row="
<<
feats
.
size
()
<<
",col="
<<
feats
[
0
].
size
()
<<
std
::
endl
;
std
::
vector
<
float
>
pp_input_mat
;
std
::
vector
<
float
>
pp_input_mat
;
for
(
const
auto
&
item
:
feats
)
{
for
(
const
auto
&
item
:
feats
)
{
pp_input_mat
.
insert
(
pp_input_mat
.
end
(),
item
.
begin
(),
item
.
end
());
pp_input_mat
.
insert
(
pp_input_mat
.
end
(),
item
.
begin
(),
item
.
end
());
}
}
...
@@ -64,10 +80,10 @@ void model_forward_test() {
...
@@ -64,10 +80,10 @@ void model_forward_test() {
int
col
=
feats
[
0
].
size
();
int
col
=
feats
[
0
].
size
();
std
::
vector
<
std
::
string
>
input_names
=
predictor
->
GetInputNames
();
std
::
vector
<
std
::
string
>
input_names
=
predictor
->
GetInputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
for
(
auto
name
:
input_names
){
for
(
auto
name
:
input_names
)
{
cout
<<
"model input names: "
<<
name
<<
endl
;
cout
<<
"model input names: "
<<
name
<<
endl
;
}
}
for
(
auto
name
:
output_names
){
for
(
auto
name
:
output_names
)
{
cout
<<
"model output names: "
<<
name
<<
endl
;
cout
<<
"model output names: "
<<
name
<<
endl
;
}
}
...
@@ -79,7 +95,8 @@ void model_forward_test() {
...
@@ -79,7 +95,8 @@ void model_forward_test() {
input_tensor
->
CopyFromCpu
(
pp_input_mat
.
data
());
input_tensor
->
CopyFromCpu
(
pp_input_mat
.
data
());
// input length
// input length
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
vector
<
int
>
input_len_size
=
{
1
};
std
::
vector
<
int
>
input_len_size
=
{
1
};
input_len
->
Reshape
(
input_len_size
);
input_len
->
Reshape
(
input_len_size
);
std
::
vector
<
int64_t
>
audio_len
;
std
::
vector
<
int64_t
>
audio_len
;
...
@@ -87,20 +104,28 @@ void model_forward_test() {
...
@@ -87,20 +104,28 @@ void model_forward_test() {
input_len
->
CopyFromCpu
(
audio_len
.
data
());
input_len
->
CopyFromCpu
(
audio_len
.
data
());
// state_h
// state_h
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
std
::
vector
<
int
>
chunk_state_h_box_shape
=
{
3
,
1
,
1024
};
std
::
vector
<
int
>
chunk_state_h_box_shape
=
{
3
,
1
,
1024
};
chunk_state_h_box
->
Reshape
(
chunk_state_h_box_shape
);
chunk_state_h_box
->
Reshape
(
chunk_state_h_box_shape
);
int
chunk_state_h_box_size
=
std
::
accumulate
(
chunk_state_h_box_shape
.
begin
(),
chunk_state_h_box_shape
.
end
(),
int
chunk_state_h_box_size
=
1
,
std
::
multiplies
<
int
>
());
std
::
accumulate
(
chunk_state_h_box_shape
.
begin
(),
chunk_state_h_box_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
chunk_state_h_box_data
(
chunk_state_h_box_size
,
0.0
f
);
std
::
vector
<
float
>
chunk_state_h_box_data
(
chunk_state_h_box_size
,
0.0
f
);
chunk_state_h_box
->
CopyFromCpu
(
chunk_state_h_box_data
.
data
());
chunk_state_h_box
->
CopyFromCpu
(
chunk_state_h_box_data
.
data
());
// state_c
// state_c
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
std
::
vector
<
int
>
chunk_state_c_box_shape
=
{
3
,
1
,
1024
};
std
::
vector
<
int
>
chunk_state_c_box_shape
=
{
3
,
1
,
1024
};
chunk_state_c_box
->
Reshape
(
chunk_state_c_box_shape
);
chunk_state_c_box
->
Reshape
(
chunk_state_c_box_shape
);
int
chunk_state_c_box_size
=
std
::
accumulate
(
chunk_state_c_box_shape
.
begin
(),
chunk_state_c_box_shape
.
end
(),
int
chunk_state_c_box_size
=
1
,
std
::
multiplies
<
int
>
());
std
::
accumulate
(
chunk_state_c_box_shape
.
begin
(),
chunk_state_c_box_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
chunk_state_c_box_data
(
chunk_state_c_box_size
,
0.0
f
);
std
::
vector
<
float
>
chunk_state_c_box_data
(
chunk_state_c_box_size
,
0.0
f
);
chunk_state_c_box
->
CopyFromCpu
(
chunk_state_c_box_data
.
data
());
chunk_state_c_box
->
CopyFromCpu
(
chunk_state_c_box_data
.
data
());
...
@@ -108,18 +133,20 @@ void model_forward_test() {
...
@@ -108,18 +133,20 @@ void model_forward_test() {
bool
success
=
predictor
->
Run
();
bool
success
=
predictor
->
Run
();
// state_h out
// state_h out
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
std
::
vector
<
int
>
h_out_shape
=
h_out
->
shape
();
std
::
vector
<
int
>
h_out_shape
=
h_out
->
shape
();
int
h_out_size
=
std
::
accumulate
(
h_out_shape
.
begin
(),
h_out_shape
.
end
(),
int
h_out_size
=
std
::
accumulate
(
1
,
std
::
multiplies
<
int
>
());
h_out_shape
.
begin
(),
h_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
h_out_data
(
h_out_size
);
std
::
vector
<
float
>
h_out_data
(
h_out_size
);
h_out
->
CopyToCpu
(
h_out_data
.
data
());
h_out
->
CopyToCpu
(
h_out_data
.
data
());
// stage_c out
// stage_c out
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
std
::
vector
<
int
>
c_out_shape
=
c_out
->
shape
();
std
::
vector
<
int
>
c_out_shape
=
c_out
->
shape
();
int
c_out_size
=
std
::
accumulate
(
c_out_shape
.
begin
(),
c_out_shape
.
end
(),
int
c_out_size
=
std
::
accumulate
(
1
,
std
::
multiplies
<
int
>
());
c_out_shape
.
begin
(),
c_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
c_out_data
(
c_out_size
);
std
::
vector
<
float
>
c_out_data
(
c_out_size
);
c_out
->
CopyToCpu
(
c_out_data
.
data
());
c_out
->
CopyToCpu
(
c_out_data
.
data
());
...
@@ -128,8 +155,8 @@ void model_forward_test() {
...
@@ -128,8 +155,8 @@ void model_forward_test() {
predictor
->
GetOutputHandle
(
output_names
[
0
]);
predictor
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
output_tensor
->
shape
();
std
::
vector
<
int
>
output_shape
=
output_tensor
->
shape
();
std
::
vector
<
float
>
output_probs
;
std
::
vector
<
float
>
output_probs
;
int
output_size
=
std
::
accumulate
(
output_shape
.
begin
(),
output_shape
.
end
(),
int
output_size
=
std
::
accumulate
(
1
,
std
::
multiplies
<
int
>
());
output_shape
.
begin
(),
output_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
output_probs
.
resize
(
output_size
);
output_probs
.
resize
(
output_size
);
output_tensor
->
CopyToCpu
(
output_probs
.
data
());
output_tensor
->
CopyToCpu
(
output_probs
.
data
());
row
=
output_shape
[
1
];
row
=
output_shape
[
1
];
...
@@ -148,9 +175,11 @@ void model_forward_test() {
...
@@ -148,9 +175,11 @@ void model_forward_test() {
}
}
std
::
vector
<
std
::
vector
<
float
>>
log_feat
=
probs
;
std
::
vector
<
std
::
vector
<
float
>>
log_feat
=
probs
;
std
::
cout
<<
"probs, row: "
<<
log_feat
.
size
()
<<
" col: "
<<
log_feat
[
0
].
size
()
<<
std
::
endl
;
std
::
cout
<<
"probs, row: "
<<
log_feat
.
size
()
<<
" col: "
<<
log_feat
[
0
].
size
()
<<
std
::
endl
;
for
(
size_t
row_idx
=
0
;
row_idx
<
log_feat
.
size
();
++
row_idx
)
{
for
(
size_t
row_idx
=
0
;
row_idx
<
log_feat
.
size
();
++
row_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
log_feat
[
row_idx
].
size
();
++
col_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
log_feat
[
row_idx
].
size
();
++
col_idx
)
{
std
::
cout
<<
log_feat
[
row_idx
][
col_idx
]
<<
" "
;
std
::
cout
<<
log_feat
[
row_idx
][
col_idx
]
<<
" "
;
}
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
...
...
speechx/speechx/base/basic_types.h
浏览文件 @
41feecbd
...
@@ -43,18 +43,18 @@ typedef unsigned long long uint64;
...
@@ -43,18 +43,18 @@ typedef unsigned long long uint64;
typedef
signed
int
char32
;
typedef
signed
int
char32
;
const
uint8
kuint8max
=
((
uint8
)
0xFF
);
const
uint8
kuint8max
=
((
uint8
)
0xFF
);
const
uint16
kuint16max
=
((
uint16
)
0xFFFF
);
const
uint16
kuint16max
=
((
uint16
)
0xFFFF
);
const
uint32
kuint32max
=
((
uint32
)
0xFFFFFFFF
);
const
uint32
kuint32max
=
((
uint32
)
0xFFFFFFFF
);
const
uint64
kuint64max
=
((
uint64
)
(
0xFFFFFFFFFFFFFFFFLL
));
const
uint64
kuint64max
=
((
uint64
)(
0xFFFFFFFFFFFFFFFFLL
));
const
int8
kint8min
=
((
int8
)
0x80
);
const
int8
kint8min
=
((
int8
)
0x80
);
const
int8
kint8max
=
((
int8
)
0x7F
);
const
int8
kint8max
=
((
int8
)
0x7F
);
const
int16
kint16min
=
((
int16
)
0x8000
);
const
int16
kint16min
=
((
int16
)
0x8000
);
const
int16
kint16max
=
((
int16
)
0x7FFF
);
const
int16
kint16max
=
((
int16
)
0x7FFF
);
const
int32
kint32min
=
((
int32
)
0x80000000
);
const
int32
kint32min
=
((
int32
)
0x80000000
);
const
int32
kint32max
=
((
int32
)
0x7FFFFFFF
);
const
int32
kint32max
=
((
int32
)
0x7FFFFFFF
);
const
int64
kint64min
=
((
int64
)
(
0x8000000000000000LL
));
const
int64
kint64min
=
((
int64
)
(
0x8000000000000000LL
));
const
int64
kint64max
=
((
int64
)
(
0x7FFFFFFFFFFFFFFFLL
));
const
int64
kint64max
=
((
int64
)
(
0x7FFFFFFFFFFFFFFFLL
));
const
BaseFloat
kBaseFloatMax
=
std
::
numeric_limits
<
BaseFloat
>::
max
();
const
BaseFloat
kBaseFloatMax
=
std
::
numeric_limits
<
BaseFloat
>::
max
();
const
BaseFloat
kBaseFloatMin
=
std
::
numeric_limits
<
BaseFloat
>::
min
();
const
BaseFloat
kBaseFloatMin
=
std
::
numeric_limits
<
BaseFloat
>::
min
();
speechx/speechx/base/common.h
浏览文件 @
41feecbd
...
@@ -15,22 +15,22 @@
...
@@ -15,22 +15,22 @@
#pragma once
#pragma once
#include <deque>
#include <deque>
#include <fstream>
#include <iostream>
#include <iostream>
#include <istream>
#include <istream>
#include <fstream>
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <mutex>
#include <ostream>
#include <ostream>
#include <set>
#include <set>
#include <sstream>
#include <sstream>
#include <stack>
#include <stack>
#include <string>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <
mutex
>
#include <
vector
>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h"
#include "base/basic_types.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/macros.h"
#include "base/macros.h"
speechx/speechx/base/macros.h
浏览文件 @
41feecbd
speechx/speechx/base/thread_pool.h
浏览文件 @
41feecbd
...
@@ -23,28 +23,29 @@
...
@@ -23,28 +23,29 @@
#ifndef BASE_THREAD_POOL_H
#ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <condition_variable>
#include <future>
#include <functional>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <stdexcept>
#include <thread>
#include <vector>
class
ThreadPool
{
class
ThreadPool
{
public:
public:
ThreadPool
(
size_t
);
ThreadPool
(
size_t
);
template
<
class
F
,
class
...
Args
>
template
<
class
F
,
class
...
Args
>
auto
enqueue
(
F
&&
f
,
Args
&&
...
args
)
auto
enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
;
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
;
~
ThreadPool
();
~
ThreadPool
();
private:
private:
// need to keep track of threads so we can join them
// need to keep track of threads so we can join them
std
::
vector
<
std
::
thread
>
workers
;
std
::
vector
<
std
::
thread
>
workers
;
// the task queue
// the task queue
std
::
queue
<
std
::
function
<
void
()
>
>
tasks
;
std
::
queue
<
std
::
function
<
void
()
>
>
tasks
;
// synchronization
// synchronization
std
::
mutex
queue_mutex
;
std
::
mutex
queue_mutex
;
...
@@ -53,68 +54,57 @@ private:
...
@@ -53,68 +54,57 @@ private:
};
};
// the constructor just launches some amount of workers
// the constructor just launches some amount of workers
inline
ThreadPool
::
ThreadPool
(
size_t
threads
)
inline
ThreadPool
::
ThreadPool
(
size_t
threads
)
:
stop
(
false
)
{
:
stop
(
false
)
for
(
size_t
i
=
0
;
i
<
threads
;
++
i
)
{
workers
.
emplace_back
([
this
]
{
for
(
size_t
i
=
0
;
i
<
threads
;
++
i
)
for
(;;)
{
workers
.
emplace_back
(
[
this
]
{
for
(;;)
{
std
::
function
<
void
()
>
task
;
std
::
function
<
void
()
>
task
;
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
queue_mutex
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
queue_mutex
);
this
->
condition
.
wait
(
lock
,
this
->
condition
.
wait
(
lock
,
[
this
]
{
[
this
]{
return
this
->
stop
||
!
this
->
tasks
.
empty
();
}
);
return
this
->
stop
||
!
this
->
tasks
.
empty
(
);
if
(
this
->
stop
&&
this
->
tasks
.
empty
())
});
return
;
if
(
this
->
stop
&&
this
->
tasks
.
empty
())
return
;
task
=
std
::
move
(
this
->
tasks
.
front
());
task
=
std
::
move
(
this
->
tasks
.
front
());
this
->
tasks
.
pop
();
this
->
tasks
.
pop
();
}
}
task
();
task
();
}
}
}
});
);
}
}
// add new work item to the pool
// add new work item to the pool
template
<
class
F
,
class
...
Args
>
template
<
class
F
,
class
...
Args
>
auto
ThreadPool
::
enqueue
(
F
&&
f
,
Args
&&
...
args
)
auto
ThreadPool
::
enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
{
{
using
return_type
=
typename
std
::
result_of
<
F
(
Args
...)
>::
type
;
using
return_type
=
typename
std
::
result_of
<
F
(
Args
...)
>::
type
;
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>
>
(
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>>
(
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...)
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...));
);
std
::
future
<
return_type
>
res
=
task
->
get_future
();
std
::
future
<
return_type
>
res
=
task
->
get_future
();
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
queue_mutex
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
queue_mutex
);
// don't allow enqueueing after stopping the pool
// don't allow enqueueing after stopping the pool
if
(
stop
)
if
(
stop
)
throw
std
::
runtime_error
(
"enqueue on stopped ThreadPool"
);
throw
std
::
runtime_error
(
"enqueue on stopped ThreadPool"
);
tasks
.
emplace
([
task
](){
(
*
task
)();
});
tasks
.
emplace
([
task
]()
{
(
*
task
)();
});
}
}
condition
.
notify_one
();
condition
.
notify_one
();
return
res
;
return
res
;
}
}
// the destructor joins all threads
// the destructor joins all threads
inline
ThreadPool
::~
ThreadPool
()
inline
ThreadPool
::~
ThreadPool
()
{
{
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
queue_mutex
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
queue_mutex
);
stop
=
true
;
stop
=
true
;
}
}
condition
.
notify_all
();
condition
.
notify_all
();
for
(
std
::
thread
&
worker
:
workers
)
for
(
std
::
thread
&
worker
:
workers
)
worker
.
join
();
worker
.
join
();
}
}
#endif
#endif
speechx/speechx/decoder/common.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/basic_types.h"
#include "base/basic_types.h"
struct
DecoderResult
{
struct
DecoderResult
{
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
#include "base/basic_types.h"
...
@@ -9,25 +23,23 @@ namespace ppspeech {
...
@@ -9,25 +23,23 @@ namespace ppspeech {
using
std
::
vector
;
using
std
::
vector
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
CTCBeamSearch
::
CTCBeamSearch
(
const
CTCBeamSearchOptions
&
opts
)
:
CTCBeamSearch
::
CTCBeamSearch
(
const
CTCBeamSearchOptions
&
opts
)
opts_
(
opts
),
:
opts_
(
opts
),
init_ext_scorer_
(
nullptr
),
init_ext_scorer_
(
nullptr
),
blank_id
(
-
1
),
blank_id
(
-
1
),
space_id
(
-
1
),
space_id
(
-
1
),
num_frame_decoded_
(
0
),
num_frame_decoded_
(
0
),
root
(
nullptr
)
{
root
(
nullptr
)
{
LOG
(
INFO
)
<<
"dict path: "
<<
opts_
.
dict_file
;
LOG
(
INFO
)
<<
"dict path: "
<<
opts_
.
dict_file
;
if
(
!
ReadFileToVector
(
opts_
.
dict_file
,
&
vocabulary_
))
{
if
(
!
ReadFileToVector
(
opts_
.
dict_file
,
&
vocabulary_
))
{
LOG
(
INFO
)
<<
"load the dict failed"
;
LOG
(
INFO
)
<<
"load the dict failed"
;
}
}
LOG
(
INFO
)
<<
"read the vocabulary success, dict size: "
<<
vocabulary_
.
size
();
LOG
(
INFO
)
<<
"read the vocabulary success, dict size: "
<<
vocabulary_
.
size
();
LOG
(
INFO
)
<<
"language model path: "
<<
opts_
.
lm_path
;
LOG
(
INFO
)
<<
"language model path: "
<<
opts_
.
lm_path
;
init_ext_scorer_
=
std
::
make_shared
<
Scorer
>
(
opts_
.
alpha
,
init_ext_scorer_
=
std
::
make_shared
<
Scorer
>
(
opts_
.
beta
,
opts_
.
alpha
,
opts_
.
beta
,
opts_
.
lm_path
,
vocabulary_
);
opts_
.
lm_path
,
vocabulary_
);
}
}
void
CTCBeamSearch
::
Reset
()
{
void
CTCBeamSearch
::
Reset
()
{
...
@@ -36,7 +48,6 @@ void CTCBeamSearch::Reset() {
...
@@ -36,7 +48,6 @@ void CTCBeamSearch::Reset() {
}
}
void
CTCBeamSearch
::
InitDecoder
()
{
void
CTCBeamSearch
::
InitDecoder
()
{
blank_id
=
0
;
blank_id
=
0
;
auto
it
=
std
::
find
(
vocabulary_
.
begin
(),
vocabulary_
.
end
(),
" "
);
auto
it
=
std
::
find
(
vocabulary_
.
begin
(),
vocabulary_
.
end
(),
" "
);
...
@@ -51,10 +62,11 @@ void CTCBeamSearch::InitDecoder() {
...
@@ -51,10 +62,11 @@ void CTCBeamSearch::InitDecoder() {
root
=
std
::
make_shared
<
PathTrie
>
();
root
=
std
::
make_shared
<
PathTrie
>
();
root
->
score
=
root
->
log_prob_b_prev
=
0.0
;
root
->
score
=
root
->
log_prob_b_prev
=
0.0
;
prefixes
.
push_back
(
root
.
get
());
prefixes
.
push_back
(
root
.
get
());
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
auto
fst_dict
=
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
init_ext_scorer_
->
dictionary
);
static_cast
<
fst
::
StdVectorFst
*>
(
init_ext_scorer_
->
dictionary
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
root
->
set_dictionary
(
dict_ptr
);
root
->
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
...
@@ -62,23 +74,24 @@ void CTCBeamSearch::InitDecoder() {
...
@@ -62,23 +74,24 @@ void CTCBeamSearch::InitDecoder() {
}
}
}
}
void
CTCBeamSearch
::
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
)
{
void
CTCBeamSearch
::
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
)
{
return
;
return
;
}
}
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
;
}
return
num_frame_decoded_
;
}
// todo rename, refactor
// todo rename, refactor
void
CTCBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
void
CTCBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
int
max_frames
)
{
int
max_frames
)
{
while
(
max_frames
>
0
)
{
while
(
max_frames
>
0
)
{
vector
<
vector
<
BaseFloat
>>
likelihood
;
vector
<
vector
<
BaseFloat
>>
likelihood
;
if
(
decodable
->
IsLastFrame
(
NumFrameDecoded
()
+
1
))
{
if
(
decodable
->
IsLastFrame
(
NumFrameDecoded
()
+
1
))
{
break
;
break
;
}
}
likelihood
.
push_back
(
decodable
->
FrameLogLikelihood
(
NumFrameDecoded
()
+
1
));
likelihood
.
push_back
(
decodable
->
FrameLogLikelihood
(
NumFrameDecoded
()
+
1
));
AdvanceDecoding
(
likelihood
);
AdvanceDecoding
(
likelihood
);
max_frames
--
;
max_frames
--
;
}
}
...
@@ -93,12 +106,13 @@ void CTCBeamSearch::ResetPrefixes() {
...
@@ -93,12 +106,13 @@ void CTCBeamSearch::ResetPrefixes() {
}
}
}
}
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
vector
<
string
>&
nbest_words
)
{
vector
<
string
>&
nbest_words
)
{
kaldi
::
Timer
timer
;
kaldi
::
Timer
timer
;
timer
.
Reset
();
timer
.
Reset
();
AdvanceDecoding
(
probs
);
AdvanceDecoding
(
probs
);
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
<<
static_cast
<
float
>
(
timer
.
Elapsed
())
/
1000.0
f
;
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
<<
static_cast
<
float
>
(
timer
.
Elapsed
())
/
1000.0
f
;
return
0
;
return
0
;
}
}
...
@@ -124,12 +138,13 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
...
@@ -124,12 +138,13 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
double
cutoff_prob
=
opts_
.
cutoff_prob
;
double
cutoff_prob
=
opts_
.
cutoff_prob
;
size_t
cutoff_top_n
=
opts_
.
cutoff_top_n
;
size_t
cutoff_top_n
=
opts_
.
cutoff_top_n
;
vector
<
vector
<
double
>>
probs_seq
(
probs
.
size
(),
vector
<
double
>
(
probs
[
0
].
size
(),
0
));
vector
<
vector
<
double
>>
probs_seq
(
probs
.
size
(),
vector
<
double
>
(
probs
[
0
].
size
(),
0
));
int
row
=
probs
.
size
();
int
row
=
probs
.
size
();
int
col
=
probs
[
0
].
size
();
int
col
=
probs
[
0
].
size
();
for
(
int
i
=
0
;
i
<
row
;
i
++
)
{
for
(
int
i
=
0
;
i
<
row
;
i
++
)
{
for
(
int
j
=
0
;
j
<
col
;
j
++
)
{
for
(
int
j
=
0
;
j
<
col
;
j
++
)
{
probs_seq
[
i
][
j
]
=
static_cast
<
double
>
(
probs
[
i
][
j
]);
probs_seq
[
i
][
j
]
=
static_cast
<
double
>
(
probs
[
i
][
j
]);
}
}
}
}
...
@@ -141,7 +156,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
...
@@ -141,7 +156,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
bool
full_beam
=
false
;
bool
full_beam
=
false
;
if
(
init_ext_scorer_
!=
nullptr
)
{
if
(
init_ext_scorer_
!=
nullptr
)
{
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
prefix_compare
);
if
(
num_prefixes
==
0
)
{
if
(
num_prefixes
==
0
)
{
...
@@ -181,7 +197,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
...
@@ -181,7 +197,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
}
// for probs_seq
}
// for probs_seq
}
}
int32
CTCBeamSearch
::
SearchOneChar
(
const
bool
&
full_beam
,
int32
CTCBeamSearch
::
SearchOneChar
(
const
bool
&
full_beam
,
const
std
::
pair
<
size_t
,
BaseFloat
>&
log_prob_idx
,
const
std
::
pair
<
size_t
,
BaseFloat
>&
log_prob_idx
,
const
BaseFloat
&
min_cutoff
)
{
const
BaseFloat
&
min_cutoff
)
{
size_t
beam_size
=
opts_
.
beam_size
;
size_t
beam_size
=
opts_
.
beam_size
;
...
@@ -196,10 +213,8 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
...
@@ -196,10 +213,8 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
}
}
if
(
c
==
blank_id
)
{
if
(
c
==
blank_id
)
{
prefix
->
log_prob_b_cur
=
log_sum_exp
(
prefix
->
log_prob_b_cur
=
prefix
->
log_prob_b_cur
,
log_sum_exp
(
prefix
->
log_prob_b_cur
,
log_prob_c
+
prefix
->
score
);
log_prob_c
+
prefix
->
score
);
continue
;
continue
;
}
}
...
@@ -207,9 +222,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
...
@@ -207,9 +222,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
if
(
c
==
prefix
->
character
)
{
if
(
c
==
prefix
->
character
)
{
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix
->
log_prob_nb_cur
=
log_sum_exp
(
prefix
->
log_prob_nb_cur
=
log_sum_exp
(
prefix
->
log_prob_nb_cur
,
prefix
->
log_prob_nb_cur
,
log_prob_c
+
prefix
->
log_prob_nb_prev
);
log_prob_c
+
prefix
->
log_prob_nb_prev
);
}
}
// get new prefix
// get new prefix
...
@@ -228,7 +241,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
...
@@ -228,7 +241,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
// language model scoring
// language model scoring
if
(
init_ext_scorer_
!=
nullptr
&&
if
(
init_ext_scorer_
!=
nullptr
&&
(
c
==
space_id
||
init_ext_scorer_
->
is_character_based
()))
{
(
c
==
space_id
||
init_ext_scorer_
->
is_character_based
()))
{
PathTrie
*
prefix_to_score
=
nullptr
;
PathTrie
*
prefix_to_score
=
nullptr
;
// skip scoring the space
// skip scoring the space
if
(
init_ext_scorer_
->
is_character_based
())
{
if
(
init_ext_scorer_
->
is_character_based
())
{
prefix_to_score
=
prefix_new
;
prefix_to_score
=
prefix_new
;
...
@@ -247,8 +260,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
...
@@ -247,8 +260,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
}
}
// p_{nb}(l;x_{1:t})
// p_{nb}(l;x_{1:t})
prefix_new
->
log_prob_nb_cur
=
prefix_new
->
log_prob_nb_cur
=
log_sum_exp
(
prefix_new
->
log_prob_nb_cur
,
log_sum_exp
(
prefix_new
->
log_prob_nb_cur
,
log_p
);
log_p
);
}
}
}
// end of loop over prefix
}
// end of loop over prefix
return
0
;
return
0
;
...
@@ -258,9 +270,7 @@ void CTCBeamSearch::CalculateApproxScore() {
...
@@ -258,9 +270,7 @@ void CTCBeamSearch::CalculateApproxScore() {
size_t
beam_size
=
opts_
.
beam_size
;
size_t
beam_size
=
opts_
.
beam_size
;
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
std
::
sort
(
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
// compute aproximate ctc score as the return score, without affecting the
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
// return order of decoding result. To delete when decoder gets stable.
...
@@ -274,8 +284,8 @@ void CTCBeamSearch::CalculateApproxScore() {
...
@@ -274,8 +284,8 @@ void CTCBeamSearch::CalculateApproxScore() {
// remove word insert
// remove word insert
approx_ctc
=
approx_ctc
-
prefix_length
*
init_ext_scorer_
->
beta
;
approx_ctc
=
approx_ctc
-
prefix_length
*
init_ext_scorer_
->
beta
;
// remove language model weight:
// remove language model weight:
approx_ctc
-=
approx_ctc
-=
(
init_ext_scorer_
->
get_sent_log_prob
(
words
))
*
(
init_ext_scorer_
->
get_sent_log_prob
(
words
))
*
init_ext_scorer_
->
alpha
;
init_ext_scorer_
->
alpha
;
}
}
prefixes
[
i
]
->
approx_ctc
=
approx_ctc
;
prefixes
[
i
]
->
approx_ctc
=
approx_ctc
;
}
}
...
@@ -283,13 +293,15 @@ void CTCBeamSearch::CalculateApproxScore() {
...
@@ -283,13 +293,15 @@ void CTCBeamSearch::CalculateApproxScore() {
void
CTCBeamSearch
::
LMRescore
()
{
void
CTCBeamSearch
::
LMRescore
()
{
size_t
beam_size
=
opts_
.
beam_size
;
size_t
beam_size
=
opts_
.
beam_size
;
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
auto
prefix
=
prefixes
[
i
];
auto
prefix
=
prefixes
[
i
];
if
(
!
prefix
->
is_empty
()
&&
prefix
->
character
!=
space_id
)
{
if
(
!
prefix
->
is_empty
()
&&
prefix
->
character
!=
space_id
)
{
float
score
=
0.0
;
float
score
=
0.0
;
vector
<
string
>
ngram
=
init_ext_scorer_
->
make_ngram
(
prefix
);
vector
<
string
>
ngram
=
init_ext_scorer_
->
make_ngram
(
prefix
);
score
=
init_ext_scorer_
->
get_log_cond_prob
(
ngram
)
*
init_ext_scorer_
->
alpha
;
score
=
init_ext_scorer_
->
get_log_cond_prob
(
ngram
)
*
init_ext_scorer_
->
alpha
;
score
+=
init_ext_scorer_
->
beta
;
score
+=
init_ext_scorer_
->
beta
;
prefix
->
score
+=
score
;
prefix
->
score
+=
score
;
}
}
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h"
#include "nnet/decodable-itf.h"
#include "util/parse-options.h"
#include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once
#pragma once
...
@@ -17,26 +31,27 @@ struct CTCBeamSearchOptions {
...
@@ -17,26 +31,27 @@ struct CTCBeamSearchOptions {
int
beam_size
;
int
beam_size
;
int
cutoff_top_n
;
int
cutoff_top_n
;
int
num_proc_bsearch
;
int
num_proc_bsearch
;
CTCBeamSearchOptions
()
:
CTCBeamSearchOptions
()
dict_file
(
"./model/words.txt"
),
:
dict_file
(
"./model/words.txt"
),
lm_path
(
"./model/lm.arpa"
),
lm_path
(
"./model/lm.arpa"
),
alpha
(
1.9
f
),
alpha
(
1.9
f
),
beta
(
5.0
),
beta
(
5.0
),
beam_size
(
300
),
beam_size
(
300
),
cutoff_prob
(
0.99
f
),
cutoff_prob
(
0.99
f
),
cutoff_top_n
(
40
),
cutoff_top_n
(
40
),
num_proc_bsearch
(
0
)
{
num_proc_bsearch
(
0
)
{}
}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"dict"
,
&
dict_file
,
"dict file "
);
opts
->
Register
(
"dict"
,
&
dict_file
,
"dict file "
);
opts
->
Register
(
"lm-path"
,
&
lm_path
,
"language model file"
);
opts
->
Register
(
"lm-path"
,
&
lm_path
,
"language model file"
);
opts
->
Register
(
"alpha"
,
&
alpha
,
"alpha"
);
opts
->
Register
(
"alpha"
,
&
alpha
,
"alpha"
);
opts
->
Register
(
"beta"
,
&
beta
,
"beta"
);
opts
->
Register
(
"beta"
,
&
beta
,
"beta"
);
opts
->
Register
(
"beam-size"
,
&
beam_size
,
"beam size for beam search method"
);
opts
->
Register
(
"beam-size"
,
&
beam_size
,
"beam size for beam search method"
);
opts
->
Register
(
"cutoff-prob"
,
&
cutoff_prob
,
"cutoff probs"
);
opts
->
Register
(
"cutoff-prob"
,
&
cutoff_prob
,
"cutoff probs"
);
opts
->
Register
(
"cutoff-top-n"
,
&
cutoff_top_n
,
"cutoff top n"
);
opts
->
Register
(
"cutoff-top-n"
,
&
cutoff_top_n
,
"cutoff top n"
);
opts
->
Register
(
"num-proc-bsearch"
,
&
num_proc_bsearch
,
"num proc bsearch"
);
opts
->
Register
(
"num-proc-bsearch"
,
&
num_proc_bsearch
,
"num proc bsearch"
);
}
}
};
};
...
@@ -50,11 +65,13 @@ class CTCBeamSearch {
...
@@ -50,11 +65,13 @@ class CTCBeamSearch {
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
();
std
::
string
GetFinalBestPath
();
std
::
string
GetFinalBestPath
();
int
NumFrameDecoded
();
int
NumFrameDecoded
();
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
std
::
vector
<
std
::
string
>&
nbest_words
);
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
int
max_frames
);
int
max_frames
);
void
Reset
();
void
Reset
();
private:
private:
void
ResetPrefixes
();
void
ResetPrefixes
();
int32
SearchOneChar
(
const
bool
&
full_beam
,
int32
SearchOneChar
(
const
bool
&
full_beam
,
...
@@ -66,7 +83,7 @@ class CTCBeamSearch {
...
@@ -66,7 +83,7 @@ class CTCBeamSearch {
CTCBeamSearchOptions
opts_
;
CTCBeamSearchOptions
opts_
;
std
::
shared_ptr
<
Scorer
>
init_ext_scorer_
;
// todo separate later
std
::
shared_ptr
<
Scorer
>
init_ext_scorer_
;
// todo separate later
//
std::vector<DecodeResult> decoder_results_;
//
std::vector<DecodeResult> decoder_results_;
std
::
vector
<
std
::
string
>
vocabulary_
;
// todo remove later
std
::
vector
<
std
::
string
>
vocabulary_
;
// todo remove later
size_t
blank_id
;
size_t
blank_id
;
int
space_id
;
int
space_id
;
...
...
speechx/speechx/frontend/fbank.h
浏览文件 @
41feecbd
...
@@ -24,7 +24,8 @@ class FbankExtractor : FeatureExtractorInterface {
...
@@ -24,7 +24,8 @@ class FbankExtractor : FeatureExtractorInterface {
public:
public:
explicit
FbankExtractor
(
const
FbankOptions
&
opts
,
explicit
FbankExtractor
(
const
FbankOptions
&
opts
,
share_ptr
<
FeatureExtractorInterface
>
pre_extractor
);
share_ptr
<
FeatureExtractorInterface
>
pre_extractor
);
virtual
void
AcceptWaveform
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
input
)
=
0
;
virtual
void
AcceptWaveform
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
input
)
=
0
;
virtual
void
Read
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feat
)
=
0
;
virtual
void
Read
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feat
)
=
0
;
virtual
size_t
Dim
()
const
=
0
;
virtual
size_t
Dim
()
const
=
0
;
...
...
speechx/speechx/frontend/feature_extractor_controller.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
speechx/speechx/frontend/feature_extractor_controller_impl.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
speechx/speechx/frontend/feature_extractor_interface.h
浏览文件 @
41feecbd
...
@@ -21,7 +21,8 @@ namespace ppspeech {
...
@@ -21,7 +21,8 @@ namespace ppspeech {
class
FeatureExtractorInterface
{
class
FeatureExtractorInterface
{
public:
public:
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
)
=
0
;
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
)
=
0
;
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
)
=
0
;
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
)
=
0
;
virtual
size_t
Dim
()
const
=
0
;
virtual
size_t
Dim
()
const
=
0
;
};
};
...
...
speechx/speechx/frontend/linear_spectrogram.cc
浏览文件 @
41feecbd
...
@@ -25,7 +25,7 @@ using kaldi::VectorBase;
...
@@ -25,7 +25,7 @@ using kaldi::VectorBase;
using
kaldi
::
Matrix
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
using
std
::
vector
;
//todo remove later
//
todo remove later
void
CopyVector2StdVector_
(
const
VectorBase
<
BaseFloat
>&
input
,
void
CopyVector2StdVector_
(
const
VectorBase
<
BaseFloat
>&
input
,
vector
<
BaseFloat
>*
output
)
{
vector
<
BaseFloat
>*
output
)
{
if
(
input
.
Dim
()
==
0
)
return
;
if
(
input
.
Dim
()
==
0
)
return
;
...
...
speechx/speechx/frontend/linear_spectrogram.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h"
#include "kaldi/feat/feature-window.h"
#include "base/common.h"
namespace
ppspeech
{
namespace
ppspeech
{
struct
LinearSpectrogramOptions
{
struct
LinearSpectrogramOptions
{
kaldi
::
FrameExtractionOptions
frame_opts
;
kaldi
::
FrameExtractionOptions
frame_opts
;
LinearSpectrogramOptions
()
:
LinearSpectrogramOptions
()
:
frame_opts
()
{}
frame_opts
()
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
frame_opts
.
Register
(
opts
);
}
frame_opts
.
Register
(
opts
);
}
};
};
class
LinearSpectrogram
:
public
FeatureExtractorInterface
{
class
LinearSpectrogram
:
public
FeatureExtractorInterface
{
public:
public:
explicit
LinearSpectrogram
(
const
LinearSpectrogramOptions
&
opts
,
explicit
LinearSpectrogram
(
const
LinearSpectrogramOptions
&
opts
,
std
::
unique_ptr
<
FeatureExtractorInterface
>
base_extractor
);
std
::
unique_ptr
<
FeatureExtractorInterface
>
base_extractor
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
size_t
Dim
()
const
{
return
dim_
;
}
virtual
size_t
Dim
()
const
{
return
dim_
;
}
void
ReadFeats
(
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
feats
);
void
ReadFeats
(
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
feats
);
...
...
speechx/speechx/frontend/normalizer.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/normalizer.h"
#include "frontend/normalizer.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/feat/cmvn.h"
...
@@ -16,7 +30,8 @@ DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
...
@@ -16,7 +30,8 @@ DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
dim_
=
0
;
dim_
=
0
;
}
}
void
DecibelNormalizer
::
AcceptWaveform
(
const
kaldi
::
VectorBase
<
BaseFloat
>&
input
)
{
void
DecibelNormalizer
::
AcceptWaveform
(
const
kaldi
::
VectorBase
<
BaseFloat
>&
input
)
{
dim_
=
input
.
Dim
();
dim_
=
input
.
Dim
();
waveform_
.
Resize
(
input
.
Dim
());
waveform_
.
Resize
(
input
.
Dim
());
waveform_
.
CopyFromVec
(
input
);
waveform_
.
CopyFromVec
(
input
);
...
@@ -27,7 +42,7 @@ void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
...
@@ -27,7 +42,7 @@ void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
Compute
(
waveform_
,
feat
);
Compute
(
waveform_
,
feat
);
}
}
//todo remove later
//
todo remove later
void
CopyVector2StdVector
(
const
kaldi
::
VectorBase
<
BaseFloat
>&
input
,
void
CopyVector2StdVector
(
const
kaldi
::
VectorBase
<
BaseFloat
>&
input
,
vector
<
BaseFloat
>*
output
)
{
vector
<
BaseFloat
>*
output
)
{
if
(
input
.
Dim
()
==
0
)
return
;
if
(
input
.
Dim
()
==
0
)
return
;
...
@@ -61,7 +76,7 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
...
@@ -61,7 +76,7 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
}
}
// square
// square
for
(
auto
&
d
:
samples
)
{
for
(
auto
&
d
:
samples
)
{
if
(
opts_
.
convert_int_float
)
{
if
(
opts_
.
convert_int_float
)
{
d
=
d
*
wave_float_normlization
;
d
=
d
*
wave_float_normlization
;
}
}
...
@@ -74,14 +89,15 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
...
@@ -74,14 +89,15 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
gain
=
opts_
.
target_db
-
rms_db
;
gain
=
opts_
.
target_db
-
rms_db
;
if
(
gain
>
opts_
.
max_gain_db
)
{
if
(
gain
>
opts_
.
max_gain_db
)
{
LOG
(
ERROR
)
<<
"Unable to normalize segment to "
<<
opts_
.
target_db
<<
"dB,"
LOG
(
ERROR
)
<<
"Unable to normalize segment to "
<<
opts_
.
target_db
<<
"dB,"
<<
"because the the probable gain have exceeds opts_.max_gain_db"
<<
"because the the probable gain have exceeds opts_.max_gain_db"
<<
opts_
.
max_gain_db
<<
"dB."
;
<<
opts_
.
max_gain_db
<<
"dB."
;
return
false
;
return
false
;
}
}
// Note that this is an in-place transformation.
// Note that this is an in-place transformation.
for
(
auto
&
item
:
samples
)
{
for
(
auto
&
item
:
samples
)
{
// python item *= 10.0 ** (gain / 20.0)
// python item *= 10.0 ** (gain / 20.0)
item
*=
std
::
pow
(
10.0
,
gain
/
20.0
);
item
*=
std
::
pow
(
10.0
,
gain
/
20.0
);
}
}
...
@@ -100,21 +116,20 @@ void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) {
...
@@ -100,21 +116,20 @@ void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) {
return
;
return
;
}
}
void
CMVN
::
Read
(
kaldi
::
VectorBase
<
BaseFloat
>*
feat
)
{
void
CMVN
::
Read
(
kaldi
::
VectorBase
<
BaseFloat
>*
feat
)
{
return
;
}
return
;
}
// feats contain num_frames feature.
// feats contain num_frames feature.
void
CMVN
::
ApplyCMVN
(
bool
var_norm
,
VectorBase
<
BaseFloat
>*
feats
)
{
void
CMVN
::
ApplyCMVN
(
bool
var_norm
,
VectorBase
<
BaseFloat
>*
feats
)
{
KALDI_ASSERT
(
feats
!=
NULL
);
KALDI_ASSERT
(
feats
!=
NULL
);
int32
dim
=
stats_
.
NumCols
()
-
1
;
int32
dim
=
stats_
.
NumCols
()
-
1
;
if
(
stats_
.
NumRows
()
>
2
||
stats_
.
NumRows
()
<
1
||
feats
->
Dim
()
%
dim
!=
0
)
{
if
(
stats_
.
NumRows
()
>
2
||
stats_
.
NumRows
()
<
1
||
KALDI_ERR
<<
"Dim mismatch: cmvn "
feats
->
Dim
()
%
dim
!=
0
)
{
<<
stats_
.
NumRows
()
<<
'x'
<<
stats_
.
NumCols
()
KALDI_ERR
<<
"Dim mismatch: cmvn "
<<
stats_
.
NumRows
()
<<
'x'
<<
", feats "
<<
feats
->
Dim
()
<<
'x'
;
<<
stats_
.
NumCols
()
<<
", feats "
<<
feats
->
Dim
()
<<
'x'
;
}
}
if
(
stats_
.
NumRows
()
==
1
&&
var_norm
)
{
if
(
stats_
.
NumRows
()
==
1
&&
var_norm
)
{
KALDI_ERR
<<
"You requested variance normalization but no variance stats_ "
KALDI_ERR
<<
"You requested variance normalization but no variance stats_ "
<<
"are supplied."
;
<<
"are supplied."
;
}
}
...
@@ -122,17 +137,20 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
...
@@ -122,17 +137,20 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one.
// computing an offset and representing it as stats_, we use a count of one.
if
(
count
<
1.0
)
if
(
count
<
1.0
)
KALDI_ERR
<<
"Insufficient stats_ for cepstral mean and variance normalization: "
KALDI_ERR
<<
"Insufficient stats_ for cepstral mean and variance "
"normalization: "
<<
"count = "
<<
count
;
<<
"count = "
<<
count
;
if
(
!
var_norm
)
{
if
(
!
var_norm
)
{
Vector
<
BaseFloat
>
offset
(
feats
->
Dim
());
Vector
<
BaseFloat
>
offset
(
feats
->
Dim
());
SubVector
<
double
>
mean_stats
(
stats_
.
RowData
(
0
),
dim
);
SubVector
<
double
>
mean_stats
(
stats_
.
RowData
(
0
),
dim
);
Vector
<
double
>
mean_stats_apply
(
feats
->
Dim
());
Vector
<
double
>
mean_stats_apply
(
feats
->
Dim
());
//fill the datat of mean_stats in mean_stats_appy whose dim is equal with the dim of feature.
// fill the datat of mean_stats in mean_stats_appy whose dim is equal
//the dim of feats = dim * num_frames;
// with the dim of feature.
// the dim of feats = dim * num_frames;
for
(
int32
idx
=
0
;
idx
<
feats
->
Dim
()
/
dim
;
++
idx
)
{
for
(
int32
idx
=
0
;
idx
<
feats
->
Dim
()
/
dim
;
++
idx
)
{
SubVector
<
double
>
stats_tmp
(
mean_stats_apply
.
Data
()
+
dim
*
idx
,
dim
);
SubVector
<
double
>
stats_tmp
(
mean_stats_apply
.
Data
()
+
dim
*
idx
,
dim
);
stats_tmp
.
CopyFromVec
(
mean_stats
);
stats_tmp
.
CopyFromVec
(
mean_stats
);
}
}
offset
.
AddVec
(
-
1.0
/
count
,
mean_stats_apply
);
offset
.
AddVec
(
-
1.0
/
count
,
mean_stats_apply
);
...
@@ -144,18 +162,18 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
...
@@ -144,18 +162,18 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
kaldi
::
Matrix
<
BaseFloat
>
norm
(
2
,
feats
->
Dim
());
kaldi
::
Matrix
<
BaseFloat
>
norm
(
2
,
feats
->
Dim
());
for
(
int32
d
=
0
;
d
<
dim
;
d
++
)
{
for
(
int32
d
=
0
;
d
<
dim
;
d
++
)
{
double
mean
,
offset
,
scale
;
double
mean
,
offset
,
scale
;
mean
=
stats_
(
0
,
d
)
/
count
;
mean
=
stats_
(
0
,
d
)
/
count
;
double
var
=
(
stats_
(
1
,
d
)
/
count
)
-
mean
*
mean
,
double
var
=
(
stats_
(
1
,
d
)
/
count
)
-
mean
*
mean
,
floor
=
1.0e-20
;
floor
=
1.0e-20
;
if
(
var
<
floor
)
{
if
(
var
<
floor
)
{
KALDI_WARN
<<
"Flooring cepstral variance from "
<<
var
<<
" to "
KALDI_WARN
<<
"Flooring cepstral variance from "
<<
var
<<
" to "
<<
floor
;
<<
floor
;
var
=
floor
;
var
=
floor
;
}
}
scale
=
1.0
/
sqrt
(
var
);
scale
=
1.0
/
sqrt
(
var
);
if
(
scale
!=
scale
||
1
/
scale
==
0.0
)
if
(
scale
!=
scale
||
1
/
scale
==
0.0
)
KALDI_ERR
<<
"NaN or infinity in cepstral mean/variance computation"
;
KALDI_ERR
offset
=
-
(
mean
*
scale
);
<<
"NaN or infinity in cepstral mean/variance computation"
;
offset
=
-
(
mean
*
scale
);
for
(
int32
d_skip
=
d
;
d_skip
<
feats
->
Dim
();)
{
for
(
int32
d_skip
=
d
;
d_skip
<
feats
->
Dim
();)
{
norm
(
0
,
d_skip
)
=
offset
;
norm
(
0
,
d_skip
)
=
offset
;
norm
(
1
,
d_skip
)
=
scale
;
norm
(
1
,
d_skip
)
=
scale
;
...
...
speechx/speechx/frontend/normalizer.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#pragma once
#include "base/common.h"
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/options-itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace
ppspeech
{
namespace
ppspeech
{
...
@@ -12,26 +26,30 @@ struct DecibelNormalizerOptions {
...
@@ -12,26 +26,30 @@ struct DecibelNormalizerOptions {
float
target_db
;
float
target_db
;
float
max_gain_db
;
float
max_gain_db
;
bool
convert_int_float
;
bool
convert_int_float
;
DecibelNormalizerOptions
()
:
DecibelNormalizerOptions
()
target_db
(
-
20
),
:
target_db
(
-
20
),
max_gain_db
(
300.0
),
convert_int_float
(
false
)
{}
max_gain_db
(
300.0
),
convert_int_float
(
false
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"target-db"
,
&
target_db
,
"target db for db normalization"
);
opts
->
Register
(
opts
->
Register
(
"max-gain-db"
,
&
max_gain_db
,
"max gain db for db normalization"
);
"target-db"
,
&
target_db
,
"target db for db normalization"
);
opts
->
Register
(
"convert-int-float"
,
&
convert_int_float
,
"if convert int samples to float"
);
opts
->
Register
(
"max-gain-db"
,
&
max_gain_db
,
"max gain db for db normalization"
);
opts
->
Register
(
"convert-int-float"
,
&
convert_int_float
,
"if convert int samples to float"
);
}
}
};
};
class
DecibelNormalizer
:
public
FeatureExtractorInterface
{
class
DecibelNormalizer
:
public
FeatureExtractorInterface
{
public:
public:
explicit
DecibelNormalizer
(
const
DecibelNormalizerOptions
&
opts
);
explicit
DecibelNormalizer
(
const
DecibelNormalizerOptions
&
opts
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
size_t
Dim
()
const
{
return
dim_
;
}
virtual
size_t
Dim
()
const
{
return
dim_
;
}
bool
Compute
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
,
bool
Compute
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
,
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
)
const
;
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
)
const
;
private:
private:
DecibelNormalizerOptions
opts_
;
DecibelNormalizerOptions
opts_
;
size_t
dim_
;
size_t
dim_
;
...
@@ -43,7 +61,8 @@ class DecibelNormalizer : public FeatureExtractorInterface {
...
@@ -43,7 +61,8 @@ class DecibelNormalizer : public FeatureExtractorInterface {
class
CMVN
:
public
FeatureExtractorInterface
{
class
CMVN
:
public
FeatureExtractorInterface
{
public:
public:
explicit
CMVN
(
std
::
string
cmvn_file
);
explicit
CMVN
(
std
::
string
cmvn_file
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
size_t
Dim
()
const
{
return
stats_
.
NumCols
()
-
1
;
}
virtual
size_t
Dim
()
const
{
return
stats_
.
NumCols
()
-
1
;
}
bool
Compute
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
,
bool
Compute
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
,
...
@@ -51,6 +70,7 @@ class CMVN : public FeatureExtractorInterface {
...
@@ -51,6 +70,7 @@ class CMVN : public FeatureExtractorInterface {
// for test
// for test
void
ApplyCMVN
(
bool
var_norm
,
kaldi
::
VectorBase
<
BaseFloat
>*
feats
);
void
ApplyCMVN
(
bool
var_norm
,
kaldi
::
VectorBase
<
BaseFloat
>*
feats
);
void
ApplyCMVNMatrix
(
bool
var_norm
,
kaldi
::
MatrixBase
<
BaseFloat
>*
feats
);
void
ApplyCMVNMatrix
(
bool
var_norm
,
kaldi
::
MatrixBase
<
BaseFloat
>*
feats
);
private:
private:
kaldi
::
Matrix
<
double
>
stats_
;
kaldi
::
Matrix
<
double
>
stats_
;
std
::
shared_ptr
<
FeatureExtractorInterface
>
base_extractor_
;
std
::
shared_ptr
<
FeatureExtractorInterface
>
base_extractor_
;
...
...
speechx/speechx/frontend/window.h
浏览文件 @
41feecbd
...
@@ -13,4 +13,3 @@
...
@@ -13,4 +13,3 @@
// limitations under the License.
// limitations under the License.
// extract the window of kaldi feat.
// extract the window of kaldi feat.
speechx/speechx/nnet/decodable-itf.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
...
@@ -42,8 +56,10 @@ namespace kaldi {
...
@@ -42,8 +56,10 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/,
There are two ways these are used: the old online-decoding code, in
and the new online-decoding code, in ../online2/. In the old online-decoding
../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do:
code, the decoder would do:
\code{.cc}
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
...
@@ -52,13 +68,16 @@ namespace kaldi {
...
@@ -52,13 +68,16 @@ namespace kaldi {
\endcode
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for
online-decoding mode is still supported, it is what happens when you call,
for
example, LatticeFasterDecoder::Decode().
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example)
control endpointing. In the "new" decoding code, you don't call (for
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
something like this:
...
@@ -68,7 +87,8 @@ namespace kaldi {
...
@@ -68,7 +87,8 @@ namespace kaldi {
}
}
\endcode
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will
you are starting with a matrix of features, the NumFramesReady() function
will
always just return the number of frames in the file, and IsLastFrame() will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
return true for the last frame.
...
@@ -82,30 +102,39 @@ namespace kaldi {
...
@@ -82,30 +102,39 @@ namespace kaldi {
class
DecodableInterface
{
class
DecodableInterface
{
public:
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this.
/// before calling this.
virtual
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
)
=
0
;
virtual
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
)
=
0
;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting
/// be careful). Caution: the behavior of this function in an online
/// setting
/// is being changed somewhat. In future it may return false in cases where
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide
/// we haven't yet decided to terminate decoding, but later true if we
/// decide
/// to terminate decoding. The plan in future is to rely more on
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// decoding-from-matrix setting where we want to allow the last delta or
/// LDA
/// features to be flushed out for compatibility with the baseline setup.
/// features to be flushed out for compatibility with the baseline setup.
virtual
bool
IsLastFrame
(
int32
frame
)
const
=
0
;
virtual
bool
IsLastFrame
(
int32
frame
)
const
=
0
;
/// The call NumFramesReady() will return the number of frames currently available
/// The call NumFramesReady() will return the number of frames currently
/// for this decodable object. This is for use in setups where you don't want the
/// available
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// for this decodable object. This is for use in setups where you don't
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// want the
/// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// and I hope, going forward, to rely on this mechanism more than
/// IsLastFrame to
/// know when to stop decoding.
/// know when to stop decoding.
virtual
int32
NumFramesReady
()
const
{
virtual
int32
NumFramesReady
()
const
{
KALDI_ERR
<<
"NumFramesReady() not implemented for this decodable type."
;
KALDI_ERR
<<
"NumFramesReady() not implemented for this decodable type."
;
return
-
1
;
return
-
1
;
}
}
...
...
speechx/speechx/nnet/decodable.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/decodable.h"
#include "nnet/decodable.h"
namespace
ppspeech
{
namespace
ppspeech
{
...
@@ -5,18 +19,14 @@ namespace ppspeech {
...
@@ -5,18 +19,14 @@ namespace ppspeech {
using
kaldi
::
BaseFloat
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
using
kaldi
::
Matrix
;
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
)
:
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
)
frontend_
(
NULL
),
:
frontend_
(
NULL
),
nnet_
(
nnet
),
finished_
(
false
),
frames_ready_
(
0
)
{}
nnet_
(
nnet
),
finished_
(
false
),
frames_ready_
(
0
)
{
}
void
Decodable
::
Acceptlikelihood
(
const
Matrix
<
BaseFloat
>&
likelihood
)
{
void
Decodable
::
Acceptlikelihood
(
const
Matrix
<
BaseFloat
>&
likelihood
)
{
frames_ready_
+=
likelihood
.
NumRows
();
frames_ready_
+=
likelihood
.
NumRows
();
}
}
//Decodable::Init(DecodableConfig config) {
//
Decodable::Init(DecodableConfig config) {
//}
//}
bool
Decodable
::
IsLastFrame
(
int32
frame
)
const
{
bool
Decodable
::
IsLastFrame
(
int32
frame
)
const
{
...
@@ -24,18 +34,14 @@ bool Decodable::IsLastFrame(int32 frame) const {
...
@@ -24,18 +34,14 @@ bool Decodable::IsLastFrame(int32 frame) const {
return
finished_
&&
(
frame
==
frames_ready_
-
1
);
return
finished_
&&
(
frame
==
frames_ready_
-
1
);
}
}
int32
Decodable
::
NumIndices
()
const
{
int32
Decodable
::
NumIndices
()
const
{
return
0
;
}
return
0
;
}
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
return
0
;
}
return
0
;
}
void
Decodable
::
FeedFeatures
(
const
Matrix
<
kaldi
::
BaseFloat
>&
features
)
{
void
Decodable
::
FeedFeatures
(
const
Matrix
<
kaldi
::
BaseFloat
>&
features
)
{
nnet_
->
FeedForward
(
features
,
&
nnet_cache_
);
nnet_
->
FeedForward
(
features
,
&
nnet_cache_
);
frames_ready_
+=
nnet_cache_
.
NumRows
();
frames_ready_
+=
nnet_cache_
.
NumRows
();
return
;
return
;
}
}
std
::
vector
<
BaseFloat
>
Decodable
::
FrameLogLikelihood
(
int32
frame
)
{
std
::
vector
<
BaseFloat
>
Decodable
::
FrameLogLikelihood
(
int32
frame
)
{
...
...
speechx/speechx/nnet/decodable.h
浏览文件 @
41feecbd
#include "nnet/decodable-itf.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h"
#include "nnet/nnet_interface.h"
namespace
ppspeech
{
namespace
ppspeech
{
...
@@ -11,15 +25,18 @@ struct DecodableOpts;
...
@@ -11,15 +25,18 @@ struct DecodableOpts;
class
Decodable
:
public
kaldi
::
DecodableInterface
{
class
Decodable
:
public
kaldi
::
DecodableInterface
{
public:
public:
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
);
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
);
//void Init(DecodableOpts config);
//
void Init(DecodableOpts config);
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
virtual
bool
IsLastFrame
(
int32
frame
)
const
;
virtual
bool
IsLastFrame
(
int32
frame
)
const
;
virtual
int32
NumIndices
()
const
;
virtual
int32
NumIndices
()
const
;
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
);
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
);
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
// remove later
void
Acceptlikelihood
(
void
FeedFeatures
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
feature
);
// only for test, todo remove later
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
// remove later
void
FeedFeatures
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
feature
);
// only for test, todo remove later
void
Reset
();
void
Reset
();
void
InputFinished
()
{
finished_
=
true
;
}
void
InputFinished
()
{
finished_
=
true
;
}
private:
private:
std
::
shared_ptr
<
FeatureExtractorInterface
>
frontend_
;
std
::
shared_ptr
<
FeatureExtractorInterface
>
frontend_
;
std
::
shared_ptr
<
NnetInterface
>
nnet_
;
std
::
shared_ptr
<
NnetInterface
>
nnet_
;
...
...
speechx/speechx/nnet/nnet_interface.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#pragma once
...
@@ -10,10 +24,9 @@ namespace ppspeech {
...
@@ -10,10 +24,9 @@ namespace ppspeech {
class
NnetInterface
{
class
NnetInterface
{
public:
public:
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
)
=
0
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
)
=
0
;
virtual
void
Reset
()
=
0
;
virtual
void
Reset
()
=
0
;
virtual
~
NnetInterface
()
{}
virtual
~
NnetInterface
()
{}
};
};
}
// namespace ppspeech
}
// namespace ppspeech
speechx/speechx/nnet/paddle_nnet.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/paddle_nnet.h"
#include "nnet/paddle_nnet.h"
#include "absl/strings/str_split.h"
#include "absl/strings/str_split.h"
...
@@ -21,18 +35,18 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
...
@@ -21,18 +35,18 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std
::
vector
<
std
::
string
>
tmp_shape
;
std
::
vector
<
std
::
string
>
tmp_shape
;
tmp_shape
=
absl
::
StrSplit
(
cache_shapes
[
i
],
"-"
);
tmp_shape
=
absl
::
StrSplit
(
cache_shapes
[
i
],
"-"
);
std
::
vector
<
int
>
cur_shape
;
std
::
vector
<
int
>
cur_shape
;
std
::
transform
(
tmp_shape
.
begin
(),
tmp_shape
.
end
(),
std
::
transform
(
tmp_shape
.
begin
(),
tmp_shape
.
end
(),
std
::
back_inserter
(
cur_shape
),
std
::
back_inserter
(
cur_shape
),
[](
const
std
::
string
&
s
)
{
[](
const
std
::
string
&
s
)
{
return
atoi
(
s
.
c_str
());
});
return
atoi
(
s
.
c_str
());
});
cache_names_idx_
[
cache_names
[
i
]]
=
i
;
cache_names_idx_
[
cache_names
[
i
]]
=
i
;
std
::
shared_ptr
<
Tensor
<
BaseFloat
>>
cache_eout
=
std
::
make_shared
<
Tensor
<
BaseFloat
>>
(
cur_shape
);
std
::
shared_ptr
<
Tensor
<
BaseFloat
>>
cache_eout
=
std
::
make_shared
<
Tensor
<
BaseFloat
>>
(
cur_shape
);
cache_encouts_
.
push_back
(
cache_eout
);
cache_encouts_
.
push_back
(
cache_eout
);
}
}
}
}
PaddleNnet
::
PaddleNnet
(
const
ModelOptions
&
opts
)
:
opts_
(
opts
)
{
PaddleNnet
::
PaddleNnet
(
const
ModelOptions
&
opts
)
:
opts_
(
opts
)
{
paddle_infer
::
Config
config
;
paddle_infer
::
Config
config
;
config
.
SetModel
(
opts
.
model_path
,
opts
.
params_path
);
config
.
SetModel
(
opts
.
model_path
,
opts
.
params_path
);
if
(
opts
.
use_gpu
)
{
if
(
opts
.
use_gpu
)
{
...
@@ -45,7 +59,8 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
...
@@ -45,7 +59,8 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
if
(
opts
.
enable_profile
)
{
if
(
opts
.
enable_profile
)
{
config
.
EnableProfile
();
config
.
EnableProfile
();
}
}
pool
.
reset
(
new
paddle_infer
::
services
::
PredictorPool
(
config
,
opts
.
thread_num
));
pool
.
reset
(
new
paddle_infer
::
services
::
PredictorPool
(
config
,
opts
.
thread_num
));
if
(
pool
==
nullptr
)
{
if
(
pool
==
nullptr
)
{
LOG
(
ERROR
)
<<
"create the predictor pool failed"
;
LOG
(
ERROR
)
<<
"create the predictor pool failed"
;
}
}
...
@@ -68,16 +83,14 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
...
@@ -68,16 +83,14 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
std
::
vector
<
std
::
string
>
model_output_names
=
predictor
->
GetOutputNames
();
std
::
vector
<
std
::
string
>
model_output_names
=
predictor
->
GetOutputNames
();
assert
(
output_names_vec
.
size
()
==
model_output_names
.
size
());
assert
(
output_names_vec
.
size
()
==
model_output_names
.
size
());
for
(
size_t
i
=
0
;
i
<
output_names_vec
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
output_names_vec
.
size
();
i
++
)
{
assert
(
output_names_vec
[
i
]
==
model_output_names
[
i
]);
assert
(
output_names_vec
[
i
]
==
model_output_names
[
i
]);
}
}
ReleasePredictor
(
predictor
);
ReleasePredictor
(
predictor
);
InitCacheEncouts
(
opts
);
InitCacheEncouts
(
opts
);
}
}
void
PaddleNnet
::
Reset
()
{
void
PaddleNnet
::
Reset
()
{
InitCacheEncouts
(
opts_
);
}
InitCacheEncouts
(
opts_
);
}
paddle_infer
::
Predictor
*
PaddleNnet
::
GetPredictor
()
{
paddle_infer
::
Predictor
*
PaddleNnet
::
GetPredictor
()
{
LOG
(
INFO
)
<<
"attempt to get a new predictor instance "
<<
std
::
endl
;
LOG
(
INFO
)
<<
"attempt to get a new predictor instance "
<<
std
::
endl
;
...
@@ -130,13 +143,14 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
...
@@ -130,13 +143,14 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
return
cache_encouts_
[
iter
->
second
];
return
cache_encouts_
[
iter
->
second
];
}
}
void
PaddleNnet
::
FeedForward
(
const
Matrix
<
BaseFloat
>&
features
,
Matrix
<
BaseFloat
>*
inferences
)
{
void
PaddleNnet
::
FeedForward
(
const
Matrix
<
BaseFloat
>&
features
,
Matrix
<
BaseFloat
>*
inferences
)
{
paddle_infer
::
Predictor
*
predictor
=
GetPredictor
();
paddle_infer
::
Predictor
*
predictor
=
GetPredictor
();
int
row
=
features
.
NumRows
();
int
row
=
features
.
NumRows
();
int
col
=
features
.
NumCols
();
int
col
=
features
.
NumCols
();
std
::
vector
<
BaseFloat
>
feed_feature
;
std
::
vector
<
BaseFloat
>
feed_feature
;
// todo refactor feed feature: SmileGoat
// todo refactor feed feature: SmileGoat
feed_feature
.
reserve
(
row
*
col
);
feed_feature
.
reserve
(
row
*
col
);
for
(
size_t
row_idx
=
0
;
row_idx
<
features
.
NumRows
();
++
row_idx
)
{
for
(
size_t
row_idx
=
0
;
row_idx
<
features
.
NumRows
();
++
row_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
features
.
NumCols
();
++
col_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
features
.
NumCols
();
++
col_idx
)
{
feed_feature
.
push_back
(
features
(
row_idx
,
col_idx
));
feed_feature
.
push_back
(
features
(
row_idx
,
col_idx
));
...
@@ -146,22 +160,26 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
...
@@ -146,22 +160,26 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
LOG
(
INFO
)
<<
"feat info: row="
<<
row
<<
", col= "
<<
col
;
LOG
(
INFO
)
<<
"feat info: row="
<<
row
<<
", col= "
<<
col
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
row
,
col
};
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
row
,
col
};
input_tensor
->
Reshape
(
INPUT_SHAPE
);
input_tensor
->
Reshape
(
INPUT_SHAPE
);
input_tensor
->
CopyFromCpu
(
feed_feature
.
data
());
input_tensor
->
CopyFromCpu
(
feed_feature
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
vector
<
int
>
input_len_size
=
{
1
};
std
::
vector
<
int
>
input_len_size
=
{
1
};
input_len
->
Reshape
(
input_len_size
);
input_len
->
Reshape
(
input_len_size
);
std
::
vector
<
int64_t
>
audio_len
;
std
::
vector
<
int64_t
>
audio_len
;
audio_len
.
push_back
(
row
);
audio_len
.
push_back
(
row
);
input_len
->
CopyFromCpu
(
audio_len
.
data
());
input_len
->
CopyFromCpu
(
audio_len
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
shared_ptr
<
Tensor
<
BaseFloat
>>
h_cache
=
GetCacheEncoder
(
input_names
[
2
]);
shared_ptr
<
Tensor
<
BaseFloat
>>
h_cache
=
GetCacheEncoder
(
input_names
[
2
]);
h_box
->
Reshape
(
h_cache
->
get_shape
());
h_box
->
Reshape
(
h_cache
->
get_shape
());
h_box
->
CopyFromCpu
(
h_cache
->
get_data
().
data
());
h_box
->
CopyFromCpu
(
h_cache
->
get_data
().
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
shared_ptr
<
Tensor
<
float
>>
c_cache
=
GetCacheEncoder
(
input_names
[
3
]);
shared_ptr
<
Tensor
<
float
>>
c_cache
=
GetCacheEncoder
(
input_names
[
3
]);
c_box
->
Reshape
(
c_cache
->
get_shape
());
c_box
->
Reshape
(
c_cache
->
get_shape
());
c_box
->
CopyFromCpu
(
c_cache
->
get_data
().
data
());
c_box
->
CopyFromCpu
(
c_cache
->
get_data
().
data
());
...
@@ -172,10 +190,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
...
@@ -172,10 +190,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
}
}
LOG
(
INFO
)
<<
"get the model success"
;
LOG
(
INFO
)
<<
"get the model success"
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
assert
(
h_cache
->
get_shape
()
==
h_out
->
shape
());
assert
(
h_cache
->
get_shape
()
==
h_out
->
shape
());
h_out
->
CopyToCpu
(
h_cache
->
get_data
().
data
());
h_out
->
CopyToCpu
(
h_cache
->
get_data
().
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
assert
(
c_cache
->
get_shape
()
==
c_out
->
shape
());
assert
(
c_cache
->
get_shape
()
==
c_out
->
shape
());
c_out
->
CopyToCpu
(
c_cache
->
get_data
().
data
());
c_out
->
CopyToCpu
(
c_cache
->
get_data
().
data
());
...
@@ -187,13 +207,14 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
...
@@ -187,13 +207,14 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
col
=
output_shape
[
2
];
col
=
output_shape
[
2
];
vector
<
float
>
inferences_result
;
vector
<
float
>
inferences_result
;
inferences
->
Resize
(
row
,
col
);
inferences
->
Resize
(
row
,
col
);
inferences_result
.
resize
(
row
*
col
);
inferences_result
.
resize
(
row
*
col
);
output_tensor
->
CopyToCpu
(
inferences_result
.
data
());
output_tensor
->
CopyToCpu
(
inferences_result
.
data
());
ReleasePredictor
(
predictor
);
ReleasePredictor
(
predictor
);
for
(
int
row_idx
=
0
;
row_idx
<
row
;
++
row_idx
)
{
for
(
int
row_idx
=
0
;
row_idx
<
row
;
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
col
;
++
col_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
col
;
++
col_idx
)
{
(
*
inferences
)(
row_idx
,
col_idx
)
=
inferences_result
[
col
*
row_idx
+
col_idx
];
(
*
inferences
)(
row_idx
,
col_idx
)
=
inferences_result
[
col
*
row_idx
+
col_idx
];
}
}
}
}
}
}
...
...
speechx/speechx/nnet/paddle_nnet.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#pragma once
#include "nnet/nnet_interface.h"
#include "base/common.h"
#include "base/common.h"
#include "nnet/nnet_interface.h"
#include "paddle_inference_api.h"
#include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/matrix/kaldi-matrix.h"
...
@@ -24,19 +38,27 @@ struct ModelOptions {
...
@@ -24,19 +38,27 @@ struct ModelOptions {
std
::
string
cache_shape
;
std
::
string
cache_shape
;
bool
enable_fc_padding
;
bool
enable_fc_padding
;
bool
enable_profile
;
bool
enable_profile
;
ModelOptions
()
:
ModelOptions
()
model_path
(
"../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"
),
:
model_path
(
params_path
(
"../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"
),
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdmodel"
),
params_path
(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"
),
thread_num
(
2
),
thread_num
(
2
),
use_gpu
(
false
),
use_gpu
(
false
),
input_names
(
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"
),
input_names
(
output_names
(
"save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"
),
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_"
"box"
),
output_names
(
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1"
),
cache_names
(
"chunk_state_h_box,chunk_state_c_box"
),
cache_names
(
"chunk_state_h_box,chunk_state_c_box"
),
cache_shape
(
"3-1-1024,3-1-1024"
),
cache_shape
(
"3-1-1024,3-1-1024"
),
switch_ir_optim
(
false
),
switch_ir_optim
(
false
),
enable_fc_padding
(
false
),
enable_fc_padding
(
false
),
enable_profile
(
false
)
{
enable_profile
(
false
)
{}
}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"model-path"
,
&
model_path
,
"model file path"
);
opts
->
Register
(
"model-path"
,
&
model_path
,
"model file path"
);
...
@@ -47,37 +69,37 @@ struct ModelOptions {
...
@@ -47,37 +69,37 @@ struct ModelOptions {
opts
->
Register
(
"output-names"
,
&
output_names
,
"paddle output names"
);
opts
->
Register
(
"output-names"
,
&
output_names
,
"paddle output names"
);
opts
->
Register
(
"cache-names"
,
&
cache_names
,
"cache names"
);
opts
->
Register
(
"cache-names"
,
&
cache_names
,
"cache names"
);
opts
->
Register
(
"cache-shape"
,
&
cache_shape
,
"cache shape"
);
opts
->
Register
(
"cache-shape"
,
&
cache_shape
,
"cache shape"
);
opts
->
Register
(
"switch-ir-optiom"
,
&
switch_ir_optim
,
"paddle SwitchIrOptim option"
);
opts
->
Register
(
"switch-ir-optiom"
,
opts
->
Register
(
"enable-fc-padding"
,
&
enable_fc_padding
,
"paddle EnableFCPadding option"
);
&
switch_ir_optim
,
opts
->
Register
(
"enable-profile"
,
&
enable_profile
,
"paddle EnableProfile option"
);
"paddle SwitchIrOptim option"
);
opts
->
Register
(
"enable-fc-padding"
,
&
enable_fc_padding
,
"paddle EnableFCPadding option"
);
opts
->
Register
(
"enable-profile"
,
&
enable_profile
,
"paddle EnableProfile option"
);
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
class
Tensor
{
class
Tensor
{
public:
public:
Tensor
()
{
Tensor
()
{}
}
Tensor
(
const
std
::
vector
<
int
>&
shape
)
:
_shape
(
shape
)
{
Tensor
(
const
std
::
vector
<
int
>&
shape
)
:
int
data_size
=
std
::
accumulate
(
_shape
(
shape
)
{
_shape
.
begin
(),
_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
data_size
=
std
::
accumulate
(
_shape
.
begin
(),
_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
LOG
(
INFO
)
<<
"data size: "
<<
data_size
;
LOG
(
INFO
)
<<
"data size: "
<<
data_size
;
_data
.
resize
(
data_size
,
0
);
_data
.
resize
(
data_size
,
0
);
}
}
void
reshape
(
const
std
::
vector
<
int
>&
shape
)
{
void
reshape
(
const
std
::
vector
<
int
>&
shape
)
{
_shape
=
shape
;
_shape
=
shape
;
int
data_size
=
std
::
accumulate
(
_shape
.
begin
(),
_shape
.
end
(),
int
data_size
=
std
::
accumulate
(
1
,
std
::
multiplies
<
int
>
());
_shape
.
begin
(),
_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
_data
.
resize
(
data_size
,
0
);
_data
.
resize
(
data_size
,
0
);
}
}
const
std
::
vector
<
int
>&
get_shape
()
const
{
const
std
::
vector
<
int
>&
get_shape
()
const
{
return
_shape
;
}
return
_shape
;
std
::
vector
<
T
>&
get_data
()
{
return
_data
;
}
}
std
::
vector
<
T
>&
get_data
()
{
private:
return
_data
;
}
private:
std
::
vector
<
int
>
_shape
;
std
::
vector
<
int
>
_shape
;
std
::
vector
<
T
>
_data
;
std
::
vector
<
T
>
_data
;
};
};
...
@@ -88,7 +110,8 @@ class PaddleNnet : public NnetInterface {
...
@@ -88,7 +110,8 @@ class PaddleNnet : public NnetInterface {
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
);
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
);
virtual
void
Reset
();
virtual
void
Reset
();
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
const
std
::
string
&
name
);
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
const
std
::
string
&
name
);
void
InitCacheEncouts
(
const
ModelOptions
&
opts
);
void
InitCacheEncouts
(
const
ModelOptions
&
opts
);
private:
private:
...
...
speechx/speechx/utils/file_utils.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "utils/file_utils.h"
#include "utils/file_utils.h"
namespace
ppspeech
{
namespace
ppspeech
{
...
@@ -17,5 +31,4 @@ bool ReadFileToVector(const std::string& filename,
...
@@ -17,5 +31,4 @@ bool ReadFileToVector(const std::string& filename,
return
true
;
return
true
;
}
}
}
}
\ No newline at end of file
speechx/speechx/utils/file_utils.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "base/common.h"
namespace
ppspeech
{
namespace
ppspeech
{
bool
ReadFileToVector
(
const
std
::
string
&
filename
,
bool
ReadFileToVector
(
const
std
::
string
&
filename
,
std
::
vector
<
std
::
string
>*
data
);
std
::
vector
<
std
::
string
>*
data
);
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录