Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5383dff2
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
5383dff2
编写于
3月 14, 2022
作者:
Y
YangZhou
提交者:
GitHub
3月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "align nnet decoder & refactor"
上级
bb07144c
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
170 addition
and
190 deletion
+170
-190
speechx/examples/decoder/CMakeLists.txt
speechx/examples/decoder/CMakeLists.txt
+3
-3
speechx/examples/decoder/offline-decoder-main.cc
speechx/examples/decoder/offline-decoder-main.cc
+15
-40
speechx/speechx/decoder/ctc_beam_search_decoder.cc
speechx/speechx/decoder/ctc_beam_search_decoder.cc
+10
-8
speechx/speechx/decoder/ctc_beam_search_decoder.h
speechx/speechx/decoder/ctc_beam_search_decoder.h
+6
-4
speechx/speechx/frontend/raw_audio.h
speechx/speechx/frontend/raw_audio.h
+7
-9
speechx/speechx/nnet/decodable-itf.h
speechx/speechx/nnet/decodable-itf.h
+72
-45
speechx/speechx/nnet/decodable.cc
speechx/speechx/nnet/decodable.cc
+11
-38
speechx/speechx/nnet/decodable.h
speechx/speechx/nnet/decodable.h
+7
-17
speechx/speechx/nnet/nnet_interface.h
speechx/speechx/nnet/nnet_interface.h
+2
-4
speechx/speechx/nnet/paddle_nnet.cc
speechx/speechx/nnet/paddle_nnet.cc
+29
-15
speechx/speechx/nnet/paddle_nnet.h
speechx/speechx/nnet/paddle_nnet.h
+8
-7
未找到文件。
speechx/examples/decoder/CMakeLists.txt
浏览文件 @
5383dff2
cmake_minimum_required
(
VERSION 3.14 FATAL_ERROR
)
add_executable
(
offline
_decoder_main
${
CMAKE_CURRENT_SOURCE_DIR
}
/offline_decoder_
main.cc
)
target_include_directories
(
offline
_decoder_
main PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
offline
_decoder_
main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
${
DEPS
}
)
add_executable
(
offline
-decoder-main
${
CMAKE_CURRENT_SOURCE_DIR
}
/offline-decoder-
main.cc
)
target_include_directories
(
offline
-decoder-
main PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
offline
-decoder-
main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
${
DEPS
}
)
\ No newline at end of file
speechx/examples/decoder/offline
_decoder_
main.cc
→
speechx/examples/decoder/offline
-decoder-
main.cc
浏览文件 @
5383dff2
...
...
@@ -17,75 +17,50 @@
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/raw_audio.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string
(
feature_respecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
model_path
,
"avg_1.jit.pdmodel"
,
"paddle nnet model"
);
DEFINE_string
(
param_path
,
"avg_1.jit.pdiparams"
,
"paddle nnet model param"
);
DEFINE_string
(
dict_file
,
"vocab.txt"
,
"vocabulary of lm"
);
DEFINE_string
(
lm_path
,
"lm.klm"
,
"language model"
);
DEFINE_string
(
feature_respecifier
,
""
,
"test nnet prob"
);
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
// void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_respecifier
);
std
::
string
model_graph
=
FLAGS_model_path
;
std
::
string
model_params
=
FLAGS_param_path
;
std
::
string
dict_file
=
FLAGS_dict_file
;
std
::
string
lm_path
=
FLAGS_lm_path
;
// test nnet_output --> decoder result
int32
num_done
=
0
,
num_err
=
0
;
ppspeech
::
CTCBeamSearchOptions
opts
;
opts
.
dict_file
=
dict_file
;
opts
.
lm_path
=
lm_path
;
ppspeech
::
CTCBeamSearch
decoder
(
opts
);
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
model_graph
;
model_opts
.
params_path
=
model_params
;
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
new
ppspeech
::
PaddleNnet
(
model_opts
));
std
::
shared_ptr
<
ppspeech
::
RawDataCache
>
raw_data
(
new
ppspeech
::
RawDataCache
());
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
,
raw_data
));
new
ppspeech
::
Decodable
(
nnet
));
int32
chunk_size
=
35
;
//
int32 chunk_size = 35;
decoder
.
InitDecoder
();
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
string
utt
=
feature_reader
.
Key
();
const
kaldi
::
Matrix
<
BaseFloat
>
feature
=
feature_reader
.
Value
();
raw_data
->
SetDim
(
feature
.
NumCols
());
int32
row_idx
=
0
;
int32
num_chunks
=
feature
.
NumRows
()
/
chunk_size
;
for
(
int
chunk_idx
=
0
;
chunk_idx
<
num_chunks
;
++
chunk_idx
)
{
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
feature_chunk
(
chunk_size
*
feature
.
NumCols
());
for
(
int
row_id
=
0
;
row_id
<
chunk_size
;
++
row_id
)
{
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
tmp
(
feature
,
row_idx
);
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
f_chunk_tmp
(
feature_chunk
.
Data
()
+
row_id
*
feature
.
NumCols
(),
feature
.
NumCols
());
f_chunk_tmp
.
CopyFromVec
(
tmp
);
row_idx
++
;
}
raw_data
->
Accept
(
feature_chunk
);
if
(
chunk_idx
==
num_chunks
-
1
)
{
raw_data
->
SetFinished
();
}
decoder
.
AdvanceDecode
(
decodable
);
}
decodable
->
FeedFeatures
(
feature
);
decoder
.
AdvanceDecode
(
decodable
,
8
);
decodable
->
InputFinished
();
std
::
string
result
;
result
=
decoder
.
GetFinalBestPath
();
KALDI_LOG
<<
" the result of "
<<
utt
<<
" is "
<<
result
;
...
...
@@ -96,4 +71,4 @@ int main(int argc, char* argv[]) {
KALDI_LOG
<<
"Done "
<<
num_done
<<
" utterances, "
<<
num_err
<<
" with errors."
;
return
(
num_done
!=
0
?
0
:
1
);
}
}
\ No newline at end of file
speechx/speechx/decoder/ctc_beam_search_decoder.cc
浏览文件 @
5383dff2
...
...
@@ -79,19 +79,21 @@ void CTCBeamSearch::Decode(
return
;
}
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
+
1
;
}
int32
CTCBeamSearch
::
NumFrameDecoded
()
{
return
num_frame_decoded_
;
}
// todo rename, refactor
void
CTCBeamSearch
::
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
)
{
while
(
1
)
{
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
int
max_frames
)
{
while
(
max_frames
>
0
)
{
vector
<
vector
<
BaseFloat
>>
likelihood
;
vector
<
BaseFloat
>
frame_prob
;
bool
flag
=
decodable
->
FrameLogLikelihood
(
num_frame_decoded_
,
&
frame_prob
);
if
(
flag
==
false
)
break
;
likelihood
.
push_back
(
frame_prob
);
if
(
decodable
->
IsLastFrame
(
NumFrameDecoded
()
+
1
))
{
break
;
}
likelihood
.
push_back
(
decodable
->
FrameLogLikelihood
(
NumFrameDecoded
()
+
1
)
);
AdvanceDecoding
(
likelihood
);
max_frames
--
;
}
}
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.h
浏览文件 @
5383dff2
...
...
@@ -32,8 +32,8 @@ struct CTCBeamSearchOptions {
int
cutoff_top_n
;
int
num_proc_bsearch
;
CTCBeamSearchOptions
()
:
dict_file
(
"
vocab
.txt"
),
lm_path
(
"
lm.klm
"
),
:
dict_file
(
"
./model/words
.txt"
),
lm_path
(
"
./model/lm.arpa
"
),
alpha
(
1.9
f
),
beta
(
5.0
),
beam_size
(
300
),
...
...
@@ -68,7 +68,8 @@ class CTCBeamSearch {
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
void
AdvanceDecode
(
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
);
const
std
::
shared_ptr
<
kaldi
::
DecodableInterface
>&
decodable
,
int
max_frames
);
void
Reset
();
private:
...
...
@@ -82,7 +83,8 @@ class CTCBeamSearch {
CTCBeamSearchOptions
opts_
;
std
::
shared_ptr
<
Scorer
>
init_ext_scorer_
;
// todo separate later
std
::
vector
<
std
::
string
>
vocabulary_
;
// todo remove later
// std::vector<DecodeResult> decoder_results_;
std
::
vector
<
std
::
string
>
vocabulary_
;
// todo remove later
size_t
blank_id
;
int
space_id
;
std
::
shared_ptr
<
PathTrie
>
root
;
...
...
speechx/speechx/frontend/raw_audio.h
浏览文件 @
5383dff2
...
...
@@ -18,8 +18,6 @@
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#pragma once
namespace
ppspeech
{
class
RawAudioCache
:
public
FeatureExtractorInterface
{
...
...
@@ -47,12 +45,13 @@ class RawAudioCache : public FeatureExtractorInterface {
DISALLOW_COPY_AND_ASSIGN
(
RawAudioCache
);
};
// it is a data
source for testing
different frontend module.
// it
accepts waves or feats.
class
RawDataCache
:
public
FeatureExtractorInterface
{
// it is a data
source to test
different frontend module.
// it
Accepts waves or feats.
class
RawDataCache
:
public
FeatureExtractorInterface
{
public:
explicit
RawDataCache
()
{
finished_
=
false
;
}
virtual
void
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
inputs
)
{
virtual
void
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
inputs
)
{
data_
=
inputs
;
}
virtual
bool
Read
(
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
feats
)
{
...
...
@@ -63,15 +62,14 @@ class RawDataCache : public FeatureExtractorInterface {
data_
.
Resize
(
0
);
return
true
;
}
virtual
size_t
Dim
()
const
{
return
dim_
;
}
//the dim is data_ length
virtual
size_t
Dim
()
const
{
return
data_
.
Dim
();
}
virtual
void
SetFinished
()
{
finished_
=
true
;
}
virtual
bool
IsFinished
()
const
{
return
finished_
;
}
void
SetDim
(
int32
dim
)
{
dim_
=
dim
;
}
private:
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
data_
;
bool
finished_
;
int32
dim_
;
DISALLOW_COPY_AND_ASSIGN
(
RawDataCache
);
};
...
...
speechx/speechx/nnet/decodable-itf.h
浏览文件 @
5383dff2
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
...
...
@@ -42,8 +56,10 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
There are two ways these are used: the old online-decoding code, in
../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do:
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
...
...
@@ -52,13 +68,16 @@ namespace kaldi {
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for
online-decoding mode is still supported, it is what happens when you call,
for
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
control endpointing. In the "new" decoding code, you don't call (for
example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
...
...
@@ -68,7 +87,8 @@ namespace kaldi {
}
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will
you are starting with a matrix of features, the NumFramesReady() function
will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
...
...
@@ -80,45 +100,52 @@ namespace kaldi {
frame of the file once we've decided to terminate decoding.
*/
class
DecodableInterface
{
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// before calling this.
virtual
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
)
=
0
;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual
bool
IsLastFrame
(
int32
frame
)
const
=
0
;
/// The call NumFramesReady() will return the number of frames currently available
/// for this decodable object. This is for use in setups where you don't want the
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// know when to stop decoding.
virtual
int32
NumFramesReady
()
const
{
KALDI_ERR
<<
"NumFramesReady() not implemented for this decodable type."
;
return
-
1
;
}
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual
int32
NumIndices
()
const
=
0
;
virtual
bool
FrameLogLikelihood
(
int32
frame
,
std
::
vector
<
kaldi
::
BaseFloat
>*
likelihood
)
=
0
;
virtual
~
DecodableInterface
()
{}
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this.
virtual
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
)
=
0
;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online
/// setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we
/// decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or
/// LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual
bool
IsLastFrame
(
int32
frame
)
const
=
0
;
/// The call NumFramesReady() will return the number of frames currently
/// available
/// for this decodable object. This is for use in setups where you don't
/// want the
/// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// and I hope, going forward, to rely on this mechanism more than
/// IsLastFrame to
/// know when to stop decoding.
virtual
int32
NumFramesReady
()
const
{
KALDI_ERR
<<
"NumFramesReady() not implemented for this decodable type."
;
return
-
1
;
}
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual
int32
NumIndices
()
const
=
0
;
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
)
=
0
;
virtual
~
DecodableInterface
()
{}
};
/// @}
}
// namespace Kaldi
...
...
speechx/speechx/nnet/decodable.cc
浏览文件 @
5383dff2
...
...
@@ -18,16 +18,9 @@ namespace ppspeech {
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
using
kaldi
::
Vector
;
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
,
const
std
::
shared_ptr
<
FeatureExtractorInterface
>&
frontend
)
:
frontend_
(
frontend
),
nnet_
(
nnet
),
finished_
(
false
),
frame_offset_
(
0
),
frames_ready_
(
0
)
{}
Decodable
::
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
)
:
frontend_
(
NULL
),
nnet_
(
nnet
),
finished_
(
false
),
frames_ready_
(
0
)
{}
void
Decodable
::
Acceptlikelihood
(
const
Matrix
<
BaseFloat
>&
likelihood
)
{
frames_ready_
+=
likelihood
.
NumRows
();
...
...
@@ -38,46 +31,26 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
bool
Decodable
::
IsLastFrame
(
int32
frame
)
const
{
CHECK_LE
(
frame
,
frames_ready_
);
return
IsInputFinished
()
&&
(
frame
==
frames_ready_
-
1
);
return
finished_
&&
(
frame
==
frames_ready_
-
1
);
}
int32
Decodable
::
NumIndices
()
const
{
return
0
;
}
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
CHECK_LE
(
index
,
nnet_cache_
.
NumCols
());
return
0
;
}
BaseFloat
Decodable
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
return
0
;
}
bool
Decodable
::
EnsureFrameHaveComputed
(
int32
frame
)
{
if
(
frame
>=
frames_ready_
)
{
return
AdvanceChunk
();
}
return
true
;
}
bool
Decodable
::
AdvanceChunk
()
{
Vector
<
BaseFloat
>
features
;
if
(
frontend_
->
Read
(
&
features
)
==
false
)
{
return
false
;
}
int32
nnet_dim
=
0
;
Vector
<
BaseFloat
>
inferences
;
nnet_
->
FeedForward
(
features
,
frontend_
->
Dim
(),
&
inferences
,
&
nnet_dim
);
nnet_cache_
.
Resize
(
inferences
.
Dim
()
/
nnet_dim
,
nnet_dim
);
nnet_cache_
.
CopyRowsFromVec
(
inferences
);
frame_offset_
=
frames_ready_
;
void
Decodable
::
FeedFeatures
(
const
Matrix
<
kaldi
::
BaseFloat
>&
features
)
{
nnet_
->
FeedForward
(
features
,
&
nnet_cache_
);
frames_ready_
+=
nnet_cache_
.
NumRows
();
return
true
;
return
;
}
bool
Decodable
::
FrameLogLikelihood
(
int32
frame
,
vector
<
BaseFloat
>*
likelihood
)
{
std
::
vector
<
BaseFloat
>
Decodable
::
FrameLogLikelihood
(
int32
frame
)
{
std
::
vector
<
BaseFloat
>
result
;
if
(
EnsureFrameHaveComputed
(
frame
)
==
false
)
return
false
;
likelihood
->
resize
(
nnet_cache_
.
NumCols
());
result
.
reserve
(
nnet_cache_
.
NumCols
());
for
(
int32
idx
=
0
;
idx
<
nnet_cache_
.
NumCols
();
++
idx
)
{
(
*
likelihood
)[
idx
]
=
nnet_cache_
(
frame
-
frame_offset_
,
idx
);
result
[
idx
]
=
nnet_cache_
(
frame
,
idx
);
}
return
true
;
return
result
;
}
void
Decodable
::
Reset
()
{
...
...
speechx/speechx/nnet/decodable.h
浏览文件 @
5383dff2
...
...
@@ -24,35 +24,25 @@ struct DecodableOpts;
class
Decodable
:
public
kaldi
::
DecodableInterface
{
public:
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
,
const
std
::
shared_ptr
<
FeatureExtractorInterface
>&
frontend
);
explicit
Decodable
(
const
std
::
shared_ptr
<
NnetInterface
>&
nnet
);
// void Init(DecodableOpts config);
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
virtual
bool
IsLastFrame
(
int32
frame
)
const
;
virtual
int32
NumIndices
()
const
;
virtual
bool
FrameLogLikelihood
(
int32
frame
,
std
::
vector
<
kaldi
::
BaseFloat
>*
likelihood
);
// for offline test
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
);
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
// remove later
void
FeedFeatures
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
feature
);
// only for test, todo remove later
void
Reset
();
bool
IsInputFinished
()
const
{
return
frontend_
->
IsFinished
();
}
bool
EnsureFrameHaveComputed
(
int32
frame
);
void
InputFinished
()
{
finished_
=
true
;
}
private:
bool
AdvanceChunk
();
std
::
shared_ptr
<
FeatureExtractorInterface
>
frontend_
;
std
::
shared_ptr
<
NnetInterface
>
nnet_
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
nnet_cache_
;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;
bool
finished_
;
int32
frame_offset_
;
int32
frames_ready_
;
// todo: feature frame mismatch with nnet inference frame
// eg: 35 frame features output 8 frame inferences
// so use subsampled_frame
int32
current_log_post_subsampled_offset_
;
int32
num_chunk_computed_
;
};
}
// namespace ppspeech
speechx/speechx/nnet/nnet_interface.h
浏览文件 @
5383dff2
...
...
@@ -23,10 +23,8 @@ namespace ppspeech {
class
NnetInterface
{
public:
virtual
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
int32
feature_dim
,
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
inferences
,
int32
*
inference_dim
)
=
0
;
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
)
=
0
;
virtual
void
Reset
()
=
0
;
virtual
~
NnetInterface
()
{}
};
...
...
speechx/speechx/nnet/paddle_nnet.cc
浏览文件 @
5383dff2
...
...
@@ -21,7 +21,6 @@ using std::vector;
using
std
::
string
;
using
std
::
shared_ptr
;
using
kaldi
::
Matrix
;
using
kaldi
::
Vector
;
void
PaddleNnet
::
InitCacheEncouts
(
const
ModelOptions
&
opts
)
{
std
::
vector
<
std
::
string
>
cache_names
;
...
...
@@ -144,27 +143,34 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
return
cache_encouts_
[
iter
->
second
];
}
void
PaddleNnet
::
FeedForward
(
const
Vector
<
BaseFloat
>&
features
,
int32
feature_dim
,
Vector
<
BaseFloat
>*
inferences
,
int32
*
inference_dim
)
{
void
PaddleNnet
::
FeedForward
(
const
Matrix
<
BaseFloat
>&
features
,
Matrix
<
BaseFloat
>*
inferences
)
{
paddle_infer
::
Predictor
*
predictor
=
GetPredictor
();
int
feat_row
=
features
.
Dim
()
/
feature_dim
;
int
row
=
features
.
NumRows
();
int
col
=
features
.
NumCols
();
std
::
vector
<
BaseFloat
>
feed_feature
;
// todo refactor feed feature: SmileGoat
feed_feature
.
reserve
(
row
*
col
);
for
(
size_t
row_idx
=
0
;
row_idx
<
features
.
NumRows
();
++
row_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
features
.
NumCols
();
++
col_idx
)
{
feed_feature
.
push_back
(
features
(
row_idx
,
col_idx
));
}
}
std
::
vector
<
std
::
string
>
input_names
=
predictor
->
GetInputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
LOG
(
INFO
)
<<
"feat info: row
s, cols: "
<<
feat_row
<<
", "
<<
feature_dim
;
LOG
(
INFO
)
<<
"feat info: row
="
<<
row
<<
", col= "
<<
col
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
feat_row
,
feature_dim
};
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
row
,
col
};
input_tensor
->
Reshape
(
INPUT_SHAPE
);
input_tensor
->
CopyFromCpu
(
fe
atures
.
D
ata
());
input_tensor
->
CopyFromCpu
(
fe
ed_feature
.
d
ata
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
vector
<
int
>
input_len_size
=
{
1
};
input_len
->
Reshape
(
input_len_size
);
std
::
vector
<
int64_t
>
audio_len
;
audio_len
.
push_back
(
feat_
row
);
audio_len
.
push_back
(
row
);
input_len
->
CopyFromCpu
(
audio_len
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_box
=
...
...
@@ -197,12 +203,20 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
output_tensor
=
predictor
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
output_tensor
->
shape
();
int32
row
=
output_shape
[
1
];
int32
col
=
output_shape
[
2
];
inferences
->
Resize
(
row
*
col
);
*
inference_dim
=
col
;
output_tensor
->
CopyToCpu
(
inferences
->
Data
());
row
=
output_shape
[
1
];
col
=
output_shape
[
2
];
vector
<
float
>
inferences_result
;
inferences
->
Resize
(
row
,
col
);
inferences_result
.
resize
(
row
*
col
);
output_tensor
->
CopyToCpu
(
inferences_result
.
data
());
ReleasePredictor
(
predictor
);
for
(
int
row_idx
=
0
;
row_idx
<
row
;
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
col
;
++
col_idx
)
{
(
*
inferences
)(
row_idx
,
col_idx
)
=
inferences_result
[
col
*
row_idx
+
col_idx
];
}
}
}
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/nnet/paddle_nnet.h
浏览文件 @
5383dff2
...
...
@@ -39,8 +39,12 @@ struct ModelOptions {
bool
enable_fc_padding
;
bool
enable_profile
;
ModelOptions
()
:
model_path
(
"avg_1.jit.pdmodel"
),
params_path
(
"avg_1.jit.pdiparams"
),
:
model_path
(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdmodel"
),
params_path
(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"
),
thread_num
(
2
),
use_gpu
(
false
),
input_names
(
...
...
@@ -103,11 +107,8 @@ class Tensor {
class
PaddleNnet
:
public
NnetInterface
{
public:
PaddleNnet
(
const
ModelOptions
&
opts
);
virtual
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
int32
feature_dim
,
kaldi
::
Vector
<
kaldi
::
BaseFloat
>*
inferences
,
int32
*
inference_dim
);
void
Dim
();
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
);
virtual
void
Reset
();
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
const
std
::
string
&
name
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录