Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
41feecbd
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41feecbd
编写于
3月 07, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format
上级
11dc485d
变更
32
显示空白变更内容
内联
并排
Showing
32 changed file
with
2112 addition
and
1653 deletion
+2112
-1653
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-2
paddlespeech/s2t/io/sampler.py
paddlespeech/s2t/io/sampler.py
+1
-1
paddlespeech/t2s/modules/transformer/repeat.py
paddlespeech/t2s/modules/transformer/repeat.py
+1
-1
speechx/examples/README.md
speechx/examples/README.md
+0
-1
speechx/examples/decoder/offline-decoder-main.cc
speechx/examples/decoder/offline-decoder-main.cc
+55
-38
speechx/examples/feat/feature-mfcc-test.cc
speechx/examples/feat/feature-mfcc-test.cc
+645
-611
speechx/examples/feat/linear-spectrogram-main.cc
speechx/examples/feat/linear-spectrogram-main.cc
+231
-99
speechx/examples/nnet/pp-model-test.cc
speechx/examples/nnet/pp-model-test.cc
+78
-49
speechx/speechx/base/basic_types.h
speechx/speechx/base/basic_types.h
+27
-27
speechx/speechx/base/common.h
speechx/speechx/base/common.h
+5
-5
speechx/speechx/base/macros.h
speechx/speechx/base/macros.h
+2
-2
speechx/speechx/base/thread_pool.h
speechx/speechx/base/thread_pool.h
+41
-51
speechx/speechx/decoder/common.h
speechx/speechx/decoder/common.h
+17
-3
speechx/speechx/decoder/ctc_beam_search_decoder.cc
speechx/speechx/decoder/ctc_beam_search_decoder.cc
+231
-219
speechx/speechx/decoder/ctc_beam_search_decoder.h
speechx/speechx/decoder/ctc_beam_search_decoder.h
+57
-40
speechx/speechx/frontend/fbank.h
speechx/speechx/frontend/fbank.h
+5
-4
speechx/speechx/frontend/feature_extractor_controller.h
speechx/speechx/frontend/feature_extractor_controller.h
+14
-0
speechx/speechx/frontend/feature_extractor_controller_impl.h
speechx/speechx/frontend/feature_extractor_controller_impl.h
+14
-0
speechx/speechx/frontend/feature_extractor_interface.h
speechx/speechx/frontend/feature_extractor_interface.h
+2
-1
speechx/speechx/frontend/linear_spectrogram.cc
speechx/speechx/frontend/linear_spectrogram.cc
+104
-104
speechx/speechx/frontend/linear_spectrogram.h
speechx/speechx/frontend/linear_spectrogram.h
+24
-11
speechx/speechx/frontend/normalizer.cc
speechx/speechx/frontend/normalizer.cc
+143
-125
speechx/speechx/frontend/normalizer.h
speechx/speechx/frontend/normalizer.h
+33
-13
speechx/speechx/frontend/window.h
speechx/speechx/frontend/window.h
+0
-1
speechx/speechx/nnet/decodable-itf.h
speechx/speechx/nnet/decodable-itf.h
+72
-43
speechx/speechx/nnet/decodable.cc
speechx/speechx/nnet/decodable.cc
+34
-28
speechx/speechx/nnet/decodable.h
speechx/speechx/nnet/decodable.h
+23
-6
speechx/speechx/nnet/nnet_interface.h
speechx/speechx/nnet/nnet_interface.h
+15
-2
speechx/speechx/nnet/paddle_nnet.cc
speechx/speechx/nnet/paddle_nnet.cc
+123
-102
speechx/speechx/nnet/paddle_nnet.h
speechx/speechx/nnet/paddle_nnet.h
+83
-60
speechx/speechx/utils/file_utils.cc
speechx/speechx/utils/file_utils.cc
+14
-1
speechx/speechx/utils/file_utils.h
speechx/speechx/utils/file_utils.h
+16
-3
未找到文件。
.pre-commit-config.yaml
浏览文件 @
41feecbd
...
...
@@ -50,13 +50,13 @@ repos:
entry
:
bash .pre-commit-hooks/clang-format.hook -i
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude
:
(?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude
:
(?=speechx/speechx/kaldi
|speechx/patch
).*(\.cpp|\.cc|\.h|\.py)$
-
id
:
copyright_checker
name
:
copyright_checker
entry
:
python .pre-commit-hooks/copyright-check.hook
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude
:
(?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude
:
(?=third_party|pypinyin|speechx/speechx/kaldi
|speechx/patch
).*(\.cpp|\.cc|\.h|\.py)$
-
repo
:
https://github.com/asottile/reorder_python_imports
rev
:
v2.4.0
hooks
:
...
...
paddlespeech/s2t/io/sampler.py
浏览文件 @
41feecbd
...
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng
=
np
.
random
.
RandomState
(
epoch
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
rng
.
shuffle
(
batch_indices
)
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
assert
clipped
is
False
...
...
paddlespeech/t2s/modules/transformer/repeat.py
浏览文件 @
41feecbd
...
...
@@ -36,4 +36,4 @@ def repeat(N, fn):
Returns:
MultiSequential: Repeated model instance.
"""
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)])
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)])
speechx/examples/README.md
浏览文件 @
41feecbd
...
...
@@ -3,4 +3,3 @@
*
decoder - offline decoder
*
feat - mfcc, linear
*
nnet - ds2 nn
speechx/examples/decoder/offline-decoder-main.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "nnet/paddle_nnet.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string
(
feature_respecifier
,
""
,
"test nnet prob"
);
...
...
@@ -13,7 +27,7 @@ using kaldi::BaseFloat;
using
kaldi
::
Matrix
;
using
std
::
vector
;
//
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
//
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
...
...
@@ -23,7 +37,8 @@ int main(int argc, char* argv[]) {
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_respecifier
);
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_respecifier
);
// test nnet_output --> decoder result
int32
num_done
=
0
,
num_err
=
0
;
...
...
@@ -32,11 +47,13 @@ int main(int argc, char* argv[]) {
ppspeech
::
CTCBeamSearch
decoder
(
opts
);
ppspeech
::
ModelOptions
model_opts
;
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
new
ppspeech
::
PaddleNnet
(
model_opts
));
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
new
ppspeech
::
PaddleNnet
(
model_opts
));
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
));
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
));
//
int32 chunk_size = 35;
//
int32 chunk_size = 35;
decoder
.
InitDecoder
();
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
string
utt
=
feature_reader
.
Key
();
...
...
speechx/examples/feat/feature-mfcc-test.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
...
...
@@ -20,17 +34,15 @@
#include <iostream>
#include "feat/feature-mfcc.h"
#include "base/kaldi-math.h"
#include "
matrix/kaldi-matrix-inl
.h"
#include "
feat/feature-mfcc
.h"
#include "feat/wave-reader.h"
#include "matrix/kaldi-matrix-inl.h"
using
namespace
kaldi
;
static
void
UnitTestReadWave
()
{
std
::
cout
<<
"=== UnitTestReadWave() ===
\n
"
;
Vector
<
BaseFloat
>
v
,
v2
;
...
...
@@ -47,15 +59,15 @@ static void UnitTestReadWave() {
v
.
CopyFromVec
(
data
.
Row
(
0
));
}
std
::
cout
<<
"<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab
\n
"
;
std
::
ifstream
input
(
"test_data/test_matlab.ascii"
);
std
::
cout
<<
"<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab
\n
"
;
std
::
ifstream
input
(
"test_data/test_matlab.ascii"
);
KALDI_ASSERT
(
input
.
good
());
v2
.
Read
(
input
,
false
);
input
.
close
();
std
::
cout
<<
"<<<=== Comparing freshly read waveform to 'libsndfile' waveform
\n
"
;
std
::
cout
<<
"<<<=== Comparing freshly read waveform to 'libsndfile' waveform
\n
"
;
KALDI_ASSERT
(
v
.
Dim
()
==
v2
.
Dim
());
for
(
int32
i
=
0
;
i
<
v
.
Dim
();
i
++
)
{
KALDI_ASSERT
(
v
(
i
)
==
v2
(
i
));
...
...
@@ -66,11 +78,9 @@ static void UnitTestReadWave() {
// std::cout << v;
std
::
cout
<<
"Test passed :)
\n\n
"
;
}
/**
*/
static
void
UnitTestSimple
()
{
...
...
@@ -81,7 +91,7 @@ static void UnitTestSimple() {
// init with noise
for
(
int32
i
=
0
;
i
<
v
.
Dim
();
i
++
)
{
v
(
i
)
=
(
abs
(
i
*
433024253
)
%
65535
)
-
(
65535
/
2
);
v
(
i
)
=
(
abs
(
i
*
433024253
)
%
65535
)
-
(
65535
/
2
);
}
std
::
cout
<<
"<<<=== Just make sure it runs... Nothing is compared
\n
"
;
...
...
@@ -147,9 +157,7 @@ static void UnitTestHTKCompare1() {
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
bool
passed
=
true
;
...
...
@@ -158,32 +166,36 @@ static void UnitTestHTKCompare1() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if
(
i_old
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
}
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
};
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.1"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
...
...
@@ -231,9 +243,7 @@ static void UnitTestHTKCompare2() {
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
bool
passed
=
true
;
...
...
@@ -242,32 +252,36 @@ static void UnitTestHTKCompare2() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if
(
i_old
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
}
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
};
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.2"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
...
...
@@ -305,7 +319,7 @@ static void UnitTestHTKCompare3() {
op
.
htk_compat
=
true
;
op
.
use_energy
=
true
;
// Use energy.
op
.
mel_opts
.
low_freq
=
20.0
;
//
op.mel_opts.debug_mel = true;
//
op.mel_opts.debug_mel = true;
op
.
mel_opts
.
htk_mode
=
true
;
Mfcc
mfcc
(
op
);
...
...
@@ -316,9 +330,7 @@ static void UnitTestHTKCompare3() {
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
bool
passed
=
true
;
...
...
@@ -327,32 +339,36 @@ static void UnitTestHTKCompare3() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
}
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
};
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.3"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
...
...
@@ -399,9 +415,7 @@ static void UnitTestHTKCompare4() {
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
bool
passed
=
true
;
...
...
@@ -410,32 +424,36 @@ static void UnitTestHTKCompare4() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
}
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
};
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.4"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
...
...
@@ -476,7 +494,8 @@ static void UnitTestHTKCompare5() {
op
.
mel_opts
.
vtln_high
=
7500.0
;
op
.
mel_opts
.
htk_mode
=
true
;
BaseFloat
vtln_warp
=
1.1
;
// our approach identical to htk for warp factor >1,
BaseFloat
vtln_warp
=
1.1
;
// our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc
mfcc
(
op
);
...
...
@@ -487,9 +506,7 @@ static void UnitTestHTKCompare5() {
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
bool
passed
=
true
;
...
...
@@ -498,32 +515,36 @@ static void UnitTestHTKCompare5() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
}
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
};
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.5"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
...
...
@@ -572,9 +593,7 @@ static void UnitTestHTKCompare6() {
DeltaFeaturesOptions
delta_opts
;
Matrix
<
BaseFloat
>
kaldi_features
;
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
ComputeDeltas
(
delta_opts
,
kaldi_raw_features
,
&
kaldi_features
);
// compare the results
bool
passed
=
true
;
...
...
@@ -583,32 +602,36 @@ static void UnitTestHTKCompare6() {
KALDI_ASSERT
(
kaldi_features
.
NumCols
()
==
htk_features
.
NumCols
());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
i
=
10
;
i
+
10
<
kaldi_features
.
NumRows
();
i
++
)
{
for
(
int32
j
=
0
;
j
<
kaldi_features
.
NumCols
();
j
++
)
{
BaseFloat
a
=
kaldi_features
(
i
,
j
),
b
=
htk_features
(
i
,
j
);
if
((
std
::
abs
(
b
-
a
))
>
1.0
)
{
//<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if
(
static_cast
<
int32
>
(
i_old
)
!=
i
)
{
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
std
::
cout
<<
"
\n\n\n
[HTK-row: "
<<
i
<<
"] "
<<
htk_features
.
Row
(
i
)
<<
"
\n
"
;
std
::
cout
<<
"[Kaldi-row: "
<<
i
<<
"] "
<<
kaldi_features
.
Row
(
i
)
<<
"
\n\n\n
"
;
i_old
=
i
;
}
// print indices of non-matching cells
std
::
cout
<<
"["
<<
i
<<
", "
<<
j
<<
"]"
;
passed
=
false
;
}}}
}
}
}
if
(
!
passed
)
KALDI_ERR
<<
"Test failed"
;
// write the htk features for later inspection
HtkHeader
header
=
{
kaldi_features
.
NumRows
(),
100000
,
// 10ms
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
static_cast
<
int16
>
(
sizeof
(
float
)
*
kaldi_features
.
NumCols
()),
021406
// MFCC_D_A_0
};
{
std
::
ofstream
os
(
"tmp.test.wav.fea_kaldi.6"
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
std
::
ios
::
out
|
std
::
ios
::
binary
);
WriteHtk
(
os
,
kaldi_features
,
header
);
}
...
...
@@ -619,36 +642,51 @@ static void UnitTestHTKCompare6() {
void
UnitTestVtln
()
{
// Test the function VtlnWarpFreq.
BaseFloat
low_freq
=
10
,
high_freq
=
780
0
,
vtln_low_cutoff
=
20
,
vtln_high_cutoff
=
7400
;
BaseFloat
low_freq
=
10
,
high_freq
=
7800
,
vtln_low_cutoff
=
2
0
,
vtln_high_cutoff
=
7400
;
for
(
size_t
i
=
0
;
i
<
100
;
i
++
)
{
BaseFloat
freq
=
5000
,
warp_factor
=
0.9
+
RandUniform
()
*
0.2
;
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
freq
),
freq
/
warp_factor
);
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
low_freq
),
low_freq
);
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
AssertEqual
(
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
high_freq
),
high_freq
);
BaseFloat
freq2
=
low_freq
+
(
high_freq
-
low_freq
)
*
RandUniform
(),
freq3
=
freq2
+
(
high_freq
-
freq2
)
*
RandUniform
();
// freq3>=freq2
BaseFloat
w2
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
BaseFloat
freq2
=
low_freq
+
(
high_freq
-
low_freq
)
*
RandUniform
(),
freq3
=
freq2
+
(
high_freq
-
freq2
)
*
RandUniform
();
// freq3>=freq2
BaseFloat
w2
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
freq2
);
BaseFloat
w3
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
BaseFloat
w3
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
warp_factor
,
freq3
);
KALDI_ASSERT
(
w3
>=
w2
);
// increasing function.
BaseFloat
w3dash
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
1.0
,
freq3
);
BaseFloat
w3dash
=
MelBanks
::
VtlnWarpFreq
(
vtln_low_cutoff
,
vtln_high_cutoff
,
low_freq
,
high_freq
,
1.0
,
freq3
);
AssertEqual
(
w3dash
,
freq3
);
}
}
...
...
@@ -670,11 +708,9 @@ static void UnitTestFeat() {
}
int
main
()
{
try
{
for
(
int
i
=
0
;
i
<
5
;
i
++
)
UnitTestFeat
();
for
(
int
i
=
0
;
i
<
5
;
i
++
)
UnitTestFeat
();
std
::
cout
<<
"Tests succeeded.
\n
"
;
return
0
;
}
catch
(
const
std
::
exception
&
e
)
{
...
...
@@ -682,5 +718,3 @@ int main() {
return
1
;
}
}
speechx/examples/feat/linear-spectrogram-main.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test wav path"
);
DEFINE_string
(
feature_wspecifier
,
""
,
"test wav ark"
);
...
...
@@ -15,12 +29,120 @@ DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string
(
cmvn_write_path
,
"./cmvn.ark"
,
"test wav ark"
);
std
::
vector
<
float
>
mean_
{
-
13730251.531853663
,
-
12982852.199316509
,
-
13673844.299583456
,
-
13089406.559646806
,
-
12673095.524938712
,
-
12823859.223276224
,
-
13590267.158903603
,
-
14257618.467152044
,
-
14374605.116185192
,
-
14490009.21822485
,
-
14849827.158924166
,
-
15354435.470563512
,
-
15834149.206532761
,
-
16172971.985514281
,
-
16348740.496746974
,
-
16423536.699409386
,
-
16556246.263649225
,
-
16744088.772748645
,
-
16916184.08510357
,
-
17054034.840031497
,
-
17165612.509455364
,
-
17255955.470915023
,
-
17322572.527648456
,
-
17408943.862033736
,
-
17521554.799865916
,
-
17620623.254924215
,
-
17699792.395918526
,
-
17723364.411134344
,
-
17741483.4433254
,
-
17747426.888704527
,
-
17733315.928209435
,
-
17748780.160905756
,
-
17808336.883775543
,
-
17895918.671983004
,
-
18009812.59173023
,
-
18098188.66548325
,
-
18195798.958462656
,
-
18293617.62980999
,
-
18397432.92077201
,
-
18505834.787318766
,
-
18585451.8100908
,
-
18652438.235649142
,
-
18700960.306275308
,
-
18734944.58792185
,
-
18737426.313365128
,
-
18735347.165987637
,
-
18738813.444170244
,
-
18737086.848890636
,
-
18731576.2474336
,
-
18717405.44095871
,
-
18703089.25545657
,
-
18691014.546456724
,
-
18692460.568905357
,
-
18702119.628629155
,
-
18727710.621126678
,
-
18761582.72034647
,
-
18806745.835547544
,
-
18850674.8692112
,
-
18884431.510951452
,
-
18919999.992506847
,
-
18939303.799078144
,
-
18952946.273760635
,
-
18980289.22996379
,
-
19011610.17803294
,
-
19040948.61805145
,
-
19061021.429847397
,
-
19112055.53768819
,
-
19149667.414264943
,
-
19201127.05091321
,
-
19270250.82564605
,
-
19334606.883057203
,
-
19390513.336589377
,
-
19444176.259208687
,
-
19502755.000038862
,
-
19544333.014549147
,
-
19612668.183176614
,
-
19681902.19006569
,
-
19771969.951249883
,
-
19873329.723376893
,
-
19996752.59235844
,
-
20110031.131400537
,
-
20231658.612529557
,
-
20319378.894054495
,
-
20378534.45718066
,
-
20413332.089584175
,
-
20438147.844177883
,
-
20443710.248040095
,
-
20465457.02238927
,
-
20488610.969337028
,
-
20516295.16424432
,
-
20541423.795738827
,
-
20553192.874953747
,
-
20573605.50701977
,
-
20577871.61936797
,
-
20571807.008916274
,
-
20556242.38912231
,
-
20542199.30819195
,
-
20521239.063551214
,
-
20519150.80004532
,
-
20527204.80248933
,
-
20536933.769257784
,
-
20543470.522332076
,
-
20549700.089992985
,
-
20551525.24958494
,
-
20554873.406493705
,
-
20564277.65794227
,
-
20572211.740052115
,
-
20574305.69550465
,
-
20575494.450104576
,
-
20567092.577932164
,
-
20549302.929608088
,
-
20545445.11878376
,
-
20546625.326603737
,
-
20549190.03499401
,
-
20554824.947828256
,
-
20568341.378989458
,
-
20577582.331383612
,
-
20577980.519402675
,
-
20566603.03458152
,
-
20560131.592262644
,
-
20552166.469060015
,
-
20549063.06763577
,
-
20544490.562339947
,
-
20539817.82346569
,
-
20528747.715731595
,
-
20518026.24576161
,
-
20510977.844974525
,
-
20506874.36087992
,
-
20506731.11977665
,
-
20510482.133420516
,
-
20507760.92101862
,
-
20494644.834457114
,
-
20480107.89304893
,
-
20461312.091867123
,
-
20442941.75080173
,
-
20426123.02834838
,
-
20424607.675283
,
-
20426810.369107097
,
-
20434024.50097819
,
-
20437404.75544205
,
-
20447688.63916367
,
-
20460893.335563846
,
-
20482922.735127095
,
-
20503610.119434915
,
-
20527062.76448319
,
-
20557830.035128627
,
-
20593274.72068722
,
-
20632528.452965066
,
-
20673637.471334763
,
-
20733106.97143075
,
-
20842921.0447562
,
-
21054357.83621519
,
-
21416569.534189366
,
-
21978460.272811692
,
-
22753170.052172784
,
-
23671344.10563395
,
-
24613499.293358143
,
-
25406477.12230188
,
-
25884377.82156489
,
-
26049040.62791664
,
-
26996879.104431007
};
std
::
vector
<
float
>
variance_
{
213747175.10846674
,
188395815.34302503
,
212706429.10966414
,
199109025.81461075
,
189235901.23864496
,
194901336.53253657
,
217481594.29306737
,
238689869.12327808
,
243977501.24115244
,
248479623.6431067
,
259766741.47116545
,
275516766.7790273
,
291271202.3691234
,
302693239.8220509
,
308627358.3997694
,
311143911.38788426
,
315446105.07731867
,
321705430.9341829
,
327458907.4659941
,
332245072.43223983
,
336251717.5935284
,
339694069.7639722
,
342188204.4322228
,
345587110.31313115
,
349903086.2875232
,
353660214.20643026
,
356700344.5270885
,
357665362.3529641
,
358493352.05658793
,
358857951.620328
,
358375239.52774596
,
358899733.6342954
,
361051818.3511561
,
364361716.05025816
,
368750322.3771452
,
372047800.6462831
,
375655861.1349018
,
379358519.1980013
,
383327605.3935181
,
387458599.282341
,
390434692.3406868
,
392994486.35057056
,
394874418.04603153
,
396230525.79763395
,
396365592.0414835
,
396334819.8242737
,
396488353.19250053
,
396438877.00744957
,
396197980.4459586
,
395590921.6672991
,
395001107.62072515
,
394528291.7318225
,
394593110.424006
,
395018405.59353715
,
396110577.5415993
,
397506704.0371068
,
399400197.4657644
,
401243568.2468382
,
402687134.7805103
,
404136047.2872507
,
404883170.001883
,
405522253.219517
,
406660365.3626476
,
407919346.0991902
,
409045348.5384909
,
409759588.7889818
,
411974821.8564483
,
413489718.78201455
,
415535392.56684107
,
418466481.97674364
,
421104678.35678065
,
423405392.5200779
,
425550570.40798235
,
427929423.9579701
,
429585274.253478
,
432368493.55181056
,
435193587.13513297
,
438886855.20476013
,
443058876.8633751
,
448181232.5093362
,
452883835.6332396
,
458056721.77926534
,
461816531.22735566
,
464363620.1970998
,
465886343.5057493
,
466928872.0651
,
467180536.42647296
,
468111848.70714295
,
469138695.3071312
,
470378429.6930793
,
471517958.7132626
,
472109050.4262365
,
473087417.0177867
,
473381322.04648733
,
473220195.85483915
,
472666071.8998819
,
472124669.87879956
,
471298571.411737
,
471251033.2902761
,
471672676.43128747
,
472177147.2193172
,
472572361.7711908
,
472968783.7751127
,
473156295.4164052
,
473398034.82676554
,
473897703.5203811
,
474328271.33112127
,
474452670.98002136
,
474549003.99284613
,
474252887.13567275
,
473557462.909069
,
473483385.85193115
,
473609738.04855174
,
473746944.82085115
,
474016729.91696435
,
474617321.94138587
,
475045097.237122
,
475125402.586558
,
474664112.9824912
,
474426247.5800283
,
474104075.42796475
,
473978219.7273978
,
473773171.7798875
,
473578534.69508696
,
473102924.16904145
,
472651240.5232615
,
472374383.1810912
,
472209479.6956096
,
472202298.8921673
,
472370090.76781124
,
472220933.99374026
,
471625467.37106377
,
470994646.51883453
,
470182428.9637543
,
469348211.5939578
,
468570387.4467277
,
468540442.7225135
,
468672018.90414184
,
468994346.9533251
,
469138757.58201426
,
469553915.95710236
,
470134523.38582784
,
471082421.62055486
,
471962316.51804745
,
472939745.1708408
,
474250621.5944825
,
475773933.43199486
,
477465399.71087736
,
479218782.61382693
,
481752299.7930922
,
486608947.8984568
,
496119403.2067917
,
512730085.5704984
,
539048915.2641417
,
576285298.3548826
,
621610270.2240586
,
669308196.4436442
,
710656993.5957186
,
736344437.3725077
,
745481288.0241544
,
801121432.9925804
};
std
::
vector
<
float
>
mean_
{
-
13730251.531853663
,
-
12982852.199316509
,
-
13673844.299583456
,
-
13089406.559646806
,
-
12673095.524938712
,
-
12823859.223276224
,
-
13590267.158903603
,
-
14257618.467152044
,
-
14374605.116185192
,
-
14490009.21822485
,
-
14849827.158924166
,
-
15354435.470563512
,
-
15834149.206532761
,
-
16172971.985514281
,
-
16348740.496746974
,
-
16423536.699409386
,
-
16556246.263649225
,
-
16744088.772748645
,
-
16916184.08510357
,
-
17054034.840031497
,
-
17165612.509455364
,
-
17255955.470915023
,
-
17322572.527648456
,
-
17408943.862033736
,
-
17521554.799865916
,
-
17620623.254924215
,
-
17699792.395918526
,
-
17723364.411134344
,
-
17741483.4433254
,
-
17747426.888704527
,
-
17733315.928209435
,
-
17748780.160905756
,
-
17808336.883775543
,
-
17895918.671983004
,
-
18009812.59173023
,
-
18098188.66548325
,
-
18195798.958462656
,
-
18293617.62980999
,
-
18397432.92077201
,
-
18505834.787318766
,
-
18585451.8100908
,
-
18652438.235649142
,
-
18700960.306275308
,
-
18734944.58792185
,
-
18737426.313365128
,
-
18735347.165987637
,
-
18738813.444170244
,
-
18737086.848890636
,
-
18731576.2474336
,
-
18717405.44095871
,
-
18703089.25545657
,
-
18691014.546456724
,
-
18692460.568905357
,
-
18702119.628629155
,
-
18727710.621126678
,
-
18761582.72034647
,
-
18806745.835547544
,
-
18850674.8692112
,
-
18884431.510951452
,
-
18919999.992506847
,
-
18939303.799078144
,
-
18952946.273760635
,
-
18980289.22996379
,
-
19011610.17803294
,
-
19040948.61805145
,
-
19061021.429847397
,
-
19112055.53768819
,
-
19149667.414264943
,
-
19201127.05091321
,
-
19270250.82564605
,
-
19334606.883057203
,
-
19390513.336589377
,
-
19444176.259208687
,
-
19502755.000038862
,
-
19544333.014549147
,
-
19612668.183176614
,
-
19681902.19006569
,
-
19771969.951249883
,
-
19873329.723376893
,
-
19996752.59235844
,
-
20110031.131400537
,
-
20231658.612529557
,
-
20319378.894054495
,
-
20378534.45718066
,
-
20413332.089584175
,
-
20438147.844177883
,
-
20443710.248040095
,
-
20465457.02238927
,
-
20488610.969337028
,
-
20516295.16424432
,
-
20541423.795738827
,
-
20553192.874953747
,
-
20573605.50701977
,
-
20577871.61936797
,
-
20571807.008916274
,
-
20556242.38912231
,
-
20542199.30819195
,
-
20521239.063551214
,
-
20519150.80004532
,
-
20527204.80248933
,
-
20536933.769257784
,
-
20543470.522332076
,
-
20549700.089992985
,
-
20551525.24958494
,
-
20554873.406493705
,
-
20564277.65794227
,
-
20572211.740052115
,
-
20574305.69550465
,
-
20575494.450104576
,
-
20567092.577932164
,
-
20549302.929608088
,
-
20545445.11878376
,
-
20546625.326603737
,
-
20549190.03499401
,
-
20554824.947828256
,
-
20568341.378989458
,
-
20577582.331383612
,
-
20577980.519402675
,
-
20566603.03458152
,
-
20560131.592262644
,
-
20552166.469060015
,
-
20549063.06763577
,
-
20544490.562339947
,
-
20539817.82346569
,
-
20528747.715731595
,
-
20518026.24576161
,
-
20510977.844974525
,
-
20506874.36087992
,
-
20506731.11977665
,
-
20510482.133420516
,
-
20507760.92101862
,
-
20494644.834457114
,
-
20480107.89304893
,
-
20461312.091867123
,
-
20442941.75080173
,
-
20426123.02834838
,
-
20424607.675283
,
-
20426810.369107097
,
-
20434024.50097819
,
-
20437404.75544205
,
-
20447688.63916367
,
-
20460893.335563846
,
-
20482922.735127095
,
-
20503610.119434915
,
-
20527062.76448319
,
-
20557830.035128627
,
-
20593274.72068722
,
-
20632528.452965066
,
-
20673637.471334763
,
-
20733106.97143075
,
-
20842921.0447562
,
-
21054357.83621519
,
-
21416569.534189366
,
-
21978460.272811692
,
-
22753170.052172784
,
-
23671344.10563395
,
-
24613499.293358143
,
-
25406477.12230188
,
-
25884377.82156489
,
-
26049040.62791664
,
-
26996879.104431007
};
std
::
vector
<
float
>
variance_
{
213747175.10846674
,
188395815.34302503
,
212706429.10966414
,
199109025.81461075
,
189235901.23864496
,
194901336.53253657
,
217481594.29306737
,
238689869.12327808
,
243977501.24115244
,
248479623.6431067
,
259766741.47116545
,
275516766.7790273
,
291271202.3691234
,
302693239.8220509
,
308627358.3997694
,
311143911.38788426
,
315446105.07731867
,
321705430.9341829
,
327458907.4659941
,
332245072.43223983
,
336251717.5935284
,
339694069.7639722
,
342188204.4322228
,
345587110.31313115
,
349903086.2875232
,
353660214.20643026
,
356700344.5270885
,
357665362.3529641
,
358493352.05658793
,
358857951.620328
,
358375239.52774596
,
358899733.6342954
,
361051818.3511561
,
364361716.05025816
,
368750322.3771452
,
372047800.6462831
,
375655861.1349018
,
379358519.1980013
,
383327605.3935181
,
387458599.282341
,
390434692.3406868
,
392994486.35057056
,
394874418.04603153
,
396230525.79763395
,
396365592.0414835
,
396334819.8242737
,
396488353.19250053
,
396438877.00744957
,
396197980.4459586
,
395590921.6672991
,
395001107.62072515
,
394528291.7318225
,
394593110.424006
,
395018405.59353715
,
396110577.5415993
,
397506704.0371068
,
399400197.4657644
,
401243568.2468382
,
402687134.7805103
,
404136047.2872507
,
404883170.001883
,
405522253.219517
,
406660365.3626476
,
407919346.0991902
,
409045348.5384909
,
409759588.7889818
,
411974821.8564483
,
413489718.78201455
,
415535392.56684107
,
418466481.97674364
,
421104678.35678065
,
423405392.5200779
,
425550570.40798235
,
427929423.9579701
,
429585274.253478
,
432368493.55181056
,
435193587.13513297
,
438886855.20476013
,
443058876.8633751
,
448181232.5093362
,
452883835.6332396
,
458056721.77926534
,
461816531.22735566
,
464363620.1970998
,
465886343.5057493
,
466928872.0651
,
467180536.42647296
,
468111848.70714295
,
469138695.3071312
,
470378429.6930793
,
471517958.7132626
,
472109050.4262365
,
473087417.0177867
,
473381322.04648733
,
473220195.85483915
,
472666071.8998819
,
472124669.87879956
,
471298571.411737
,
471251033.2902761
,
471672676.43128747
,
472177147.2193172
,
472572361.7711908
,
472968783.7751127
,
473156295.4164052
,
473398034.82676554
,
473897703.5203811
,
474328271.33112127
,
474452670.98002136
,
474549003.99284613
,
474252887.13567275
,
473557462.909069
,
473483385.85193115
,
473609738.04855174
,
473746944.82085115
,
474016729.91696435
,
474617321.94138587
,
475045097.237122
,
475125402.586558
,
474664112.9824912
,
474426247.5800283
,
474104075.42796475
,
473978219.7273978
,
473773171.7798875
,
473578534.69508696
,
473102924.16904145
,
472651240.5232615
,
472374383.1810912
,
472209479.6956096
,
472202298.8921673
,
472370090.76781124
,
472220933.99374026
,
471625467.37106377
,
470994646.51883453
,
470182428.9637543
,
469348211.5939578
,
468570387.4467277
,
468540442.7225135
,
468672018.90414184
,
468994346.9533251
,
469138757.58201426
,
469553915.95710236
,
470134523.38582784
,
471082421.62055486
,
471962316.51804745
,
472939745.1708408
,
474250621.5944825
,
475773933.43199486
,
477465399.71087736
,
479218782.61382693
,
481752299.7930922
,
486608947.8984568
,
496119403.2067917
,
512730085.5704984
,
539048915.2641417
,
576285298.3548826
,
621610270.2240586
,
669308196.4436442
,
710656993.5957186
,
736344437.3725077
,
745481288.0241544
,
801121432.9925804
};
int
count_
=
912592
;
void
WriteMatrix
()
{
kaldi
::
Matrix
<
double
>
cmvn_stats
(
2
,
mean_
.
size
()
+
1
);
kaldi
::
Matrix
<
double
>
cmvn_stats
(
2
,
mean_
.
size
()
+
1
);
for
(
size_t
idx
=
0
;
idx
<
mean_
.
size
();
++
idx
)
{
cmvn_stats
(
0
,
idx
)
=
mean_
[
idx
];
cmvn_stats
(
1
,
idx
)
=
variance_
[
idx
];
...
...
@@ -33,12 +155,15 @@ int main(int argc, char* argv[]) {
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
FLAGS_wav_rspecifier
);
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
FLAGS_wav_rspecifier
);
kaldi
::
BaseFloatMatrixWriter
feat_writer
(
FLAGS_feature_wspecifier
);
kaldi
::
BaseFloatMatrixWriter
feat_cmvn_check_writer
(
FLAGS_feature_check_wspecifier
);
kaldi
::
BaseFloatMatrixWriter
feat_cmvn_check_writer
(
FLAGS_feature_check_wspecifier
);
WriteMatrix
();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> cmvn
int32
num_done
=
0
,
num_err
=
0
;
ppspeech
::
LinearSpectrogramOptions
opt
;
opt
.
frame_opts
.
frame_length_ms
=
20
;
...
...
@@ -46,7 +171,8 @@ int main(int argc, char* argv[]) {
ppspeech
::
DecibelNormalizerOptions
db_norm_opt
;
std
::
unique_ptr
<
ppspeech
::
FeatureExtractorInterface
>
base_feature_extractor
(
new
ppspeech
::
DecibelNormalizer
(
db_norm_opt
));
ppspeech
::
LinearSpectrogram
linear_spectrogram
(
opt
,
std
::
move
(
base_feature_extractor
));
ppspeech
::
LinearSpectrogram
linear_spectrogram
(
opt
,
std
::
move
(
base_feature_extractor
));
ppspeech
::
CMVN
cmvn
(
FLAGS_cmvn_write_path
);
...
...
@@ -66,16 +192,18 @@ int main(int argc, char* argv[]) {
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
int32
this_channel
=
0
;
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
int
tot_samples
=
waveform
.
Dim
();
int
sample_offset
=
0
;
std
::
vector
<
kaldi
::
Matrix
<
BaseFloat
>>
feats
;
int
feature_rows
=
0
;
while
(
sample_offset
<
tot_samples
)
{
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
(
i
)
=
waveform
(
sample_offset
+
i
);
...
...
@@ -90,11 +218,14 @@ int main(int argc, char* argv[]) {
}
int
cur_idx
=
0
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
features
(
feature_rows
,
feats
[
0
].
NumCols
());
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
features
(
feature_rows
,
feats
[
0
].
NumCols
());
for
(
auto
feat
:
feats
)
{
for
(
int
row_idx
=
0
;
row_idx
<
feat
.
NumRows
();
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
feat
.
NumCols
();
++
col_idx
)
{
features
(
cur_idx
,
col_idx
)
=
(
feat
(
row_idx
,
col_idx
)
-
mean_
[
col_idx
])
*
variance_
[
col_idx
];
features
(
cur_idx
,
col_idx
)
=
(
feat
(
row_idx
,
col_idx
)
-
mean_
[
col_idx
])
*
variance_
[
col_idx
];
}
++
cur_idx
;
}
...
...
@@ -102,7 +233,8 @@ int main(int argc, char* argv[]) {
feat_writer
.
Write
(
utt
,
features
);
cur_idx
=
0
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
features_check
(
feature_rows
,
feats
[
0
].
NumCols
());
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
features_check
(
feature_rows
,
feats
[
0
].
NumCols
());
for
(
auto
feat
:
feats
)
{
for
(
int
row_idx
=
0
;
row_idx
<
feat
.
NumRows
();
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
feat
.
NumCols
();
++
col_idx
)
{
...
...
speechx/examples/nnet/pp-model-test.cc
浏览文件 @
41feecbd
#include "paddle_inference_api.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <iostream>
#include <thread>
#include <algorithm>
#include <fstream>
#include <functional>
#include <iostream>
#include <iterator>
#include <algorithm>
#include <numeric>
#include <functional>
#include <thread>
#include "paddle_inference_api.h"
using
std
::
cout
;
using
std
::
endl
;
...
...
@@ -39,7 +53,8 @@ void model_forward_test() {
std
::
vector
<
std
::
vector
<
float
>>
feats
;
produce_data
(
&
feats
);
std
::
cout
<<
"2. load the model"
<<
std
::
endl
;;
std
::
cout
<<
"2. load the model"
<<
std
::
endl
;
;
std
::
string
model_graph
=
FLAGS_model_path
;
std
::
string
model_params
=
FLAGS_param_path
;
cout
<<
"model path: "
<<
model_graph
<<
endl
;
...
...
@@ -53,9 +68,10 @@ void model_forward_test() {
cout
<<
"DisableFCPadding: "
<<
endl
;
auto
predictor
=
paddle_infer
::
CreatePredictor
(
config
);
std
::
cout
<<
"3. feat shape, row="
<<
feats
.
size
()
<<
",col="
<<
feats
[
0
].
size
()
<<
std
::
endl
;
std
::
cout
<<
"3. feat shape, row="
<<
feats
.
size
()
<<
",col="
<<
feats
[
0
].
size
()
<<
std
::
endl
;
std
::
vector
<
float
>
pp_input_mat
;
for
(
const
auto
&
item
:
feats
)
{
for
(
const
auto
&
item
:
feats
)
{
pp_input_mat
.
insert
(
pp_input_mat
.
end
(),
item
.
begin
(),
item
.
end
());
}
...
...
@@ -64,10 +80,10 @@ void model_forward_test() {
int
col
=
feats
[
0
].
size
();
std
::
vector
<
std
::
string
>
input_names
=
predictor
->
GetInputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
for
(
auto
name
:
input_names
){
for
(
auto
name
:
input_names
)
{
cout
<<
"model input names: "
<<
name
<<
endl
;
}
for
(
auto
name
:
output_names
){
for
(
auto
name
:
output_names
)
{
cout
<<
"model output names: "
<<
name
<<
endl
;
}
...
...
@@ -79,7 +95,8 @@ void model_forward_test() {
input_tensor
->
CopyFromCpu
(
pp_input_mat
.
data
());
// input length
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
vector
<
int
>
input_len_size
=
{
1
};
input_len
->
Reshape
(
input_len_size
);
std
::
vector
<
int64_t
>
audio_len
;
...
...
@@ -87,20 +104,28 @@ void model_forward_test() {
input_len
->
CopyFromCpu
(
audio_len
.
data
());
// state_h
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
std
::
vector
<
int
>
chunk_state_h_box_shape
=
{
3
,
1
,
1024
};
chunk_state_h_box
->
Reshape
(
chunk_state_h_box_shape
);
int
chunk_state_h_box_size
=
std
::
accumulate
(
chunk_state_h_box_shape
.
begin
(),
chunk_state_h_box_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
chunk_state_h_box_size
=
std
::
accumulate
(
chunk_state_h_box_shape
.
begin
(),
chunk_state_h_box_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
chunk_state_h_box_data
(
chunk_state_h_box_size
,
0.0
f
);
chunk_state_h_box
->
CopyFromCpu
(
chunk_state_h_box_data
.
data
());
// state_c
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
std
::
vector
<
int
>
chunk_state_c_box_shape
=
{
3
,
1
,
1024
};
chunk_state_c_box
->
Reshape
(
chunk_state_c_box_shape
);
int
chunk_state_c_box_size
=
std
::
accumulate
(
chunk_state_c_box_shape
.
begin
(),
chunk_state_c_box_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
chunk_state_c_box_size
=
std
::
accumulate
(
chunk_state_c_box_shape
.
begin
(),
chunk_state_c_box_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
chunk_state_c_box_data
(
chunk_state_c_box_size
,
0.0
f
);
chunk_state_c_box
->
CopyFromCpu
(
chunk_state_c_box_data
.
data
());
...
...
@@ -108,18 +133,20 @@ void model_forward_test() {
bool
success
=
predictor
->
Run
();
// state_h out
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
std
::
vector
<
int
>
h_out_shape
=
h_out
->
shape
();
int
h_out_size
=
std
::
accumulate
(
h_out_shape
.
begin
(),
h_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
h_out_size
=
std
::
accumulate
(
h_out_shape
.
begin
(),
h_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
h_out_data
(
h_out_size
);
h_out
->
CopyToCpu
(
h_out_data
.
data
());
// stage_c out
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
std
::
vector
<
int
>
c_out_shape
=
c_out
->
shape
();
int
c_out_size
=
std
::
accumulate
(
c_out_shape
.
begin
(),
c_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
c_out_size
=
std
::
accumulate
(
c_out_shape
.
begin
(),
c_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
c_out_data
(
c_out_size
);
c_out
->
CopyToCpu
(
c_out_data
.
data
());
...
...
@@ -128,8 +155,8 @@ void model_forward_test() {
predictor
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
output_tensor
->
shape
();
std
::
vector
<
float
>
output_probs
;
int
output_size
=
std
::
accumulate
(
output_shape
.
begin
(),
output_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
output_size
=
std
::
accumulate
(
output_shape
.
begin
(),
output_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
output_probs
.
resize
(
output_size
);
output_tensor
->
CopyToCpu
(
output_probs
.
data
());
row
=
output_shape
[
1
];
...
...
@@ -148,9 +175,11 @@ void model_forward_test() {
}
std
::
vector
<
std
::
vector
<
float
>>
log_feat
=
probs
;
std
::
cout
<<
"probs, row: "
<<
log_feat
.
size
()
<<
" col: "
<<
log_feat
[
0
].
size
()
<<
std
::
endl
;
std
::
cout
<<
"probs, row: "
<<
log_feat
.
size
()
<<
" col: "
<<
log_feat
[
0
].
size
()
<<
std
::
endl
;
for
(
size_t
row_idx
=
0
;
row_idx
<
log_feat
.
size
();
++
row_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
log_feat
[
row_idx
].
size
();
++
col_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
log_feat
[
row_idx
].
size
();
++
col_idx
)
{
std
::
cout
<<
log_feat
[
row_idx
][
col_idx
]
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
...
...
speechx/speechx/base/basic_types.h
浏览文件 @
41feecbd
...
...
@@ -43,18 +43,18 @@ typedef unsigned long long uint64;
typedef
signed
int
char32
;
const
uint8
kuint8max
=
((
uint8
)
0xFF
);
const
uint16
kuint16max
=
((
uint16
)
0xFFFF
);
const
uint32
kuint32max
=
((
uint32
)
0xFFFFFFFF
);
const
uint64
kuint64max
=
((
uint64
)
(
0xFFFFFFFFFFFFFFFFLL
));
const
int8
kint8min
=
((
int8
)
0x80
);
const
int8
kint8max
=
((
int8
)
0x7F
);
const
int16
kint16min
=
((
int16
)
0x8000
);
const
int16
kint16max
=
((
int16
)
0x7FFF
);
const
int32
kint32min
=
((
int32
)
0x80000000
);
const
int32
kint32max
=
((
int32
)
0x7FFFFFFF
);
const
int64
kint64min
=
((
int64
)
(
0x8000000000000000LL
));
const
int64
kint64max
=
((
int64
)
(
0x7FFFFFFFFFFFFFFFLL
));
const
uint8
kuint8max
=
((
uint8
)
0xFF
);
const
uint16
kuint16max
=
((
uint16
)
0xFFFF
);
const
uint32
kuint32max
=
((
uint32
)
0xFFFFFFFF
);
const
uint64
kuint64max
=
((
uint64
)(
0xFFFFFFFFFFFFFFFFLL
));
const
int8
kint8min
=
((
int8
)
0x80
);
const
int8
kint8max
=
((
int8
)
0x7F
);
const
int16
kint16min
=
((
int16
)
0x8000
);
const
int16
kint16max
=
((
int16
)
0x7FFF
);
const
int32
kint32min
=
((
int32
)
0x80000000
);
const
int32
kint32max
=
((
int32
)
0x7FFFFFFF
);
const
int64
kint64min
=
((
int64
)
(
0x8000000000000000LL
));
const
int64
kint64max
=
((
int64
)
(
0x7FFFFFFFFFFFFFFFLL
));
const
BaseFloat
kBaseFloatMax
=
std
::
numeric_limits
<
BaseFloat
>::
max
();
const
BaseFloat
kBaseFloatMin
=
std
::
numeric_limits
<
BaseFloat
>::
min
();
speechx/speechx/base/common.h
浏览文件 @
41feecbd
...
...
@@ -15,22 +15,22 @@
#pragma once
#include <deque>
#include <fstream>
#include <iostream>
#include <istream>
#include <fstream>
#include <map>
#include <memory>
#include <mutex>
#include <ostream>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <
mutex
>
#include <
vector
>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/macros.h"
speechx/speechx/base/macros.h
浏览文件 @
41feecbd
speechx/speechx/base/thread_pool.h
浏览文件 @
41feecbd
...
...
@@ -23,28 +23,29 @@
#ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <thread>
#include <vector>
class
ThreadPool
{
public:
public:
ThreadPool
(
size_t
);
template
<
class
F
,
class
...
Args
>
template
<
class
F
,
class
...
Args
>
auto
enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
;
~
ThreadPool
();
private:
private:
// need to keep track of threads so we can join them
std
::
vector
<
std
::
thread
>
workers
;
std
::
vector
<
std
::
thread
>
workers
;
// the task queue
std
::
queue
<
std
::
function
<
void
()
>
>
tasks
;
std
::
queue
<
std
::
function
<
void
()
>
>
tasks
;
// synchronization
std
::
mutex
queue_mutex
;
...
...
@@ -53,68 +54,57 @@ private:
};
// the constructor just launches some amount of workers
inline
ThreadPool
::
ThreadPool
(
size_t
threads
)
:
stop
(
false
)
{
for
(
size_t
i
=
0
;
i
<
threads
;
++
i
)
workers
.
emplace_back
(
[
this
]
{
for
(;;)
{
inline
ThreadPool
::
ThreadPool
(
size_t
threads
)
:
stop
(
false
)
{
for
(
size_t
i
=
0
;
i
<
threads
;
++
i
)
workers
.
emplace_back
([
this
]
{
for
(;;)
{
std
::
function
<
void
()
>
task
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
queue_mutex
);
this
->
condition
.
wait
(
lock
,
[
this
]{
return
this
->
stop
||
!
this
->
tasks
.
empty
();
}
);
if
(
this
->
stop
&&
this
->
tasks
.
empty
())
return
;
this
->
condition
.
wait
(
lock
,
[
this
]
{
return
this
->
stop
||
!
this
->
tasks
.
empty
(
);
});
if
(
this
->
stop
&&
this
->
tasks
.
empty
())
return
;
task
=
std
::
move
(
this
->
tasks
.
front
());
this
->
tasks
.
pop
();
}
task
();
}
}
);
});
}
// add new work item to the pool
template
<
class
F
,
class
...
Args
>
template
<
class
F
,
class
...
Args
>
auto
ThreadPool
::
enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
{
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
{
using
return_type
=
typename
std
::
result_of
<
F
(
Args
...)
>::
type
;
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>
>
(
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...)
);
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>>
(
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...));
std
::
future
<
return_type
>
res
=
task
->
get_future
();
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
queue_mutex
);
// don't allow enqueueing after stopping the pool
if
(
stop
)
throw
std
::
runtime_error
(
"enqueue on stopped ThreadPool"
);
if
(
stop
)
throw
std
::
runtime_error
(
"enqueue on stopped ThreadPool"
);
tasks
.
emplace
([
task
](){
(
*
task
)();
});
tasks
.
emplace
([
task
]()
{
(
*
task
)();
});
}
condition
.
notify_one
();
return
res
;
}
// the destructor joins all threads
inline
ThreadPool
::~
ThreadPool
()
{
inline
ThreadPool
::~
ThreadPool
()
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
queue_mutex
);
stop
=
true
;
}
condition
.
notify_all
();
for
(
std
::
thread
&
worker
:
workers
)
worker
.
join
();
for
(
std
::
thread
&
worker
:
workers
)
worker
.
join
();
}
#endif
speechx/speechx/decoder/common.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/basic_types.h"
struct
DecoderResult
{
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
...
...
@@ -9,25 +23,23 @@ namespace ppspeech {
using
std
::
vector
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
CTCBeamSearch
::
CTCBeamSearch
(
const
CTCBeamSearchOptions
&
opts
)
:
opts_
(
opts
),
CTCBeamSearch
::
CTCBeamSearch
(
const
CTCBeamSearchOptions
&
opts
)
:
opts_
(
opts
),
init_ext_scorer_
(
nullptr
),
blank_id
(
-
1
),
space_id
(
-
1
),
num_frame_decoded_
(
0
),
root
(
nullptr
)
{
LOG
(
INFO
)
<<
"dict path: "
<<
opts_
.
dict_file
;
if
(
!
ReadFileToVector
(
opts_
.
dict_file
,
&
vocabulary_
))
{
LOG
(
INFO
)
<<
"load the dict failed"
;
}
LOG
(
INFO
)
<<
"read the vocabulary success, dict size: "
<<
vocabulary_
.
size
();
LOG
(
INFO
)
<<
"read the vocabulary success, dict size: "
<<
vocabulary_
.
size
();
LOG
(
INFO
)
<<
"language model path: "
<<
opts_
.
lm_path
;
init_ext_scorer_
=
std
::
make_shared
<
Scorer
>
(
opts_
.
alpha
,
opts_
.
beta
,
opts_
.
lm_path
,
vocabulary_
);
init_ext_scorer_
=
std
::
make_shared
<
Scorer
>
(
opts_
.
alpha
,
opts_
.
beta
,
opts_
.
lm_path
,
vocabulary_
);
}
void
CTCBeamSearch
::
Reset
()
{
...
...
@@ -36,7 +48,6 @@ void CTCBeamSearch::Reset() {
}
void
CTCBeamSearch
::
InitDecoder
()
{
blank_id
=
0
;
auto
it
=
std
::
find
(
vocabulary_
.
begin
(),
vocabulary_
.
end
(),
" "
);
...
...
@@ -51,10 +62,11 @@ void CTCBeamSearch::InitDecoder() {
root
=
std
::
make_shared
<
PathTrie
>
();
root
->
score
=
root
->
log_prob_b_prev
=
0.0
;
prefixes
.
push_back
(
root
.
get
());
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
init_ext_scorer_
->
dictionary
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
static_cast
<
fst
::
StdVectorFst
*>
(
init_ext_scorer_
->
dictionary
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
root
->
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
...
...
@@ -62,23 +74,24 @@ void CTCBeamSearch::InitDecoder() {
}
}
void
CTCBeamSearch
::
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
)
{
void
CTCBeamSearch
::
Decode
(
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>
decodable
)
{
return
;
}
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
;
}
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
;
}
// todo rename, refactor
void
CTCBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
void
CTCBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
int
max_frames
)
{
while
(
max_frames
>
0
)
{
vector
<
vector
<
BaseFloat
>>
likelihood
;
if
(
decodable
->
IsLastFrame
(
NumFrameDecoded
()
+
1
))
{
break
;
}
likelihood
.
push_back
(
decodable
->
FrameLogLikelihood
(
NumFrameDecoded
()
+
1
));
likelihood
.
push_back
(
decodable
->
FrameLogLikelihood
(
NumFrameDecoded
()
+
1
));
AdvanceDecoding
(
likelihood
);
max_frames
--
;
}
...
...
@@ -93,12 +106,13 @@ void CTCBeamSearch::ResetPrefixes() {
}
}
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
vector
<
string
>&
nbest_words
)
{
kaldi
::
Timer
timer
;
timer
.
Reset
();
AdvanceDecoding
(
probs
);
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
<<
static_cast
<
float
>
(
timer
.
Elapsed
())
/
1000.0
f
;
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
<<
static_cast
<
float
>
(
timer
.
Elapsed
())
/
1000.0
f
;
return
0
;
}
...
...
@@ -124,12 +138,13 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
double
cutoff_prob
=
opts_
.
cutoff_prob
;
size_t
cutoff_top_n
=
opts_
.
cutoff_top_n
;
vector
<
vector
<
double
>>
probs_seq
(
probs
.
size
(),
vector
<
double
>
(
probs
[
0
].
size
(),
0
));
vector
<
vector
<
double
>>
probs_seq
(
probs
.
size
(),
vector
<
double
>
(
probs
[
0
].
size
(),
0
));
int
row
=
probs
.
size
();
int
col
=
probs
[
0
].
size
();
for
(
int
i
=
0
;
i
<
row
;
i
++
)
{
for
(
int
j
=
0
;
j
<
col
;
j
++
)
{
for
(
int
i
=
0
;
i
<
row
;
i
++
)
{
for
(
int
j
=
0
;
j
<
col
;
j
++
)
{
probs_seq
[
i
][
j
]
=
static_cast
<
double
>
(
probs
[
i
][
j
]);
}
}
...
...
@@ -141,7 +156,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
bool
full_beam
=
false
;
if
(
init_ext_scorer_
!=
nullptr
)
{
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
if
(
num_prefixes
==
0
)
{
...
...
@@ -181,7 +197,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
}
// for probs_seq
}
int32
CTCBeamSearch
::
SearchOneChar
(
const
bool
&
full_beam
,
int32
CTCBeamSearch
::
SearchOneChar
(
const
bool
&
full_beam
,
const
std
::
pair
<
size_t
,
BaseFloat
>&
log_prob_idx
,
const
BaseFloat
&
min_cutoff
)
{
size_t
beam_size
=
opts_
.
beam_size
;
...
...
@@ -196,10 +213,8 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
}
if
(
c
==
blank_id
)
{
prefix
->
log_prob_b_cur
=
log_sum_exp
(
prefix
->
log_prob_b_cur
,
log_prob_c
+
prefix
->
score
);
prefix
->
log_prob_b_cur
=
log_sum_exp
(
prefix
->
log_prob_b_cur
,
log_prob_c
+
prefix
->
score
);
continue
;
}
...
...
@@ -207,9 +222,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
if
(
c
==
prefix
->
character
)
{
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix
->
log_prob_nb_cur
=
log_sum_exp
(
prefix
->
log_prob_nb_cur
,
log_prob_c
+
prefix
->
log_prob_nb_prev
);
prefix
->
log_prob_nb_cur
,
log_prob_c
+
prefix
->
log_prob_nb_prev
);
}
// get new prefix
...
...
@@ -228,7 +241,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
// language model scoring
if
(
init_ext_scorer_
!=
nullptr
&&
(
c
==
space_id
||
init_ext_scorer_
->
is_character_based
()))
{
PathTrie
*
prefix_to_score
=
nullptr
;
PathTrie
*
prefix_to_score
=
nullptr
;
// skip scoring the space
if
(
init_ext_scorer_
->
is_character_based
())
{
prefix_to_score
=
prefix_new
;
...
...
@@ -247,8 +260,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
}
// p_{nb}(l;x_{1:t})
prefix_new
->
log_prob_nb_cur
=
log_sum_exp
(
prefix_new
->
log_prob_nb_cur
,
log_p
);
log_sum_exp
(
prefix_new
->
log_prob_nb_cur
,
log_p
);
}
}
// end of loop over prefix
return
0
;
...
...
@@ -258,9 +270,7 @@ void CTCBeamSearch::CalculateApproxScore() {
size_t
beam_size
=
opts_
.
beam_size
;
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
...
...
@@ -274,8 +284,8 @@ void CTCBeamSearch::CalculateApproxScore() {
// remove word insert
approx_ctc
=
approx_ctc
-
prefix_length
*
init_ext_scorer_
->
beta
;
// remove language model weight:
approx_ctc
-=
(
init_ext_scorer_
->
get_sent_log_prob
(
words
))
*
init_ext_scorer_
->
alpha
;
approx_ctc
-=
(
init_ext_scorer_
->
get_sent_log_prob
(
words
))
*
init_ext_scorer_
->
alpha
;
}
prefixes
[
i
]
->
approx_ctc
=
approx_ctc
;
}
...
...
@@ -283,13 +293,15 @@ void CTCBeamSearch::CalculateApproxScore() {
void
CTCBeamSearch
::
LMRescore
()
{
size_t
beam_size
=
opts_
.
beam_size
;
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
if
(
init_ext_scorer_
!=
nullptr
&&
!
init_ext_scorer_
->
is_character_based
())
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
auto
prefix
=
prefixes
[
i
];
if
(
!
prefix
->
is_empty
()
&&
prefix
->
character
!=
space_id
)
{
float
score
=
0.0
;
vector
<
string
>
ngram
=
init_ext_scorer_
->
make_ngram
(
prefix
);
score
=
init_ext_scorer_
->
get_log_cond_prob
(
ngram
)
*
init_ext_scorer_
->
alpha
;
score
=
init_ext_scorer_
->
get_log_cond_prob
(
ngram
)
*
init_ext_scorer_
->
alpha
;
score
+=
init_ext_scorer_
->
beta
;
prefix
->
score
+=
score
;
}
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h"
#include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once
...
...
@@ -17,26 +31,27 @@ struct CTCBeamSearchOptions {
int
beam_size
;
int
cutoff_top_n
;
int
num_proc_bsearch
;
CTCBeamSearchOptions
()
:
dict_file
(
"./model/words.txt"
),
CTCBeamSearchOptions
()
:
dict_file
(
"./model/words.txt"
),
lm_path
(
"./model/lm.arpa"
),
alpha
(
1.9
f
),
beta
(
5.0
),
beam_size
(
300
),
cutoff_prob
(
0.99
f
),
cutoff_top_n
(
40
),
num_proc_bsearch
(
0
)
{
}
num_proc_bsearch
(
0
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"dict"
,
&
dict_file
,
"dict file "
);
opts
->
Register
(
"lm-path"
,
&
lm_path
,
"language model file"
);
opts
->
Register
(
"alpha"
,
&
alpha
,
"alpha"
);
opts
->
Register
(
"beta"
,
&
beta
,
"beta"
);
opts
->
Register
(
"beam-size"
,
&
beam_size
,
"beam size for beam search method"
);
opts
->
Register
(
"beam-size"
,
&
beam_size
,
"beam size for beam search method"
);
opts
->
Register
(
"cutoff-prob"
,
&
cutoff_prob
,
"cutoff probs"
);
opts
->
Register
(
"cutoff-top-n"
,
&
cutoff_top_n
,
"cutoff top n"
);
opts
->
Register
(
"num-proc-bsearch"
,
&
num_proc_bsearch
,
"num proc bsearch"
);
opts
->
Register
(
"num-proc-bsearch"
,
&
num_proc_bsearch
,
"num proc bsearch"
);
}
};
...
...
@@ -50,11 +65,13 @@ class CTCBeamSearch {
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
();
std
::
string
GetFinalBestPath
();
int
NumFrameDecoded
();
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
int
max_frames
);
void
Reset
();
private:
void
ResetPrefixes
();
int32
SearchOneChar
(
const
bool
&
full_beam
,
...
...
@@ -66,7 +83,7 @@ class CTCBeamSearch {
CTCBeamSearchOptions
opts_
;
std
::
shared_ptr
<
Scorer
>
init_ext_scorer_
;
// todo separate later
//
std::vector<DecodeResult> decoder_results_;
//
std::vector<DecodeResult> decoder_results_;
std
::
vector
<
std
::
string
>
vocabulary_
;
// todo remove later
size_t
blank_id
;
int
space_id
;
...
...
speechx/speechx/frontend/fbank.h
浏览文件 @
41feecbd
...
...
@@ -24,7 +24,8 @@ class FbankExtractor : FeatureExtractorInterface {
public:
explicit
FbankExtractor
(
const
FbankOptions
&
opts
,
share_ptr
<
FeatureExtractorInterface
>
pre_extractor
);
virtual
void
AcceptWaveform
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
input
)
=
0
;
virtual
void
AcceptWaveform
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
input
)
=
0
;
virtual
void
Read
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feat
)
=
0
;
virtual
size_t
Dim
()
const
=
0
;
...
...
speechx/speechx/frontend/feature_extractor_controller.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
speechx/speechx/frontend/feature_extractor_controller_impl.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
speechx/speechx/frontend/feature_extractor_interface.h
浏览文件 @
41feecbd
...
...
@@ -21,7 +21,8 @@ namespace ppspeech {
class
FeatureExtractorInterface
{
public:
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
)
=
0
;
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
)
=
0
;
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
)
=
0
;
virtual
size_t
Dim
()
const
=
0
;
};
...
...
speechx/speechx/frontend/linear_spectrogram.cc
浏览文件 @
41feecbd
...
...
@@ -25,7 +25,7 @@ using kaldi::VectorBase;
using
kaldi
::
Matrix
;
using
std
::
vector
;
//todo remove later
//
todo remove later
void
CopyVector2StdVector_
(
const
VectorBase
<
BaseFloat
>&
input
,
vector
<
BaseFloat
>*
output
)
{
if
(
input
.
Dim
()
==
0
)
return
;
...
...
speechx/speechx/frontend/linear_spectrogram.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h"
#include "base/common.h"
namespace
ppspeech
{
struct
LinearSpectrogramOptions
{
kaldi
::
FrameExtractionOptions
frame_opts
;
LinearSpectrogramOptions
()
:
frame_opts
()
{}
LinearSpectrogramOptions
()
:
frame_opts
()
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
frame_opts
.
Register
(
opts
);
}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
frame_opts
.
Register
(
opts
);
}
};
class
LinearSpectrogram
:
public
FeatureExtractorInterface
{
public:
explicit
LinearSpectrogram
(
const
LinearSpectrogramOptions
&
opts
,
explicit
LinearSpectrogram
(
const
LinearSpectrogramOptions
&
opts
,
std
::
unique_ptr
<
FeatureExtractorInterface
>
base_extractor
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
size_t
Dim
()
const
{
return
dim_
;
}
void
ReadFeats
(
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
feats
);
...
...
speechx/speechx/frontend/normalizer.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/normalizer.h"
#include "kaldi/feat/cmvn.h"
...
...
@@ -16,7 +30,8 @@ DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
dim_
=
0
;
}
void
DecibelNormalizer
::
AcceptWaveform
(
const
kaldi
::
VectorBase
<
BaseFloat
>&
input
)
{
void
DecibelNormalizer
::
AcceptWaveform
(
const
kaldi
::
VectorBase
<
BaseFloat
>&
input
)
{
dim_
=
input
.
Dim
();
waveform_
.
Resize
(
input
.
Dim
());
waveform_
.
CopyFromVec
(
input
);
...
...
@@ -27,7 +42,7 @@ void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
Compute
(
waveform_
,
feat
);
}
//todo remove later
//
todo remove later
void
CopyVector2StdVector
(
const
kaldi
::
VectorBase
<
BaseFloat
>&
input
,
vector
<
BaseFloat
>*
output
)
{
if
(
input
.
Dim
()
==
0
)
return
;
...
...
@@ -61,7 +76,7 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
}
// square
for
(
auto
&
d
:
samples
)
{
for
(
auto
&
d
:
samples
)
{
if
(
opts_
.
convert_int_float
)
{
d
=
d
*
wave_float_normlization
;
}
...
...
@@ -74,14 +89,15 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
gain
=
opts_
.
target_db
-
rms_db
;
if
(
gain
>
opts_
.
max_gain_db
)
{
LOG
(
ERROR
)
<<
"Unable to normalize segment to "
<<
opts_
.
target_db
<<
"dB,"
LOG
(
ERROR
)
<<
"Unable to normalize segment to "
<<
opts_
.
target_db
<<
"dB,"
<<
"because the the probable gain have exceeds opts_.max_gain_db"
<<
opts_
.
max_gain_db
<<
"dB."
;
return
false
;
}
// Note that this is an in-place transformation.
for
(
auto
&
item
:
samples
)
{
for
(
auto
&
item
:
samples
)
{
// python item *= 10.0 ** (gain / 20.0)
item
*=
std
::
pow
(
10.0
,
gain
/
20.0
);
}
...
...
@@ -100,21 +116,20 @@ void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) {
return
;
}
void
CMVN
::
Read
(
kaldi
::
VectorBase
<
BaseFloat
>*
feat
)
{
return
;
}
void
CMVN
::
Read
(
kaldi
::
VectorBase
<
BaseFloat
>*
feat
)
{
return
;
}
// feats contain num_frames feature.
void
CMVN
::
ApplyCMVN
(
bool
var_norm
,
VectorBase
<
BaseFloat
>*
feats
)
{
KALDI_ASSERT
(
feats
!=
NULL
);
int32
dim
=
stats_
.
NumCols
()
-
1
;
if
(
stats_
.
NumRows
()
>
2
||
stats_
.
NumRows
()
<
1
||
feats
->
Dim
()
%
dim
!=
0
)
{
KALDI_ERR
<<
"Dim mismatch: cmvn "
<<
stats_
.
NumRows
()
<<
'x'
<<
stats_
.
NumCols
()
<<
", feats "
<<
feats
->
Dim
()
<<
'x'
;
if
(
stats_
.
NumRows
()
>
2
||
stats_
.
NumRows
()
<
1
||
feats
->
Dim
()
%
dim
!=
0
)
{
KALDI_ERR
<<
"Dim mismatch: cmvn "
<<
stats_
.
NumRows
()
<<
'x'
<<
stats_
.
NumCols
()
<<
", feats "
<<
feats
->
Dim
()
<<
'x'
;
}
if
(
stats_
.
NumRows
()
==
1
&&
var_norm
)
{
KALDI_ERR
<<
"You requested variance normalization but no variance stats_ "
KALDI_ERR
<<
"You requested variance normalization but no variance stats_ "
<<
"are supplied."
;
}
...
...
@@ -122,17 +137,20 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one.
if
(
count
<
1.0
)
KALDI_ERR
<<
"Insufficient stats_ for cepstral mean and variance normalization: "
KALDI_ERR
<<
"Insufficient stats_ for cepstral mean and variance "
"normalization: "
<<
"count = "
<<
count
;
if
(
!
var_norm
)
{
Vector
<
BaseFloat
>
offset
(
feats
->
Dim
());
SubVector
<
double
>
mean_stats
(
stats_
.
RowData
(
0
),
dim
);
Vector
<
double
>
mean_stats_apply
(
feats
->
Dim
());
//fill the datat of mean_stats in mean_stats_appy whose dim is equal with the dim of feature.
//the dim of feats = dim * num_frames;
// fill the datat of mean_stats in mean_stats_appy whose dim is equal
// with the dim of feature.
// the dim of feats = dim * num_frames;
for
(
int32
idx
=
0
;
idx
<
feats
->
Dim
()
/
dim
;
++
idx
)
{
SubVector
<
double
>
stats_tmp
(
mean_stats_apply
.
Data
()
+
dim
*
idx
,
dim
);
SubVector
<
double
>
stats_tmp
(
mean_stats_apply
.
Data
()
+
dim
*
idx
,
dim
);
stats_tmp
.
CopyFromVec
(
mean_stats
);
}
offset
.
AddVec
(
-
1.0
/
count
,
mean_stats_apply
);
...
...
@@ -144,18 +162,18 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
kaldi
::
Matrix
<
BaseFloat
>
norm
(
2
,
feats
->
Dim
());
for
(
int32
d
=
0
;
d
<
dim
;
d
++
)
{
double
mean
,
offset
,
scale
;
mean
=
stats_
(
0
,
d
)
/
count
;
double
var
=
(
stats_
(
1
,
d
)
/
count
)
-
mean
*
mean
,
floor
=
1.0e-20
;
mean
=
stats_
(
0
,
d
)
/
count
;
double
var
=
(
stats_
(
1
,
d
)
/
count
)
-
mean
*
mean
,
floor
=
1.0e-20
;
if
(
var
<
floor
)
{
KALDI_WARN
<<
"Flooring cepstral variance from "
<<
var
<<
" to "
<<
floor
;
var
=
floor
;
}
scale
=
1.0
/
sqrt
(
var
);
if
(
scale
!=
scale
||
1
/
scale
==
0.0
)
KALDI_ERR
<<
"NaN or infinity in cepstral mean/variance computation"
;
offset
=
-
(
mean
*
scale
);
if
(
scale
!=
scale
||
1
/
scale
==
0.0
)
KALDI_ERR
<<
"NaN or infinity in cepstral mean/variance computation"
;
offset
=
-
(
mean
*
scale
);
for
(
int32
d_skip
=
d
;
d_skip
<
feats
->
Dim
();)
{
norm
(
0
,
d_skip
)
=
offset
;
norm
(
1
,
d_skip
)
=
scale
;
...
...
speechx/speechx/frontend/normalizer.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/options-itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace
ppspeech
{
...
...
@@ -12,26 +26,30 @@ struct DecibelNormalizerOptions {
float
target_db
;
float
max_gain_db
;
bool
convert_int_float
;
DecibelNormalizerOptions
()
:
target_db
(
-
20
),
max_gain_db
(
300.0
),
convert_int_float
(
false
)
{}
DecibelNormalizerOptions
()
:
target_db
(
-
20
),
max_gain_db
(
300.0
),
convert_int_float
(
false
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"target-db"
,
&
target_db
,
"target db for db normalization"
);
opts
->
Register
(
"max-gain-db"
,
&
max_gain_db
,
"max gain db for db normalization"
);
opts
->
Register
(
"convert-int-float"
,
&
convert_int_float
,
"if convert int samples to float"
);
opts
->
Register
(
"target-db"
,
&
target_db
,
"target db for db normalization"
);
opts
->
Register
(
"max-gain-db"
,
&
max_gain_db
,
"max gain db for db normalization"
);
opts
->
Register
(
"convert-int-float"
,
&
convert_int_float
,
"if convert int samples to float"
);
}
};
class
DecibelNormalizer
:
public
FeatureExtractorInterface
{
public:
explicit
DecibelNormalizer
(
const
DecibelNormalizerOptions
&
opts
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
size_t
Dim
()
const
{
return
dim_
;
}
bool
Compute
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
,
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
)
const
;
private:
DecibelNormalizerOptions
opts_
;
size_t
dim_
;
...
...
@@ -43,7 +61,8 @@ class DecibelNormalizer : public FeatureExtractorInterface {
class
CMVN
:
public
FeatureExtractorInterface
{
public:
explicit
CMVN
(
std
::
string
cmvn_file
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
size_t
Dim
()
const
{
return
stats_
.
NumCols
()
-
1
;
}
bool
Compute
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
,
...
...
@@ -51,6 +70,7 @@ class CMVN : public FeatureExtractorInterface {
// for test
void
ApplyCMVN
(
bool
var_norm
,
kaldi
::
VectorBase
<
BaseFloat
>*
feats
);
void
ApplyCMVNMatrix
(
bool
var_norm
,
kaldi
::
MatrixBase
<
BaseFloat
>*
feats
);
private:
kaldi
::
Matrix
<
double
>
stats_
;
std
::
shared_ptr
<
FeatureExtractorInterface
>
base_extractor_
;
...
...
speechx/speechx/frontend/window.h
浏览文件 @
41feecbd
...
...
@@ -13,4 +13,3 @@
// limitations under the License.
// extract the window of kaldi feat.
speechx/speechx/nnet/decodable-itf.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
...
...
@@ -42,8 +56,10 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
There are two ways these are used: the old online-decoding code, in
../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do:
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
...
...
@@ -52,13 +68,16 @@ namespace kaldi {
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for
online-decoding mode is still supported, it is what happens when you call,
for
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
control endpointing. In the "new" decoding code, you don't call (for
example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
...
...
@@ -68,7 +87,8 @@ namespace kaldi {
}
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will
you are starting with a matrix of features, the NumFramesReady() function
will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
...
...
@@ -82,30 +102,39 @@ namespace kaldi {
class
DecodableInterface
{
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this.
virtual
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
)
=
0
;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting
/// be careful). Caution: the behavior of this function in an online
/// setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide
/// we haven't yet decided to terminate decoding, but later true if we
/// decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// decoding-from-matrix setting where we want to allow the last delta or
/// LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual
bool
IsLastFrame
(
int32
frame
)
const
=
0
;
/// The call NumFramesReady() will return the number of frames currently available
/// for this decodable object. This is for use in setups where you don't want the
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// The call NumFramesReady() will return the number of frames currently
/// available
/// for this decodable object. This is for use in setups where you don't
/// want the
/// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// and I hope, going forward, to rely on this mechanism more than
/// IsLastFrame to
/// know when to stop decoding.
virtual
int32
NumFramesReady
()
const
{
KALDI_ERR
<<
"NumFramesReady() not implemented for this decodable type."
;
KALDI_ERR
<<
"NumFramesReady() not implemented for this decodable type."
;
return
-
1
;
}
...
...
speechx/speechx/nnet/decodable.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/decodable.h"
namespace
ppspeech
{
...
...
@@ -5,18 +19,14 @@ namespace ppspeech {
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
)
:
frontend_
(
NULL
),
nnet_
(
nnet
),
finished_
(
false
),
frames_ready_
(
0
)
{
}
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
)
:
frontend_
(
NULL
),
nnet_
(
nnet
),
finished_
(
false
),
frames_ready_
(
0
)
{}
void
Decodable
::
Acceptlikelihood
(
const
Matrix
<
BaseFloat
>&
likelihood
)
{
frames_ready_
+=
likelihood
.
NumRows
();
}
//Decodable::Init(DecodableConfig config) {
//
Decodable::Init(DecodableConfig config) {
//}
bool
Decodable
::
IsLastFrame
(
int32
frame
)
const
{
...
...
@@ -24,18 +34,14 @@ bool Decodable::IsLastFrame(int32 frame) const {
return
finished_
&&
(
frame
==
frames_ready_
-
1
);
}
int32
Decodable
::
NumIndices
()
const
{
return
0
;
}
int32
Decodable
::
NumIndices
()
const
{
return
0
;
}
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
return
0
;
}
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
return
0
;
}
void
Decodable
::
FeedFeatures
(
const
Matrix
<
kaldi
::
BaseFloat
>&
features
)
{
nnet_
->
FeedForward
(
features
,
&
nnet_cache_
);
frames_ready_
+=
nnet_cache_
.
NumRows
();
return
;
return
;
}
std
::
vector
<
BaseFloat
>
Decodable
::
FrameLogLikelihood
(
int32
frame
)
{
...
...
speechx/speechx/nnet/decodable.h
浏览文件 @
41feecbd
#include "nnet/decodable-itf.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h"
namespace
ppspeech
{
...
...
@@ -11,15 +25,18 @@ struct DecodableOpts;
class
Decodable
:
public
kaldi
::
DecodableInterface
{
public:
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
);
//void Init(DecodableOpts config);
//
void Init(DecodableOpts config);
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
virtual
bool
IsLastFrame
(
int32
frame
)
const
;
virtual
int32
NumIndices
()
const
;
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
);
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
// remove later
void
FeedFeatures
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
feature
);
// only for test, todo remove later
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
// remove later
void
FeedFeatures
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
feature
);
// only for test, todo remove later
void
Reset
();
void
InputFinished
()
{
finished_
=
true
;
}
private:
std
::
shared_ptr
<
FeatureExtractorInterface
>
frontend_
;
std
::
shared_ptr
<
NnetInterface
>
nnet_
;
...
...
speechx/speechx/nnet/nnet_interface.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
...
...
@@ -10,10 +24,9 @@ namespace ppspeech {
class
NnetInterface
{
public:
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
)
=
0
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
)
=
0
;
virtual
void
Reset
()
=
0
;
virtual
~
NnetInterface
()
{}
};
}
// namespace ppspeech
speechx/speechx/nnet/paddle_nnet.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/paddle_nnet.h"
#include "absl/strings/str_split.h"
...
...
@@ -21,18 +35,18 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std
::
vector
<
std
::
string
>
tmp_shape
;
tmp_shape
=
absl
::
StrSplit
(
cache_shapes
[
i
],
"-"
);
std
::
vector
<
int
>
cur_shape
;
std
::
transform
(
tmp_shape
.
begin
(),
tmp_shape
.
end
(),
std
::
transform
(
tmp_shape
.
begin
(),
tmp_shape
.
end
(),
std
::
back_inserter
(
cur_shape
),
[](
const
std
::
string
&
s
)
{
return
atoi
(
s
.
c_str
());
});
[](
const
std
::
string
&
s
)
{
return
atoi
(
s
.
c_str
());
});
cache_names_idx_
[
cache_names
[
i
]]
=
i
;
std
::
shared_ptr
<
Tensor
<
BaseFloat
>>
cache_eout
=
std
::
make_shared
<
Tensor
<
BaseFloat
>>
(
cur_shape
);
std
::
shared_ptr
<
Tensor
<
BaseFloat
>>
cache_eout
=
std
::
make_shared
<
Tensor
<
BaseFloat
>>
(
cur_shape
);
cache_encouts_
.
push_back
(
cache_eout
);
}
}
PaddleNnet
::
PaddleNnet
(
const
ModelOptions
&
opts
)
:
opts_
(
opts
)
{
PaddleNnet
::
PaddleNnet
(
const
ModelOptions
&
opts
)
:
opts_
(
opts
)
{
paddle_infer
::
Config
config
;
config
.
SetModel
(
opts
.
model_path
,
opts
.
params_path
);
if
(
opts
.
use_gpu
)
{
...
...
@@ -45,7 +59,8 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
if
(
opts
.
enable_profile
)
{
config
.
EnableProfile
();
}
pool
.
reset
(
new
paddle_infer
::
services
::
PredictorPool
(
config
,
opts
.
thread_num
));
pool
.
reset
(
new
paddle_infer
::
services
::
PredictorPool
(
config
,
opts
.
thread_num
));
if
(
pool
==
nullptr
)
{
LOG
(
ERROR
)
<<
"create the predictor pool failed"
;
}
...
...
@@ -68,16 +83,14 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
std
::
vector
<
std
::
string
>
model_output_names
=
predictor
->
GetOutputNames
();
assert
(
output_names_vec
.
size
()
==
model_output_names
.
size
());
for
(
size_t
i
=
0
;
i
<
output_names_vec
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
output_names_vec
.
size
();
i
++
)
{
assert
(
output_names_vec
[
i
]
==
model_output_names
[
i
]);
}
ReleasePredictor
(
predictor
);
InitCacheEncouts
(
opts
);
}
void
PaddleNnet
::
Reset
()
{
InitCacheEncouts
(
opts_
);
}
void
PaddleNnet
::
Reset
()
{
InitCacheEncouts
(
opts_
);
}
paddle_infer
::
Predictor
*
PaddleNnet
::
GetPredictor
()
{
LOG
(
INFO
)
<<
"attempt to get a new predictor instance "
<<
std
::
endl
;
...
...
@@ -130,13 +143,14 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
return
cache_encouts_
[
iter
->
second
];
}
void
PaddleNnet
::
FeedForward
(
const
Matrix
<
BaseFloat
>&
features
,
Matrix
<
BaseFloat
>*
inferences
)
{
void
PaddleNnet
::
FeedForward
(
const
Matrix
<
BaseFloat
>&
features
,
Matrix
<
BaseFloat
>*
inferences
)
{
paddle_infer
::
Predictor
*
predictor
=
GetPredictor
();
int
row
=
features
.
NumRows
();
int
col
=
features
.
NumCols
();
std
::
vector
<
BaseFloat
>
feed_feature
;
// todo refactor feed feature: SmileGoat
feed_feature
.
reserve
(
row
*
col
);
feed_feature
.
reserve
(
row
*
col
);
for
(
size_t
row_idx
=
0
;
row_idx
<
features
.
NumRows
();
++
row_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
features
.
NumCols
();
++
col_idx
)
{
feed_feature
.
push_back
(
features
(
row_idx
,
col_idx
));
...
...
@@ -146,22 +160,26 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
LOG
(
INFO
)
<<
"feat info: row="
<<
row
<<
", col= "
<<
col
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
row
,
col
};
input_tensor
->
Reshape
(
INPUT_SHAPE
);
input_tensor
->
CopyFromCpu
(
feed_feature
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
vector
<
int
>
input_len_size
=
{
1
};
input_len
->
Reshape
(
input_len_size
);
std
::
vector
<
int64_t
>
audio_len
;
audio_len
.
push_back
(
row
);
input_len
->
CopyFromCpu
(
audio_len
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
shared_ptr
<
Tensor
<
BaseFloat
>>
h_cache
=
GetCacheEncoder
(
input_names
[
2
]);
h_box
->
Reshape
(
h_cache
->
get_shape
());
h_box
->
CopyFromCpu
(
h_cache
->
get_data
().
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
shared_ptr
<
Tensor
<
float
>>
c_cache
=
GetCacheEncoder
(
input_names
[
3
]);
c_box
->
Reshape
(
c_cache
->
get_shape
());
c_box
->
CopyFromCpu
(
c_cache
->
get_data
().
data
());
...
...
@@ -172,10 +190,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
}
LOG
(
INFO
)
<<
"get the model success"
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
assert
(
h_cache
->
get_shape
()
==
h_out
->
shape
());
h_out
->
CopyToCpu
(
h_cache
->
get_data
().
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
assert
(
c_cache
->
get_shape
()
==
c_out
->
shape
());
c_out
->
CopyToCpu
(
c_cache
->
get_data
().
data
());
...
...
@@ -187,13 +207,14 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
col
=
output_shape
[
2
];
vector
<
float
>
inferences_result
;
inferences
->
Resize
(
row
,
col
);
inferences_result
.
resize
(
row
*
col
);
inferences_result
.
resize
(
row
*
col
);
output_tensor
->
CopyToCpu
(
inferences_result
.
data
());
ReleasePredictor
(
predictor
);
for
(
int
row_idx
=
0
;
row_idx
<
row
;
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
col
;
++
col_idx
)
{
(
*
inferences
)(
row_idx
,
col_idx
)
=
inferences_result
[
col
*
row_idx
+
col_idx
];
(
*
inferences
)(
row_idx
,
col_idx
)
=
inferences_result
[
col
*
row_idx
+
col_idx
];
}
}
}
...
...
speechx/speechx/nnet/paddle_nnet.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "nnet/nnet_interface.h"
#include "base/common.h"
#include "nnet/nnet_interface.h"
#include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h"
...
...
@@ -24,19 +38,27 @@ struct ModelOptions {
std
::
string
cache_shape
;
bool
enable_fc_padding
;
bool
enable_profile
;
ModelOptions
()
:
model_path
(
"../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"
),
params_path
(
"../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"
),
ModelOptions
()
:
model_path
(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdmodel"
),
params_path
(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"
),
thread_num
(
2
),
use_gpu
(
false
),
input_names
(
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"
),
output_names
(
"save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"
),
input_names
(
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_"
"box"
),
output_names
(
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1"
),
cache_names
(
"chunk_state_h_box,chunk_state_c_box"
),
cache_shape
(
"3-1-1024,3-1-1024"
),
switch_ir_optim
(
false
),
enable_fc_padding
(
false
),
enable_profile
(
false
)
{
}
enable_profile
(
false
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"model-path"
,
&
model_path
,
"model file path"
);
...
...
@@ -47,37 +69,37 @@ struct ModelOptions {
opts
->
Register
(
"output-names"
,
&
output_names
,
"paddle output names"
);
opts
->
Register
(
"cache-names"
,
&
cache_names
,
"cache names"
);
opts
->
Register
(
"cache-shape"
,
&
cache_shape
,
"cache shape"
);
opts
->
Register
(
"switch-ir-optiom"
,
&
switch_ir_optim
,
"paddle SwitchIrOptim option"
);
opts
->
Register
(
"enable-fc-padding"
,
&
enable_fc_padding
,
"paddle EnableFCPadding option"
);
opts
->
Register
(
"enable-profile"
,
&
enable_profile
,
"paddle EnableProfile option"
);
opts
->
Register
(
"switch-ir-optiom"
,
&
switch_ir_optim
,
"paddle SwitchIrOptim option"
);
opts
->
Register
(
"enable-fc-padding"
,
&
enable_fc_padding
,
"paddle EnableFCPadding option"
);
opts
->
Register
(
"enable-profile"
,
&
enable_profile
,
"paddle EnableProfile option"
);
}
};
template
<
typename
T
>
template
<
typename
T
>
class
Tensor
{
public:
Tensor
()
{
}
Tensor
(
const
std
::
vector
<
int
>&
shape
)
:
_shape
(
shape
)
{
int
data_size
=
std
::
accumulate
(
_shape
.
begin
(),
_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
public:
Tensor
()
{}
Tensor
(
const
std
::
vector
<
int
>&
shape
)
:
_shape
(
shape
)
{
int
data_size
=
std
::
accumulate
(
_shape
.
begin
(),
_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
LOG
(
INFO
)
<<
"data size: "
<<
data_size
;
_data
.
resize
(
data_size
,
0
);
}
void
reshape
(
const
std
::
vector
<
int
>&
shape
)
{
_shape
=
shape
;
int
data_size
=
std
::
accumulate
(
_shape
.
begin
(),
_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
data_size
=
std
::
accumulate
(
_shape
.
begin
(),
_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
_data
.
resize
(
data_size
,
0
);
}
const
std
::
vector
<
int
>&
get_shape
()
const
{
return
_shape
;
}
std
::
vector
<
T
>&
get_data
()
{
return
_data
;
}
private:
const
std
::
vector
<
int
>&
get_shape
()
const
{
return
_shape
;
}
std
::
vector
<
T
>&
get_data
()
{
return
_data
;
}
private:
std
::
vector
<
int
>
_shape
;
std
::
vector
<
T
>
_data
;
};
...
...
@@ -88,7 +110,8 @@ class PaddleNnet : public NnetInterface {
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
);
virtual
void
Reset
();
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
const
std
::
string
&
name
);
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
const
std
::
string
&
name
);
void
InitCacheEncouts
(
const
ModelOptions
&
opts
);
private:
...
...
speechx/speechx/utils/file_utils.cc
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "utils/file_utils.h"
namespace
ppspeech
{
...
...
@@ -17,5 +31,4 @@ bool ReadFileToVector(const std::string& filename,
return
true
;
}
}
\ No newline at end of file
speechx/speechx/utils/file_utils.h
浏览文件 @
41feecbd
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
namespace
ppspeech
{
bool
ReadFileToVector
(
const
std
::
string
&
filename
,
std
::
vector
<
std
::
string
>*
data
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录