Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
bb07144c
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bb07144c
编写于
3月 14, 2022
作者:
Y
YangZhou
提交者:
GitHub
3月 14, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1559 from SmileGoat/align_nnet_decoder
align nnet decoder & refactor
上级
bedd2de4
5a0e0b02
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
190 addition
and
170 deletion
+190
-170
speechx/examples/decoder/CMakeLists.txt
speechx/examples/decoder/CMakeLists.txt
+3
-3
speechx/examples/decoder/offline_decoder_main.cc
speechx/examples/decoder/offline_decoder_main.cc
+40
-15
speechx/speechx/decoder/ctc_beam_search_decoder.cc
speechx/speechx/decoder/ctc_beam_search_decoder.cc
+8
-10
speechx/speechx/decoder/ctc_beam_search_decoder.h
speechx/speechx/decoder/ctc_beam_search_decoder.h
+4
-6
speechx/speechx/frontend/raw_audio.h
speechx/speechx/frontend/raw_audio.h
+9
-7
speechx/speechx/nnet/decodable-itf.h
speechx/speechx/nnet/decodable-itf.h
+45
-72
speechx/speechx/nnet/decodable.cc
speechx/speechx/nnet/decodable.cc
+38
-11
speechx/speechx/nnet/decodable.h
speechx/speechx/nnet/decodable.h
+17
-7
speechx/speechx/nnet/nnet_interface.h
speechx/speechx/nnet/nnet_interface.h
+4
-2
speechx/speechx/nnet/paddle_nnet.cc
speechx/speechx/nnet/paddle_nnet.cc
+15
-29
speechx/speechx/nnet/paddle_nnet.h
speechx/speechx/nnet/paddle_nnet.h
+7
-8
未找到文件。
speechx/examples/decoder/CMakeLists.txt
浏览文件 @
bb07144c
cmake_minimum_required
(
VERSION 3.14 FATAL_ERROR
)
cmake_minimum_required
(
VERSION 3.14 FATAL_ERROR
)
add_executable
(
offline-decoder-main
${
CMAKE_CURRENT_SOURCE_DIR
}
/offline-decoder-main.cc
)
add_executable
(
offline_decoder_main
${
CMAKE_CURRENT_SOURCE_DIR
}
/offline_decoder_main.cc
)
target_include_directories
(
offline-decoder-main PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_include_directories
(
offline_decoder_main PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
${
DEPS
}
)
target_link_libraries
(
offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
${
DEPS
}
)
\ No newline at end of file
speechx/examples/decoder/offline
-decoder-
main.cc
→
speechx/examples/decoder/offline
_decoder_
main.cc
浏览文件 @
bb07144c
...
@@ -17,50 +17,75 @@
...
@@ -17,50 +17,75 @@
#include "base/flags.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/raw_audio.h"
#include "kaldi/util/table-types.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
#include "nnet/paddle_nnet.h"
DEFINE_string
(
feature_respecifier
,
""
,
"test nnet prob"
);
DEFINE_string
(
feature_respecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
model_path
,
"avg_1.jit.pdmodel"
,
"paddle nnet model"
);
DEFINE_string
(
param_path
,
"avg_1.jit.pdiparams"
,
"paddle nnet model param"
);
DEFINE_string
(
dict_file
,
"vocab.txt"
,
"vocabulary of lm"
);
DEFINE_string
(
lm_path
,
"lm.klm"
,
"language model"
);
using
kaldi
::
BaseFloat
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
using
std
::
vector
;
// void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int
main
(
int
argc
,
char
*
argv
[])
{
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InitGoogleLogging
(
argv
[
0
]);
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_respecifier
);
FLAGS_feature_respecifier
);
std
::
string
model_graph
=
FLAGS_model_path
;
std
::
string
model_params
=
FLAGS_param_path
;
std
::
string
dict_file
=
FLAGS_dict_file
;
std
::
string
lm_path
=
FLAGS_lm_path
;
// test nnet_output --> decoder result
int32
num_done
=
0
,
num_err
=
0
;
int32
num_done
=
0
,
num_err
=
0
;
ppspeech
::
CTCBeamSearchOptions
opts
;
ppspeech
::
CTCBeamSearchOptions
opts
;
opts
.
dict_file
=
dict_file
;
opts
.
lm_path
=
lm_path
;
ppspeech
::
CTCBeamSearch
decoder
(
opts
);
ppspeech
::
CTCBeamSearch
decoder
(
opts
);
ppspeech
::
ModelOptions
model_opts
;
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
model_graph
;
model_opts
.
params_path
=
model_params
;
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
new
ppspeech
::
PaddleNnet
(
model_opts
));
new
ppspeech
::
PaddleNnet
(
model_opts
));
std
::
shared_ptr
<
ppspeech
::
RawDataCache
>
raw_data
(
new
ppspeech
::
RawDataCache
());
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
));
new
ppspeech
::
Decodable
(
nnet
,
raw_data
));
//
int32 chunk_size = 35;
int32
chunk_size
=
35
;
decoder
.
InitDecoder
();
decoder
.
InitDecoder
();
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
string
utt
=
feature_reader
.
Key
();
string
utt
=
feature_reader
.
Key
();
const
kaldi
::
Matrix
<
BaseFloat
>
feature
=
feature_reader
.
Value
();
const
kaldi
::
Matrix
<
BaseFloat
>
feature
=
feature_reader
.
Value
();
decodable
->
FeedFeatures
(
feature
);
raw_data
->
SetDim
(
feature
.
NumCols
());
decoder
.
AdvanceDecode
(
decodable
,
8
);
int32
row_idx
=
0
;
decodable
->
InputFinished
();
int32
num_chunks
=
feature
.
NumRows
()
/
chunk_size
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_chunks
;
++
chunk_idx
)
{
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
feature_chunk
(
chunk_size
*
feature
.
NumCols
());
for
(
int
row_id
=
0
;
row_id
<
chunk_size
;
++
row_id
)
{
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
tmp
(
feature
,
row_idx
);
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
f_chunk_tmp
(
feature_chunk
.
Data
()
+
row_id
*
feature
.
NumCols
(),
feature
.
NumCols
());
f_chunk_tmp
.
CopyFromVec
(
tmp
);
row_idx
++
;
}
raw_data
->
Accept
(
feature_chunk
);
if
(
chunk_idx
==
num_chunks
-
1
)
{
raw_data
->
SetFinished
();
}
decoder
.
AdvanceDecode
(
decodable
);
}
std
::
string
result
;
std
::
string
result
;
result
=
decoder
.
GetFinalBestPath
();
result
=
decoder
.
GetFinalBestPath
();
KALDI_LOG
<<
" the result of "
<<
utt
<<
" is "
<<
result
;
KALDI_LOG
<<
" the result of "
<<
utt
<<
" is "
<<
result
;
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.cc
浏览文件 @
bb07144c
...
@@ -79,21 +79,19 @@ void CTCBeamSearch::Decode(
...
@@ -79,21 +79,19 @@ void CTCBeamSearch::Decode(
return
;
return
;
}
}
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
;
}
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
+
1
;
}
// todo rename, refactor
// todo rename, refactor
void
CTCBeamSearch
::
AdvanceDecode
(
void
CTCBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
{
int
max_frames
)
{
while
(
1
)
{
while
(
max_frames
>
0
)
{
vector
<
vector
<
BaseFloat
>>
likelihood
;
vector
<
vector
<
BaseFloat
>>
likelihood
;
if
(
decodable
->
IsLastFrame
(
NumFrameDecoded
()
+
1
))
{
vector
<
BaseFloat
>
frame_prob
;
break
;
bool
flag
=
}
decodable
->
FrameLogLikelihood
(
num_frame_decoded_
,
&
frame_prob
);
likelihood
.
push_back
(
if
(
flag
==
false
)
break
;
decodable
->
FrameLogLikelihood
(
NumFrameDecoded
()
+
1
)
);
likelihood
.
push_back
(
frame_prob
);
AdvanceDecoding
(
likelihood
);
AdvanceDecoding
(
likelihood
);
max_frames
--
;
}
}
}
}
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.h
浏览文件 @
bb07144c
...
@@ -32,8 +32,8 @@ struct CTCBeamSearchOptions {
...
@@ -32,8 +32,8 @@ struct CTCBeamSearchOptions {
int
cutoff_top_n
;
int
cutoff_top_n
;
int
num_proc_bsearch
;
int
num_proc_bsearch
;
CTCBeamSearchOptions
()
CTCBeamSearchOptions
()
:
dict_file
(
"
./model/words
.txt"
),
:
dict_file
(
"
vocab
.txt"
),
lm_path
(
"
./model/lm.arpa
"
),
lm_path
(
"
lm.klm
"
),
alpha
(
1.9
f
),
alpha
(
1.9
f
),
beta
(
5.0
),
beta
(
5.0
),
beam_size
(
300
),
beam_size
(
300
),
...
@@ -68,8 +68,7 @@ class CTCBeamSearch {
...
@@ -68,8 +68,7 @@ class CTCBeamSearch {
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
std
::
vector
<
std
::
string
>&
nbest_words
);
void
AdvanceDecode
(
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
int
max_frames
);
void
Reset
();
void
Reset
();
private:
private:
...
@@ -83,7 +82,6 @@ class CTCBeamSearch {
...
@@ -83,7 +82,6 @@ class CTCBeamSearch {
CTCBeamSearchOptions
opts_
;
CTCBeamSearchOptions
opts_
;
std
::
shared_ptr
<
Scorer
>
init_ext_scorer_
;
// todo separate later
std
::
shared_ptr
<
Scorer
>
init_ext_scorer_
;
// todo separate later
// std::vector<DecodeResult> decoder_results_;
std
::
vector
<
std
::
string
>
vocabulary_
;
// todo remove later
std
::
vector
<
std
::
string
>
vocabulary_
;
// todo remove later
size_t
blank_id
;
size_t
blank_id
;
int
space_id
;
int
space_id
;
...
...
speechx/speechx/frontend/raw_audio.h
浏览文件 @
bb07144c
...
@@ -18,6 +18,8 @@
...
@@ -18,6 +18,8 @@
#include "base/common.h"
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/feature_extractor_interface.h"
#pragma once
namespace
ppspeech
{
namespace
ppspeech
{
class
RawAudioCache
:
public
FeatureExtractorInterface
{
class
RawAudioCache
:
public
FeatureExtractorInterface
{
...
@@ -45,13 +47,12 @@ class RawAudioCache : public FeatureExtractorInterface {
...
@@ -45,13 +47,12 @@ class RawAudioCache : public FeatureExtractorInterface {
DISALLOW_COPY_AND_ASSIGN
(
RawAudioCache
);
DISALLOW_COPY_AND_ASSIGN
(
RawAudioCache
);
};
};
// it is a data
source to test
different frontend module.
// it is a data
source for testing
different frontend module.
// it
Accepts waves or feats.
// it
accepts waves or feats.
class
RawDataCache
:
public
FeatureExtractorInterface
{
class
RawDataCache
:
public
FeatureExtractorInterface
{
public:
public:
explicit
RawDataCache
()
{
finished_
=
false
;
}
explicit
RawDataCache
()
{
finished_
=
false
;
}
virtual
void
Accept
(
virtual
void
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
inputs
)
{
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
inputs
)
{
data_
=
inputs
;
data_
=
inputs
;
}
}
virtual
bool
Read
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
)
{
virtual
bool
Read
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
)
{
...
@@ -62,14 +63,15 @@ class RawDataCache: public FeatureExtractorInterface {
...
@@ -62,14 +63,15 @@ class RawDataCache: public FeatureExtractorInterface {
data_
.
Resize
(
0
);
data_
.
Resize
(
0
);
return
true
;
return
true
;
}
}
//the dim is data_ length
virtual
size_t
Dim
()
const
{
return
dim_
;
}
virtual
size_t
Dim
()
const
{
return
data_
.
Dim
();
}
virtual
void
SetFinished
()
{
finished_
=
true
;
}
virtual
void
SetFinished
()
{
finished_
=
true
;
}
virtual
bool
IsFinished
()
const
{
return
finished_
;
}
virtual
bool
IsFinished
()
const
{
return
finished_
;
}
void
SetDim
(
int32
dim
)
{
dim_
=
dim
;
}
private:
private:
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
data_
;
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
data_
;
bool
finished_
;
bool
finished_
;
int32
dim_
;
DISALLOW_COPY_AND_ASSIGN
(
RawDataCache
);
DISALLOW_COPY_AND_ASSIGN
(
RawDataCache
);
};
};
...
...
speechx/speechx/nnet/decodable-itf.h
浏览文件 @
bb07144c
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
...
@@ -56,10 +42,8 @@ namespace kaldi {
...
@@ -56,10 +42,8 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in
There are two ways these are used: the old online-decoding code, in ../online/,
../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do:
code, the decoder would do:
\code{.cc}
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
...
@@ -68,16 +52,13 @@ namespace kaldi {
...
@@ -68,16 +52,13 @@ namespace kaldi {
\endcode
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call,
online-decoding mode is still supported, it is what happens when you call, for
for
example, LatticeFasterDecoder::Decode().
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for
control endpointing. In the "new" decoding code, you don't call (for example)
example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
something like this:
...
@@ -87,8 +68,7 @@ namespace kaldi {
...
@@ -87,8 +68,7 @@ namespace kaldi {
}
}
\endcode
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function
you are starting with a matrix of features, the NumFramesReady() function will
will
always just return the number of frames in the file, and IsLastFrame() will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
return true for the last frame.
...
@@ -102,39 +82,30 @@ namespace kaldi {
...
@@ -102,39 +82,30 @@ namespace kaldi {
class
DecodableInterface
{
class
DecodableInterface
{
public:
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// frame
/// before calling this.
/// before calling this.
virtual
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
)
=
0
;
virtual
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
)
=
0
;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online
/// be careful). Caution: the behavior of this function in an online setting
/// setting
/// is being changed somewhat. In future it may return false in cases where
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we
/// we haven't yet decided to terminate decoding, but later true if we decide
/// decide
/// to terminate decoding. The plan in future is to rely more on
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// LDA
/// features to be flushed out for compatibility with the baseline setup.
/// features to be flushed out for compatibility with the baseline setup.
virtual
bool
IsLastFrame
(
int32
frame
)
const
=
0
;
virtual
bool
IsLastFrame
(
int32
frame
)
const
=
0
;
/// The call NumFramesReady() will return the number of frames currently
/// The call NumFramesReady() will return the number of frames currently available
/// available
/// for this decodable object. This is for use in setups where you don't want the
/// for this decodable object. This is for use in setups where you don't
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// want the
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// and I hope, going forward, to rely on this mechanism more than
/// IsLastFrame to
/// know when to stop decoding.
/// know when to stop decoding.
virtual
int32
NumFramesReady
()
const
{
virtual
int32
NumFramesReady
()
const
{
KALDI_ERR
KALDI_ERR
<<
"NumFramesReady() not implemented for this decodable type."
;
<<
"NumFramesReady() not implemented for this decodable type."
;
return
-
1
;
return
-
1
;
}
}
...
@@ -143,7 +114,9 @@ class DecodableInterface {
...
@@ -143,7 +114,9 @@ class DecodableInterface {
/// this is for compatibility with OpenFst).
/// this is for compatibility with OpenFst).
virtual
int32
NumIndices
()
const
=
0
;
virtual
int32
NumIndices
()
const
=
0
;
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
)
=
0
;
virtual
bool
FrameLogLikelihood
(
int32
frame
,
std
::
vector
<
kaldi
::
BaseFloat
>*
likelihood
)
=
0
;
virtual
~
DecodableInterface
()
{}
virtual
~
DecodableInterface
()
{}
};
};
...
...
speechx/speechx/nnet/decodable.cc
浏览文件 @
bb07144c
...
@@ -18,9 +18,16 @@ namespace ppspeech {
...
@@ -18,9 +18,16 @@ namespace ppspeech {
using
kaldi
::
BaseFloat
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
using
kaldi
::
Vector
;
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
)
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
,
:
frontend_
(
NULL
),
nnet_
(
nnet
),
finished_
(
false
),
frames_ready_
(
0
)
{}
const
std
::
shared_ptr
<
FeatureExtractorInterface
>&
frontend
)
:
frontend_
(
frontend
),
nnet_
(
nnet
),
finished_
(
false
),
frame_offset_
(
0
),
frames_ready_
(
0
)
{}
void
Decodable
::
Acceptlikelihood
(
const
Matrix
<
BaseFloat
>&
likelihood
)
{
void
Decodable
::
Acceptlikelihood
(
const
Matrix
<
BaseFloat
>&
likelihood
)
{
frames_ready_
+=
likelihood
.
NumRows
();
frames_ready_
+=
likelihood
.
NumRows
();
...
@@ -31,26 +38,46 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
...
@@ -31,26 +38,46 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
bool
Decodable
::
IsLastFrame
(
int32
frame
)
const
{
bool
Decodable
::
IsLastFrame
(
int32
frame
)
const
{
CHECK_LE
(
frame
,
frames_ready_
);
CHECK_LE
(
frame
,
frames_ready_
);
return
finished_
&&
(
frame
==
frames_ready_
-
1
);
return
IsInputFinished
()
&&
(
frame
==
frames_ready_
-
1
);
}
}
int32
Decodable
::
NumIndices
()
const
{
return
0
;
}
int32
Decodable
::
NumIndices
()
const
{
return
0
;
}
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
return
0
;
}
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
CHECK_LE
(
index
,
nnet_cache_
.
NumCols
());
return
0
;
}
void
Decodable
::
FeedFeatures
(
const
Matrix
<
kaldi
::
BaseFloat
>&
features
)
{
bool
Decodable
::
EnsureFrameHaveComputed
(
int32
frame
)
{
nnet_
->
FeedForward
(
features
,
&
nnet_cache_
);
if
(
frame
>=
frames_ready_
)
{
return
AdvanceChunk
();
}
return
true
;
}
bool
Decodable
::
AdvanceChunk
()
{
Vector
<
BaseFloat
>
features
;
if
(
frontend_
->
Read
(
&
features
)
==
false
)
{
return
false
;
}
int32
nnet_dim
=
0
;
Vector
<
BaseFloat
>
inferences
;
nnet_
->
FeedForward
(
features
,
frontend_
->
Dim
(),
&
inferences
,
&
nnet_dim
);
nnet_cache_
.
Resize
(
inferences
.
Dim
()
/
nnet_dim
,
nnet_dim
);
nnet_cache_
.
CopyRowsFromVec
(
inferences
);
frame_offset_
=
frames_ready_
;
frames_ready_
+=
nnet_cache_
.
NumRows
();
frames_ready_
+=
nnet_cache_
.
NumRows
();
return
;
return
true
;
}
}
std
::
vector
<
BaseFloat
>
Decodable
::
FrameLogLikelihood
(
int32
frame
)
{
bool
Decodable
::
FrameLogLikelihood
(
int32
frame
,
vector
<
BaseFloat
>*
likelihood
)
{
std
::
vector
<
BaseFloat
>
result
;
std
::
vector
<
BaseFloat
>
result
;
result
.
reserve
(
nnet_cache_
.
NumCols
());
if
(
EnsureFrameHaveComputed
(
frame
)
==
false
)
return
false
;
likelihood
->
resize
(
nnet_cache_
.
NumCols
());
for
(
int32
idx
=
0
;
idx
<
nnet_cache_
.
NumCols
();
++
idx
)
{
for
(
int32
idx
=
0
;
idx
<
nnet_cache_
.
NumCols
();
++
idx
)
{
result
[
idx
]
=
nnet_cache_
(
frame
,
idx
);
(
*
likelihood
)[
idx
]
=
nnet_cache_
(
frame
-
frame_offset_
,
idx
);
}
}
return
result
;
return
true
;
}
}
void
Decodable
::
Reset
()
{
void
Decodable
::
Reset
()
{
...
...
speechx/speechx/nnet/decodable.h
浏览文件 @
bb07144c
...
@@ -24,25 +24,35 @@ struct DecodableOpts;
...
@@ -24,25 +24,35 @@ struct DecodableOpts;
class
Decodable
:
public
kaldi
::
DecodableInterface
{
class
Decodable
:
public
kaldi
::
DecodableInterface
{
public:
public:
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
);
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
,
const
std
::
shared_ptr
<
FeatureExtractorInterface
>&
frontend
);
// void Init(DecodableOpts config);
// void Init(DecodableOpts config);
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
virtual
bool
IsLastFrame
(
int32
frame
)
const
;
virtual
bool
IsLastFrame
(
int32
frame
)
const
;
virtual
int32
NumIndices
()
const
;
virtual
int32
NumIndices
()
const
;
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
);
virtual
bool
FrameLogLikelihood
(
int32
frame
,
void
Acceptlikelihood
(
std
::
vector
<
kaldi
::
BaseFloat
>*
likelihood
);
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
// remove later
// for offline test
void
FeedFeatures
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
feature
);
// only for test, todo remove later
void
Reset
();
void
Reset
();
void
InputFinished
()
{
finished_
=
true
;
}
bool
IsInputFinished
()
const
{
return
frontend_
->
IsFinished
();
}
bool
EnsureFrameHaveComputed
(
int32
frame
);
private:
private:
bool
AdvanceChunk
();
std
::
shared_ptr
<
FeatureExtractorInterface
>
frontend_
;
std
::
shared_ptr
<
FeatureExtractorInterface
>
frontend_
;
std
::
shared_ptr
<
NnetInterface
>
nnet_
;
std
::
shared_ptr
<
NnetInterface
>
nnet_
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
nnet_cache_
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
nnet_cache_
;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;
bool
finished_
;
bool
finished_
;
int32
frame_offset_
;
int32
frames_ready_
;
int32
frames_ready_
;
// todo: feature frame mismatch with nnet inference frame
// eg: 35 frame features output 8 frame inferences
// so use subsampled_frame
int32
current_log_post_subsampled_offset_
;
int32
num_chunk_computed_
;
};
};
}
// namespace ppspeech
}
// namespace ppspeech
speechx/speechx/nnet/nnet_interface.h
浏览文件 @
bb07144c
...
@@ -23,8 +23,10 @@ namespace ppspeech {
...
@@ -23,8 +23,10 @@ namespace ppspeech {
class
NnetInterface
{
class
NnetInterface
{
public:
public:
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
virtual
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
)
=
0
;
int32
feature_dim
,
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
inferences
,
int32
*
inference_dim
)
=
0
;
virtual
void
Reset
()
=
0
;
virtual
void
Reset
()
=
0
;
virtual
~
NnetInterface
()
{}
virtual
~
NnetInterface
()
{}
};
};
...
...
speechx/speechx/nnet/paddle_nnet.cc
浏览文件 @
bb07144c
...
@@ -21,6 +21,7 @@ using std::vector;
...
@@ -21,6 +21,7 @@ using std::vector;
using
std
::
string
;
using
std
::
string
;
using
std
::
shared_ptr
;
using
std
::
shared_ptr
;
using
kaldi
::
Matrix
;
using
kaldi
::
Matrix
;
using
kaldi
::
Vector
;
void
PaddleNnet
::
InitCacheEncouts
(
const
ModelOptions
&
opts
)
{
void
PaddleNnet
::
InitCacheEncouts
(
const
ModelOptions
&
opts
)
{
std
::
vector
<
std
::
string
>
cache_names
;
std
::
vector
<
std
::
string
>
cache_names
;
...
@@ -143,34 +144,27 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
...
@@ -143,34 +144,27 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
return
cache_encouts_
[
iter
->
second
];
return
cache_encouts_
[
iter
->
second
];
}
}
void
PaddleNnet
::
FeedForward
(
const
Matrix
<
BaseFloat
>&
features
,
void
PaddleNnet
::
FeedForward
(
const
Vector
<
BaseFloat
>&
features
,
Matrix
<
BaseFloat
>*
inferences
)
{
int32
feature_dim
,
Vector
<
BaseFloat
>*
inferences
,
int32
*
inference_dim
)
{
paddle_infer
::
Predictor
*
predictor
=
GetPredictor
();
paddle_infer
::
Predictor
*
predictor
=
GetPredictor
();
int
row
=
features
.
NumRows
();
int
feat_row
=
features
.
Dim
()
/
feature_dim
;
int
col
=
features
.
NumCols
();
std
::
vector
<
BaseFloat
>
feed_feature
;
// todo refactor feed feature: SmileGoat
feed_feature
.
reserve
(
row
*
col
);
for
(
size_t
row_idx
=
0
;
row_idx
<
features
.
NumRows
();
++
row_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
features
.
NumCols
();
++
col_idx
)
{
feed_feature
.
push_back
(
features
(
row_idx
,
col_idx
));
}
}
std
::
vector
<
std
::
string
>
input_names
=
predictor
->
GetInputNames
();
std
::
vector
<
std
::
string
>
input_names
=
predictor
->
GetInputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
LOG
(
INFO
)
<<
"feat info: row
="
<<
row
<<
", col= "
<<
col
;
LOG
(
INFO
)
<<
"feat info: row
s, cols: "
<<
feat_row
<<
", "
<<
feature_dim
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
predictor
->
GetInputHandle
(
input_names
[
0
]);
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
row
,
col
};
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
feat_row
,
feature_dim
};
input_tensor
->
Reshape
(
INPUT_SHAPE
);
input_tensor
->
Reshape
(
INPUT_SHAPE
);
input_tensor
->
CopyFromCpu
(
fe
ed_feature
.
d
ata
());
input_tensor
->
CopyFromCpu
(
fe
atures
.
D
ata
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
vector
<
int
>
input_len_size
=
{
1
};
std
::
vector
<
int
>
input_len_size
=
{
1
};
input_len
->
Reshape
(
input_len_size
);
input_len
->
Reshape
(
input_len_size
);
std
::
vector
<
int64_t
>
audio_len
;
std
::
vector
<
int64_t
>
audio_len
;
audio_len
.
push_back
(
row
);
audio_len
.
push_back
(
feat_
row
);
input_len
->
CopyFromCpu
(
audio_len
.
data
());
input_len
->
CopyFromCpu
(
audio_len
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_box
=
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_box
=
...
@@ -203,20 +197,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
...
@@ -203,20 +197,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
output_tensor
=
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
output_tensor
=
predictor
->
GetOutputHandle
(
output_names
[
0
]);
predictor
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
output_tensor
->
shape
();
std
::
vector
<
int
>
output_shape
=
output_tensor
->
shape
();
row
=
output_shape
[
1
];
int32
row
=
output_shape
[
1
];
col
=
output_shape
[
2
];
int32
col
=
output_shape
[
2
];
vector
<
float
>
inferences_result
;
inferences
->
Resize
(
row
*
col
);
inferences
->
Resize
(
row
,
col
);
*
inference_dim
=
col
;
inferences_result
.
resize
(
row
*
col
);
output_tensor
->
CopyToCpu
(
inferences
->
Data
());
output_tensor
->
CopyToCpu
(
inferences_result
.
data
());
ReleasePredictor
(
predictor
);
ReleasePredictor
(
predictor
);
for
(
int
row_idx
=
0
;
row_idx
<
row
;
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
col
;
++
col_idx
)
{
(
*
inferences
)(
row_idx
,
col_idx
)
=
inferences_result
[
col
*
row_idx
+
col_idx
];
}
}
}
}
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/nnet/paddle_nnet.h
浏览文件 @
bb07144c
...
@@ -39,12 +39,8 @@ struct ModelOptions {
...
@@ -39,12 +39,8 @@ struct ModelOptions {
bool
enable_fc_padding
;
bool
enable_fc_padding
;
bool
enable_profile
;
bool
enable_profile
;
ModelOptions
()
ModelOptions
()
:
model_path
(
:
model_path
(
"avg_1.jit.pdmodel"
),
"../../../../model/paddle_online_deepspeech/model/"
params_path
(
"avg_1.jit.pdiparams"
),
"avg_1.jit.pdmodel"
),
params_path
(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"
),
thread_num
(
2
),
thread_num
(
2
),
use_gpu
(
false
),
use_gpu
(
false
),
input_names
(
input_names
(
...
@@ -107,8 +103,11 @@ class Tensor {
...
@@ -107,8 +103,11 @@ class Tensor {
class
PaddleNnet
:
public
NnetInterface
{
class
PaddleNnet
:
public
NnetInterface
{
public:
public:
PaddleNnet
(
const
ModelOptions
&
opts
);
PaddleNnet
(
const
ModelOptions
&
opts
);
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
virtual
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
);
int32
feature_dim
,
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
inferences
,
int32
*
inference_dim
);
void
Dim
();
virtual
void
Reset
();
virtual
void
Reset
();
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
const
std
::
string
&
name
);
const
std
::
string
&
name
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录