Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
86eb7189
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
86eb7189
编写于
10月 14, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add u2 recg
上级
7dc9cba3
变更
24
显示空白变更内容
内联
并排
Showing
24 changed file
with
693 addition
and
215 deletion
+693
-215
speechx/examples/codelab/u2/local/decode.sh
speechx/examples/codelab/u2/local/decode.sh
+1
-1
speechx/speechx/decoder/CMakeLists.txt
speechx/speechx/decoder/CMakeLists.txt
+14
-7
speechx/speechx/decoder/common.h
speechx/speechx/decoder/common.h
+29
-2
speechx/speechx/decoder/ctc_beam_search_opt.h
speechx/speechx/decoder/ctc_beam_search_opt.h
+0
-64
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
+12
-2
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
+15
-30
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
...hx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
+9
-19
speechx/speechx/decoder/ctc_prefix_beam_search_result.h
speechx/speechx/decoder/ctc_prefix_beam_search_result.h
+0
-41
speechx/speechx/decoder/decoder_itf.h
speechx/speechx/decoder/decoder_itf.h
+4
-0
speechx/speechx/decoder/param.h
speechx/speechx/decoder/param.h
+5
-30
speechx/speechx/decoder/recognizer.cc
speechx/speechx/decoder/recognizer.cc
+6
-0
speechx/speechx/decoder/recognizer.h
speechx/speechx/decoder/recognizer.h
+4
-9
speechx/speechx/decoder/recognizer_main.cc
speechx/speechx/decoder/recognizer_main.cc
+28
-1
speechx/speechx/decoder/u2_recognizer.cc
speechx/speechx/decoder/u2_recognizer.cc
+209
-0
speechx/speechx/decoder/u2_recognizer.h
speechx/speechx/decoder/u2_recognizer.h
+164
-0
speechx/speechx/decoder/u2_recognizer_main.cc
speechx/speechx/decoder/u2_recognizer_main.cc
+137
-0
speechx/speechx/frontend/audio/feature_pipeline.cc
speechx/speechx/frontend/audio/feature_pipeline.cc
+1
-1
speechx/speechx/frontend/audio/feature_pipeline.h
speechx/speechx/frontend/audio/feature_pipeline.h
+15
-2
speechx/speechx/nnet/ds2_nnet.cc
speechx/speechx/nnet/ds2_nnet.cc
+1
-0
speechx/speechx/nnet/ds2_nnet.h
speechx/speechx/nnet/ds2_nnet.h
+2
-0
speechx/speechx/nnet/nnet_itf.h
speechx/speechx/nnet/nnet_itf.h
+8
-1
speechx/speechx/nnet/u2_nnet.h
speechx/speechx/nnet/u2_nnet.h
+1
-2
speechx/speechx/protocol/websocket/CMakeLists.txt
speechx/speechx/protocol/websocket/CMakeLists.txt
+0
-2
speechx/speechx/protocol/websocket/websocket_server_main.cc
speechx/speechx/protocol/websocket/websocket_server_main.cc
+28
-1
未找到文件。
speechx/examples/codelab/u2/local/decode.sh
浏览文件 @
86eb7189
#!/bin/bash
#!/bin/bash
set
-
x
set
+
x
set
-e
set
-e
.
path.sh
.
path.sh
...
...
speechx/speechx/decoder/CMakeLists.txt
浏览文件 @
86eb7189
...
@@ -9,6 +9,7 @@ add_library(decoder STATIC
...
@@ -9,6 +9,7 @@ add_library(decoder STATIC
ctc_prefix_beam_search_decoder.cc
ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
recognizer.cc
u2_recognizer.cc
)
)
target_link_libraries
(
decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings
)
target_link_libraries
(
decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings
)
...
@@ -28,10 +29,16 @@ endforeach()
...
@@ -28,10 +29,16 @@ endforeach()
# u2
# u2
set
(
bin_name ctc_prefix_beam_search_decoder_main
)
set
(
TEST_BINS
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
u2_recognizer_main
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
ctc_prefix_beam_search_decoder_main
target_link_libraries
(
${
bin_name
}
nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
)
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
foreach
(
bin_name IN LISTS TEST_BINS
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
)
add_executable
(
${
bin_name
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
bin_name
}
.cc
)
\ No newline at end of file
target_include_directories
(
${
bin_name
}
PRIVATE
${
SPEECHX_ROOT
}
${
SPEECHX_ROOT
}
/kaldi
)
target_link_libraries
(
${
bin_name
}
nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
)
target_compile_options
(
${
bin_name
}
PRIVATE
${
PADDLE_COMPILE_FLAGS
}
)
target_include_directories
(
${
bin_name
}
PRIVATE
${
pybind11_INCLUDE_DIRS
}
${
PROJECT_SOURCE_DIR
}
)
target_link_libraries
(
${
bin_name
}
${
PYTHON_LIBRARIES
}
${
PADDLE_LINK_FLAGS
}
)
endforeach
()
\ No newline at end of file
speechx/speechx/decoder/common.h
浏览文件 @
86eb7189
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -12,10 +13,36 @@
...
@@ -12,10 +13,36 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "base/basic_types.h"
#pragma once
#include "base/common.h"
struct
DecoderResult
{
struct
DecoderResult
{
BaseFloat
acoustic_score
;
BaseFloat
acoustic_score
;
std
::
vector
<
int32
>
words_idx
;
std
::
vector
<
int32
>
words_idx
;
std
::
vector
<
pair
<
int32
,
int32
>>
time_stamp
;
std
::
vector
<
std
::
pair
<
int32
,
int32
>>
time_stamp
;
};
namespace
ppspeech
{
struct
WordPiece
{
std
::
string
word
;
int
start
=
-
1
;
int
end
=
-
1
;
WordPiece
(
std
::
string
word
,
int
start
,
int
end
)
:
word
(
std
::
move
(
word
)),
start
(
start
),
end
(
end
)
{}
};
};
struct
DecodeResult
{
float
score
=
-
kBaseFloatMax
;
std
::
string
sentence
;
std
::
vector
<
WordPiece
>
word_pieces
;
static
bool
CompareFunc
(
const
DecodeResult
&
a
,
const
DecodeResult
&
b
)
{
return
a
.
score
>
b
.
score
;
}
};
}
// namespace ppspeech
speechx/speechx/decoder/ctc_beam_search_opt.h
浏览文件 @
86eb7189
...
@@ -76,68 +76,4 @@ struct CTCBeamSearchOptions {
...
@@ -76,68 +76,4 @@ struct CTCBeamSearchOptions {
}
}
};
};
// used by u2 model
struct
CTCBeamSearchDecoderOptions
{
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 67=16*4 + 3, stride is 64=16*4
int
chunk_size
;
int
num_left_chunks
;
// final_score = rescoring_weight * rescoring_score + ctc_weight *
// ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores
// in the following two search methods are different. For
// CtcPrefixBeamSerch,
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
// max(viterbi) path score + context score So we should carefully set
// ctc_weight accroding to the search methods.
float
ctc_weight
;
float
rescoring_weight
;
float
reverse_weight
;
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions
ctc_prefix_search_opts
;
CTCBeamSearchDecoderOptions
()
:
chunk_size
(
16
),
num_left_chunks
(
-
1
),
ctc_weight
(
0.5
),
rescoring_weight
(
1.0
),
reverse_weight
(
0.0
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
std
::
string
module
=
"DecoderConfig: "
;
opts
->
Register
(
"chunk-size"
,
&
chunk_size
,
module
+
"the frame number of one chunk after subsampling."
);
opts
->
Register
(
"num-left-chunks"
,
&
num_left_chunks
,
module
+
"the left history chunks number."
);
opts
->
Register
(
"ctc-weight"
,
&
ctc_weight
,
module
+
"ctc weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score."
);
opts
->
Register
(
"rescoring-weight"
,
&
rescoring_weight
,
module
+
"attention score weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score."
);
opts
->
Register
(
"reverse-weight"
,
&
reverse_weight
,
module
+
"reverse decoder weight. rescoring_score = "
"left_to_right_score * (1 - reverse_weight) + "
"right_to_left_score * reverse_weight."
);
}
};
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
浏览文件 @
86eb7189
...
@@ -30,8 +30,14 @@ using paddle::platform::TracerEventType;
...
@@ -30,8 +30,14 @@ using paddle::platform::TracerEventType;
namespace
ppspeech
{
namespace
ppspeech
{
CTCPrefixBeamSearch
::
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
)
CTCPrefixBeamSearch
::
CTCPrefixBeamSearch
(
const
std
::
string
vocab_path
,
const
CTCBeamSearchOptions
&
opts
)
:
opts_
(
opts
)
{
:
opts_
(
opts
)
{
unit_table_
=
std
::
shared_ptr
<
fst
::
SymbolTable
>
(
fst
::
SymbolTable
::
ReadText
(
vocab_path
));
CHECK
(
unit_table_
!=
nullptr
);
Reset
();
Reset
();
}
}
...
@@ -322,7 +328,11 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
...
@@ -322,7 +328,11 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
CHECK
(
n_hyps
>
0
);
CHECK
(
n_hyps
>
0
);
CHECK
(
index
<
n_hyps
);
CHECK
(
index
<
n_hyps
);
std
::
vector
<
int
>
one
=
Outputs
()[
index
];
std
::
vector
<
int
>
one
=
Outputs
()[
index
];
return
std
::
string
(
absl
::
StrJoin
(
one
,
kSpaceSymbol
));
std
::
string
sentence
;
for
(
int
i
=
0
;
i
<
one
.
size
();
i
++
){
sentence
+=
unit_table_
->
Find
(
one
[
i
]);
}
return
sentence
;
}
}
std
::
string
CTCPrefixBeamSearch
::
GetBestPath
()
{
std
::
string
CTCPrefixBeamSearch
::
GetBestPath
()
{
...
...
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
浏览文件 @
86eb7189
...
@@ -15,17 +15,21 @@
...
@@ -15,17 +15,21 @@
#pragma once
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_result.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h"
#include "decoder/decoder_itf.h"
#include "fst/symbol-table.h"
namespace
ppspeech
{
namespace
ppspeech
{
class
ContextGraph
;
class
ContextGraph
;
class
CTCPrefixBeamSearch
:
public
DecoderInterface
{
class
CTCPrefixBeamSearch
:
public
DecoderInterface
{
public:
public:
explicit
CTCPrefixBeamSearch
(
const
CTCBeamSearchOptions
&
opts
);
explicit
CTCPrefixBeamSearch
(
const
std
::
string
vocab_path
,
const
CTCBeamSearchOptions
&
opts
);
~
CTCPrefixBeamSearch
()
{}
~
CTCPrefixBeamSearch
()
{}
SearchType
Type
()
const
{
return
SearchType
::
kPrefixBeamSearch
;
}
void
InitDecoder
()
override
;
void
InitDecoder
()
override
;
void
Reset
()
override
;
void
Reset
()
override
;
...
@@ -38,10 +42,9 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -38,10 +42,9 @@ class CTCPrefixBeamSearch : public DecoderInterface {
void
FinalizeSearch
();
void
FinalizeSearch
();
protected:
const
std
::
shared_ptr
<
fst
::
SymbolTable
>
VocabTable
()
const
{
std
::
string
GetBestPath
()
override
;
return
unit_table_
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
override
;
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
)
override
;
const
std
::
vector
<
std
::
vector
<
int
>>&
Inputs
()
const
{
return
hypotheses_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Inputs
()
const
{
return
hypotheses_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Outputs
()
const
{
return
outputs_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Outputs
()
const
{
return
outputs_
;
}
...
@@ -52,6 +55,11 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -52,6 +55,11 @@ class CTCPrefixBeamSearch : public DecoderInterface {
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
{
return
times_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
{
return
times_
;
}
protected:
std
::
string
GetBestPath
()
override
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
()
override
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
GetNBestPath
(
int
n
)
override
;
private:
private:
std
::
string
GetBestPath
(
int
index
);
std
::
string
GetBestPath
(
int
index
);
...
@@ -66,6 +74,7 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -66,6 +74,7 @@ class CTCPrefixBeamSearch : public DecoderInterface {
private:
private:
CTCBeamSearchOptions
opts_
;
CTCBeamSearchOptions
opts_
;
std
::
shared_ptr
<
fst
::
SymbolTable
>
unit_table_
;
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixScoreHash
>
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixScoreHash
>
cur_hyps_
;
cur_hyps_
;
...
@@ -86,28 +95,4 @@ class CTCPrefixBeamSearch : public DecoderInterface {
...
@@ -86,28 +95,4 @@ class CTCPrefixBeamSearch : public DecoderInterface {
};
};
class
CTCPrefixBeamSearchDecoder
:
public
CTCPrefixBeamSearch
{
public:
explicit
CTCPrefixBeamSearchDecoder
(
const
CTCBeamSearchDecoderOptions
&
opts
)
:
CTCPrefixBeamSearch
(
opts
.
ctc_prefix_search_opts
),
opts_
(
opts
)
{}
~
CTCPrefixBeamSearchDecoder
()
{}
private:
CTCBeamSearchDecoderOptions
opts_
;
// cache feature
bool
start_
=
false
;
// false, this is first frame.
// for continues decoding
int
num_frames_
=
0
;
int
global_frame_offset_
=
0
;
const
int
time_stamp_gap_
=
100
;
// timestamp gap between words in a sentence
// std::unique_ptr<CtcEndpoint> ctc_endpointer_;
int
num_frames_in_current_chunk_
=
0
;
std
::
vector
<
DecodeResult
>
result_
;
};
}
// namespace ppspeech
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
浏览文件 @
86eb7189
...
@@ -55,14 +55,12 @@ int main(int argc, char* argv[]) {
...
@@ -55,14 +55,12 @@ int main(int argc, char* argv[]) {
CHECK
(
FLAGS_vocab_path
!=
""
);
CHECK
(
FLAGS_vocab_path
!=
""
);
CHECK
(
FLAGS_model_path
!=
""
);
CHECK
(
FLAGS_model_path
!=
""
);
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
LOG
(
INFO
)
<<
"Reading vocab table "
<<
FLAGS_vocab_path
;
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_rspecifier
);
FLAGS_feature_rspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
LOG
(
INFO
)
<<
"Reading vocab table "
<<
FLAGS_vocab_path
;
fst
::
SymbolTable
*
unit_table
=
fst
::
SymbolTable
::
ReadText
(
FLAGS_vocab_path
);
// nnet
// nnet
ppspeech
::
ModelOptions
model_opts
;
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
FLAGS_model_path
;
model_opts
.
model_path
=
FLAGS_model_path
;
...
@@ -75,16 +73,11 @@ int main(int argc, char* argv[]) {
...
@@ -75,16 +73,11 @@ int main(int argc, char* argv[]) {
new
ppspeech
::
Decodable
(
nnet
,
raw_data
));
new
ppspeech
::
Decodable
(
nnet
,
raw_data
));
// decoder
// decoder
ppspeech
::
CTCBeamSearchDecoderOptions
opts
;
ppspeech
::
CTCBeamSearchOptions
opts
;
opts
.
chunk_size
=
16
;
opts
.
blank
=
0
;
opts
.
num_left_chunks
=
-
1
;
opts
.
first_beam_size
=
10
;
opts
.
ctc_weight
=
0.5
;
opts
.
second_beam_size
=
10
;
opts
.
rescoring_weight
=
1.0
;
ppspeech
::
CTCPrefixBeamSearch
decoder
(
FLAGS_vocab_path
,
opts
);
opts
.
reverse_weight
=
0.3
;
opts
.
ctc_prefix_search_opts
.
blank
=
0
;
opts
.
ctc_prefix_search_opts
.
first_beam_size
=
10
;
opts
.
ctc_prefix_search_opts
.
second_beam_size
=
10
;
ppspeech
::
CTCPrefixBeamSearchDecoder
decoder
(
opts
);
int32
chunk_size
=
FLAGS_receptive_field_length
+
int32
chunk_size
=
FLAGS_receptive_field_length
+
...
@@ -150,17 +143,14 @@ int main(int argc, char* argv[]) {
...
@@ -150,17 +143,14 @@ int main(int argc, char* argv[]) {
// forward nnet
// forward nnet
decoder
.
AdvanceDecode
(
decodable
);
decoder
.
AdvanceDecode
(
decodable
);
LOG
(
INFO
)
<<
"Partial result: "
<<
decoder
.
GetPartialResult
();
}
}
decoder
.
FinalizeSearch
();
decoder
.
FinalizeSearch
();
// get 1-best result
// get 1-best result
std
::
string
result_ints
=
decoder
.
GetFinalBestPath
();
std
::
string
result
=
decoder
.
GetFinalBestPath
();
std
::
vector
<
std
::
string
>
tokenids
=
absl
::
StrSplit
(
result_ints
,
ppspeech
::
kSpaceSymbol
);
std
::
string
result
;
for
(
int
i
=
0
;
i
<
tokenids
.
size
();
i
++
){
result
+=
unit_table
->
Find
(
std
::
stoi
(
tokenids
[
i
]));
}
// after process one utt, then reset state.
// after process one utt, then reset state.
decodable
->
Reset
();
decodable
->
Reset
();
...
...
speechx/speechx/decoder/ctc_prefix_beam_search_result.h
已删除
100644 → 0
浏览文件 @
7dc9cba3
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
namespace
ppspeech
{
struct
WordPiece
{
std
::
string
word
;
int
start
=
-
1
;
int
end
=
-
1
;
WordPiece
(
std
::
string
word
,
int
start
,
int
end
)
:
word
(
std
::
move
(
word
)),
start
(
start
),
end
(
end
)
{}
};
struct
DecodeResult
{
float
score
=
-
kBaseFloatMax
;
std
::
string
sentence
;
std
::
vector
<
WordPiece
>
word_pieces
;
static
bool
CompareFunc
(
const
DecodeResult
&
a
,
const
DecodeResult
&
b
)
{
return
a
.
score
>
b
.
score
;
}
};
}
// namespace ppspeech
speechx/speechx/decoder/decoder_itf.h
浏览文件 @
86eb7189
...
@@ -20,6 +20,10 @@
...
@@ -20,6 +20,10 @@
namespace
ppspeech
{
namespace
ppspeech
{
enum
SearchType
{
kPrefixBeamSearch
=
0
,
kWfstBeamSearch
=
1
,
};
class
DecoderInterface
{
class
DecoderInterface
{
public:
public:
virtual
~
DecoderInterface
()
{}
virtual
~
DecoderInterface
()
{}
...
...
speechx/speechx/decoder/param.h
浏览文件 @
86eb7189
...
@@ -19,12 +19,15 @@
...
@@ -19,12 +19,15 @@
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
#include "frontend/audio/feature_pipeline.h"
// feature
// feature
DEFINE_bool
(
use_fbank
,
false
,
"False for fbank; or linear feature"
);
DEFINE_bool
(
use_fbank
,
false
,
"False for fbank; or linear feature"
);
// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear
// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear
// feature, or fbank");
// feature, or fbank");
DEFINE_int32
(
num_bins
,
161
,
"num bins of mel"
);
DEFINE_int32
(
num_bins
,
161
,
"num bins of mel"
);
DEFINE_string
(
cmvn_file
,
""
,
"read cmvn"
);
DEFINE_string
(
cmvn_file
,
""
,
"read cmvn"
);
// feature sliding window
// feature sliding window
DEFINE_int32
(
receptive_field_length
,
DEFINE_int32
(
receptive_field_length
,
7
,
7
,
...
@@ -33,6 +36,8 @@ DEFINE_int32(downsampling_rate,
...
@@ -33,6 +36,8 @@ DEFINE_int32(downsampling_rate,
4
,
4
,
"two CNN(kernel=3) module downsampling rate."
);
"two CNN(kernel=3) module downsampling rate."
);
DEFINE_int32
(
nnet_decoder_chunk
,
1
,
"paddle nnet forward chunk"
);
DEFINE_int32
(
nnet_decoder_chunk
,
1
,
"paddle nnet forward chunk"
);
// nnet
// nnet
DEFINE_string
(
model_path
,
"avg_1.jit.pdmodel"
,
"paddle nnet model"
);
DEFINE_string
(
model_path
,
"avg_1.jit.pdmodel"
,
"paddle nnet model"
);
DEFINE_string
(
param_path
,
"avg_1.jit.pdiparams"
,
"paddle nnet model param"
);
DEFINE_string
(
param_path
,
"avg_1.jit.pdiparams"
,
"paddle nnet model param"
);
...
@@ -89,34 +94,4 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
...
@@ -89,34 +94,4 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
return
opts
;
return
opts
;
}
}
ModelOptions
InitModelOptions
()
{
ModelOptions
model_opts
;
model_opts
.
model_path
=
FLAGS_model_path
;
model_opts
.
param_path
=
FLAGS_param_path
;
model_opts
.
cache_names
=
FLAGS_model_cache_names
;
model_opts
.
cache_shape
=
FLAGS_model_cache_shapes
;
model_opts
.
input_names
=
FLAGS_model_input_names
;
model_opts
.
output_names
=
FLAGS_model_output_names
;
return
model_opts
;
}
TLGDecoderOptions
InitDecoderOptions
()
{
TLGDecoderOptions
decoder_opts
;
decoder_opts
.
word_symbol_table
=
FLAGS_word_symbol_table
;
decoder_opts
.
fst_path
=
FLAGS_graph_path
;
decoder_opts
.
opts
.
max_active
=
FLAGS_max_active
;
decoder_opts
.
opts
.
beam
=
FLAGS_beam
;
decoder_opts
.
opts
.
lattice_beam
=
FLAGS_lattice_beam
;
return
decoder_opts
;
}
RecognizerResource
InitRecognizerResoure
()
{
RecognizerResource
resource
;
resource
.
acoustic_scale
=
FLAGS_acoustic_scale
;
resource
.
feature_pipeline_opts
=
InitFeaturePipelineOptions
();
resource
.
model_opts
=
InitModelOptions
();
resource
.
tlg_opts
=
InitDecoderOptions
();
return
resource
;
}
}
// namespace ppspeech
}
// namespace ppspeech
speechx/speechx/decoder/recognizer.cc
浏览文件 @
86eb7189
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "decoder/recognizer.h"
#include "decoder/recognizer.h"
namespace
ppspeech
{
namespace
ppspeech
{
using
kaldi
::
Vector
;
using
kaldi
::
Vector
;
...
@@ -23,14 +24,19 @@ using std::vector;
...
@@ -23,14 +24,19 @@ using std::vector;
using
kaldi
::
SubVector
;
using
kaldi
::
SubVector
;
using
std
::
unique_ptr
;
using
std
::
unique_ptr
;
Recognizer
::
Recognizer
(
const
RecognizerResource
&
resource
)
{
Recognizer
::
Recognizer
(
const
RecognizerResource
&
resource
)
{
// resource_ = resource;
// resource_ = resource;
const
FeaturePipelineOptions
&
feature_opts
=
resource
.
feature_pipeline_opts
;
const
FeaturePipelineOptions
&
feature_opts
=
resource
.
feature_pipeline_opts
;
feature_pipeline_
.
reset
(
new
FeaturePipeline
(
feature_opts
));
feature_pipeline_
.
reset
(
new
FeaturePipeline
(
feature_opts
));
std
::
shared_ptr
<
PaddleNnet
>
nnet
(
new
PaddleNnet
(
resource
.
model_opts
));
std
::
shared_ptr
<
PaddleNnet
>
nnet
(
new
PaddleNnet
(
resource
.
model_opts
));
BaseFloat
ac_scale
=
resource
.
acoustic_scale
;
BaseFloat
ac_scale
=
resource
.
acoustic_scale
;
decodable_
.
reset
(
new
Decodable
(
nnet
,
feature_pipeline_
,
ac_scale
));
decodable_
.
reset
(
new
Decodable
(
nnet
,
feature_pipeline_
,
ac_scale
));
decoder_
.
reset
(
new
TLGDecoder
(
resource
.
tlg_opts
));
decoder_
.
reset
(
new
TLGDecoder
(
resource
.
tlg_opts
));
input_finished_
=
false
;
input_finished_
=
false
;
}
}
...
...
speechx/speechx/decoder/recognizer.h
浏览文件 @
86eb7189
...
@@ -25,16 +25,11 @@
...
@@ -25,16 +25,11 @@
namespace
ppspeech
{
namespace
ppspeech
{
struct
RecognizerResource
{
struct
RecognizerResource
{
FeaturePipelineOptions
feature_pipeline_opts
;
FeaturePipelineOptions
feature_pipeline_opts
{}
;
ModelOptions
model_opts
;
ModelOptions
model_opts
{}
;
TLGDecoderOptions
tlg_opts
;
TLGDecoderOptions
tlg_opts
{}
;
// CTCBeamSearchOptions beam_search_opts;
// CTCBeamSearchOptions beam_search_opts;
kaldi
::
BaseFloat
acoustic_scale
;
kaldi
::
BaseFloat
acoustic_scale
{
1.0
};
RecognizerResource
()
:
acoustic_scale
(
1.0
),
feature_pipeline_opts
(),
model_opts
(),
tlg_opts
()
{}
};
};
class
Recognizer
{
class
Recognizer
{
...
...
speechx/speechx/decoder/recognizer_main.cc
浏览文件 @
86eb7189
...
@@ -22,6 +22,33 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
...
@@ -22,6 +22,33 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
ppspeech
::
RecognizerResource
InitRecognizerResoure
()
{
ppspeech
::
RecognizerResource
resource
;
resource
.
acoustic_scale
=
FLAGS_acoustic_scale
;
resource
.
feature_pipeline_opts
=
ppspeech
::
InitFeaturePipelineOptions
();
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
FLAGS_model_path
;
model_opts
.
param_path
=
FLAGS_param_path
;
model_opts
.
cache_names
=
FLAGS_model_cache_names
;
model_opts
.
cache_shape
=
FLAGS_model_cache_shapes
;
model_opts
.
input_names
=
FLAGS_model_input_names
;
model_opts
.
output_names
=
FLAGS_model_output_names
;
model_opts
.
subsample_rate
=
FLAGS_downsampling_rate
;
resource
.
model_opts
=
model_opts
;
ppspeech
::
TLGDecoderOptions
decoder_opts
;
decoder_opts
.
word_symbol_table
=
FLAGS_word_symbol_table
;
decoder_opts
.
fst_path
=
FLAGS_graph_path
;
decoder_opts
.
opts
.
max_active
=
FLAGS_max_active
;
decoder_opts
.
opts
.
beam
=
FLAGS_beam
;
decoder_opts
.
opts
.
lattice_beam
=
FLAGS_lattice_beam
;
resource
.
tlg_opts
=
decoder_opts
;
return
resource
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
...
@@ -29,7 +56,7 @@ int main(int argc, char* argv[]) {
...
@@ -29,7 +56,7 @@ int main(int argc, char* argv[]) {
google
::
InstallFailureSignalHandler
();
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
FLAGS_logtostderr
=
1
;
ppspeech
::
RecognizerResource
resource
=
ppspeech
::
InitRecognizerResoure
();
ppspeech
::
RecognizerResource
resource
=
InitRecognizerResoure
();
ppspeech
::
Recognizer
recognizer
(
resource
);
ppspeech
::
Recognizer
recognizer
(
resource
);
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
...
...
speechx/speechx/decoder/u2_recognizer.cc
0 → 100644
浏览文件 @
86eb7189
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/u2_recognizer.h"
#include "nnet/u2_nnet.h"
namespace
ppspeech
{
using
kaldi
::
Vector
;
using
kaldi
::
VectorBase
;
using
kaldi
::
BaseFloat
;
using
std
::
vector
;
using
kaldi
::
SubVector
;
using
std
::
unique_ptr
;
U2Recognizer
::
U2Recognizer
(
const
U2RecognizerResource
&
resource
)
:
opts_
(
resource
)
{
const
FeaturePipelineOptions
&
feature_opts
=
resource
.
feature_pipeline_opts
;
feature_pipeline_
.
reset
(
new
FeaturePipeline
(
feature_opts
));
std
::
shared_ptr
<
NnetInterface
>
nnet
(
new
U2Nnet
(
resource
.
model_opts
));
BaseFloat
am_scale
=
resource
.
acoustic_scale
;
decodable_
.
reset
(
new
Decodable
(
nnet
,
feature_pipeline_
,
am_scale
));
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
));
unit_table_
=
decoder_
->
VocabTable
();
symbol_table_
=
unit_table_
;
input_finished_
=
false
;
}
void
U2Recognizer
::
Reset
()
{
global_frame_offset_
=
0
;
num_frames_
=
0
;
result_
.
clear
();
feature_pipeline_
->
Reset
();
decodable_
->
Reset
();
decoder_
->
Reset
();
}
void
U2Recognizer
::
ResetContinuousDecoding
()
{
global_frame_offset_
=
num_frames_
;
num_frames_
=
0
;
result_
.
clear
();
feature_pipeline_
->
Reset
();
decodable_
->
Reset
();
decoder_
->
Reset
();
}
void
U2Recognizer
::
Accept
(
const
VectorBase
<
BaseFloat
>&
waves
)
{
feature_pipeline_
->
Accept
(
waves
);
}
void
U2Recognizer
::
Decode
()
{
decoder_
->
AdvanceDecode
(
decodable_
);
}
void
U2Recognizer
::
Rescoring
()
{
// Do attention Rescoring
kaldi
::
Timer
timer
;
AttentionRescoring
();
VLOG
(
1
)
<<
"Rescoring cost latency: "
<<
timer
.
Elapsed
()
<<
" sec."
;
}
void
U2Recognizer
::
UpdateResult
(
bool
finish
)
{
const
auto
&
hypotheses
=
decoder_
->
Outputs
();
const
auto
&
inputs
=
decoder_
->
Inputs
();
const
auto
&
likelihood
=
decoder_
->
Likelihood
();
const
auto
&
times
=
decoder_
->
Times
();
result_
.
clear
();
CHECK_EQ
(
hypotheses
.
size
(),
likelihood
.
size
());
for
(
size_t
i
=
0
;
i
<
hypotheses
.
size
();
i
++
)
{
const
std
::
vector
<
int
>&
hypothesis
=
hypotheses
[
i
];
DecodeResult
path
;
path
.
score
=
likelihood
[
i
];
for
(
size_t
j
=
0
;
j
<
hypothesis
.
size
();
j
++
)
{
std
::
string
word
=
symbol_table_
->
Find
(
hypothesis
[
j
]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if
(
decoder_
->
Type
()
==
kWfstBeamSearch
)
{
path
.
sentence
+=
(
" "
+
word
);
}
else
{
path
.
sentence
+=
(
word
);
}
}
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we use
// time stamp of the input(e2e model unit), which is more accurate, and it
// requires the symbol table of the e2e model used in training.
if
(
unit_table_
!=
nullptr
&&
finish
)
{
int
offset
=
global_frame_offset_
*
FrameShiftInMs
();
const
std
::
vector
<
int
>&
input
=
inputs
[
i
];
const
std
::
vector
<
int
>
time_stamp
=
times
[
i
];
CHECK_EQ
(
input
.
size
(),
time_stamp
.
size
());
for
(
size_t
j
=
0
;
j
<
input
.
size
();
j
++
)
{
std
::
string
word
=
unit_table_
->
Find
(
input
[
j
]);
int
start
=
time_stamp
[
j
]
*
FrameShiftInMs
()
-
time_stamp_gap_
>
0
?
time_stamp
[
j
]
*
FrameShiftInMs
()
-
time_stamp_gap_
:
0
;
if
(
j
>
0
)
{
start
=
(
time_stamp
[
j
]
-
time_stamp
[
j
-
1
])
*
FrameShiftInMs
()
<
time_stamp_gap_
?
(
time_stamp
[
j
-
1
]
+
time_stamp
[
j
])
/
2
*
FrameShiftInMs
()
:
start
;
}
int
end
=
time_stamp
[
j
]
*
FrameShiftInMs
();
if
(
j
<
input
.
size
()
-
1
)
{
end
=
(
time_stamp
[
j
+
1
]
-
time_stamp
[
j
])
*
FrameShiftInMs
()
<
time_stamp_gap_
?
(
time_stamp
[
j
+
1
]
+
time_stamp
[
j
])
/
2
*
FrameShiftInMs
()
:
end
;
}
WordPiece
word_piece
(
word
,
offset
+
start
,
offset
+
end
);
path
.
word_pieces
.
emplace_back
(
word_piece
);
}
}
// if (post_processor_ != nullptr) {
// path.sentence = post_processor_->Process(path.sentence, finish);
// }
result_
.
emplace_back
(
path
);
}
if
(
DecodedSomething
())
{
VLOG
(
1
)
<<
"Partial CTC result "
<<
result_
[
0
].
sentence
;
}
}
void
U2Recognizer
::
AttentionRescoring
()
{
decoder_
->
FinalizeSearch
();
UpdateResult
(
true
);
// No need to do rescoring
if
(
0.0
==
opts_
.
decoder_opts
.
rescoring_weight
)
{
LOG_EVERY_N
(
WARNING
,
3
)
<<
"Not do AttentionRescoring!"
;
return
;
}
LOG_EVERY_N
(
WARNING
,
3
)
<<
"Do AttentionRescoring!"
;
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const
auto
&
hypotheses
=
decoder_
->
Inputs
();
int
num_hyps
=
hypotheses
.
size
();
if
(
num_hyps
<=
0
)
{
return
;
}
kaldi
::
Timer
timer
;
std
::
vector
<
float
>
rescoring_score
;
decodable_
->
AttentionRescoring
(
hypotheses
,
opts_
.
decoder_opts
.
reverse_weight
,
&
rescoring_score
);
VLOG
(
1
)
<<
"Attention Rescoring takes "
<<
timer
.
Elapsed
()
<<
" sec."
;
// combine ctc score and rescoring score
for
(
size_t
i
=
0
;
i
<
num_hyps
;
i
++
)
{
VLOG
(
1
)
<<
"hyp "
<<
i
<<
" rescoring_score: "
<<
rescoring_score
[
i
]
<<
" ctc_score: "
<<
result_
[
i
].
score
;
result_
[
i
].
score
=
opts_
.
decoder_opts
.
rescoring_weight
*
rescoring_score
[
i
]
+
opts_
.
decoder_opts
.
ctc_weight
*
result_
[
i
].
score
;
}
std
::
sort
(
result_
.
begin
(),
result_
.
end
(),
DecodeResult
::
CompareFunc
);
VLOG
(
1
)
<<
"result: "
<<
result_
[
0
].
sentence
<<
" score: "
<<
result_
[
0
].
score
;
}
std
::
string
U2Recognizer
::
GetFinalResult
()
{
return
result_
[
0
].
sentence
;
}
std
::
string
U2Recognizer
::
GetPartialResult
()
{
return
result_
[
0
].
sentence
;
}
void
U2Recognizer
::
SetFinished
()
{
feature_pipeline_
->
SetFinished
();
input_finished_
=
true
;
}
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/decoder/u2_recognizer.h
0 → 100644
浏览文件 @
86eb7189
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "decoder/decoder_itf.h"
#include "frontend/audio/feature_pipeline.h"
#include "nnet/decodable.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
namespace
ppspeech
{
struct
DecodeOptions
{
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 67=16*4 + 3, stride is 64=16*4
int
chunk_size
;
int
num_left_chunks
;
// final_score = rescoring_weight * rescoring_score + ctc_weight *
// ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores
// in the following two search methods are different. For
// CtcPrefixBeamSerch,
// it's a sum(prefix) score + context score For CtcWfstBeamSerch, it's a
// max(viterbi) path score + context score So we should carefully set
// ctc_weight accroding to the search methods.
float
ctc_weight
;
float
rescoring_weight
;
float
reverse_weight
;
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions
ctc_prefix_search_opts
;
DecodeOptions
()
:
chunk_size
(
16
),
num_left_chunks
(
-
1
),
ctc_weight
(
0.5
),
rescoring_weight
(
1.0
),
reverse_weight
(
0.0
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
std
::
string
module
=
"DecoderConfig: "
;
opts
->
Register
(
"chunk-size"
,
&
chunk_size
,
module
+
"the frame number of one chunk after subsampling."
);
opts
->
Register
(
"num-left-chunks"
,
&
num_left_chunks
,
module
+
"the left history chunks number."
);
opts
->
Register
(
"ctc-weight"
,
&
ctc_weight
,
module
+
"ctc weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score."
);
opts
->
Register
(
"rescoring-weight"
,
&
rescoring_weight
,
module
+
"attention score weight for rescore. final_score = "
"rescoring_weight * rescoring_score + ctc_weight * "
"ctc_score."
);
opts
->
Register
(
"reverse-weight"
,
&
reverse_weight
,
module
+
"reverse decoder weight. rescoring_score = "
"left_to_right_score * (1 - reverse_weight) + "
"right_to_left_score * reverse_weight."
);
}
};
struct
U2RecognizerResource
{
FeaturePipelineOptions
feature_pipeline_opts
{};
ModelOptions
model_opts
{};
DecodeOptions
decoder_opts
{};
// CTCBeamSearchOptions beam_search_opts;
kaldi
::
BaseFloat
acoustic_scale
{
1.0
};
std
::
string
vocab_path
{};
};
class
U2Recognizer
{
public:
explicit
U2Recognizer
(
const
U2RecognizerResource
&
resouce
);
void
Reset
();
void
ResetContinuousDecoding
();
void
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
waves
);
void
Decode
();
void
Rescoring
();
std
::
string
GetFinalResult
();
std
::
string
GetPartialResult
();
void
SetFinished
();
bool
IsFinished
()
{
return
input_finished_
;
}
bool
DecodedSomething
()
const
{
return
!
result_
.
empty
()
&&
!
result_
[
0
].
sentence
.
empty
();
}
int
FrameShiftInMs
()
const
{
// one decoder frame length in ms
return
decodable_
->
Nnet
()
->
SubsamplingRate
()
*
feature_pipeline_
->
FrameShift
();
}
const
std
::
vector
<
DecodeResult
>&
Result
()
const
{
return
result_
;
}
private:
void
AttentionRescoring
();
void
UpdateResult
(
bool
finish
=
false
);
private:
U2RecognizerResource
opts_
;
// std::shared_ptr<U2RecognizerResource> resource_;
// U2RecognizerResource resource_;
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline_
;
std
::
shared_ptr
<
Decodable
>
decodable_
;
std
::
unique_ptr
<
CTCPrefixBeamSearch
>
decoder_
;
// e2e unit symbol table
std
::
shared_ptr
<
fst
::
SymbolTable
>
unit_table_
=
nullptr
;
std
::
shared_ptr
<
fst
::
SymbolTable
>
symbol_table_
=
nullptr
;
std
::
vector
<
DecodeResult
>
result_
;
// global decoded frame offset
int
global_frame_offset_
;
// cur decoded frame num
int
num_frames_
;
// timestamp gap between words in a sentence
const
int
time_stamp_gap_
=
100
;
bool
input_finished_
;
};
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/decoder/u2_recognizer_main.cc
0 → 100644
浏览文件 @
86eb7189
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/u2_recognizer.h"
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_double
(
streaming_chunk
,
0.36
,
"streaming feature chunk size"
);
DEFINE_int32
(
sample_rate
,
16000
,
"sample rate"
);
ppspeech
::
U2RecognizerResource
InitOpts
()
{
ppspeech
::
U2RecognizerResource
resource
;
resource
.
acoustic_scale
=
FLAGS_acoustic_scale
;
resource
.
feature_pipeline_opts
=
ppspeech
::
InitFeaturePipelineOptions
();
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
FLAGS_model_path
;
resource
.
model_opts
=
model_opts
;
ppspeech
::
DecodeOptions
decoder_opts
;
decoder_opts
.
chunk_size
=
16
;
decoder_opts
.
num_left_chunks
=
-
1
;
decoder_opts
.
ctc_weight
=
0.5
;
decoder_opts
.
rescoring_weight
=
1.0
;
decoder_opts
.
reverse_weight
=
0.3
;
decoder_opts
.
ctc_prefix_search_opts
.
blank
=
0
;
decoder_opts
.
ctc_prefix_search_opts
.
first_beam_size
=
10
;
decoder_opts
.
ctc_prefix_search_opts
.
second_beam_size
=
10
;
resource
.
decoder_opts
=
decoder_opts
;
return
resource
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
SetUsageMessage
(
"Usage:"
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
int32
num_done
=
0
,
num_err
=
0
;
double
tot_wav_duration
=
0.0
;
ppspeech
::
U2RecognizerResource
resource
=
InitOpts
();
ppspeech
::
U2Recognizer
recognizer
(
resource
);
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
FLAGS_wav_rspecifier
);
kaldi
::
TokenWriter
result_writer
(
FLAGS_result_wspecifier
);
int
sample_rate
=
FLAGS_sample_rate
;
float
streaming_chunk
=
FLAGS_streaming_chunk
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
LOG
(
INFO
)
<<
"sr: "
<<
sample_rate
;
LOG
(
INFO
)
<<
"chunk size (s): "
<<
streaming_chunk
;
LOG
(
INFO
)
<<
"chunk size (sample): "
<<
chunk_sample_size
;
kaldi
::
Timer
timer
;
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
LOG
(
INFO
)
<<
"utt: "
<<
utt
;
LOG
(
INFO
)
<<
"wav dur: "
<<
wave_data
.
Duration
()
<<
" sec."
;
tot_wav_duration
+=
wave_data
.
Duration
();
int32
this_channel
=
0
;
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
int
tot_samples
=
waveform
.
Dim
();
LOG
(
INFO
)
<<
"wav len (sample): "
<<
tot_samples
;
int
sample_offset
=
0
;
while
(
sample_offset
<
tot_samples
)
{
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
(
i
)
=
waveform
(
sample_offset
+
i
);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer
.
Accept
(
wav_chunk
);
if
(
cur_chunk_size
<
chunk_sample_size
)
{
recognizer
.
SetFinished
();
}
recognizer
.
Decode
();
LOG
(
INFO
)
<<
"Pratial result: "
<<
recognizer
.
GetPartialResult
();
// no overlap
sample_offset
+=
cur_chunk_size
;
}
// second pass decoding
recognizer
.
Rescoring
();
std
::
string
result
=
recognizer
.
GetFinalResult
();
recognizer
.
Reset
();
if
(
result
.
empty
())
{
// the TokenWriter can not write empty string.
++
num_err
;
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is empty"
;
continue
;
}
LOG
(
INFO
)
<<
" the result of "
<<
utt
<<
" is "
<<
result
;
result_writer
.
Write
(
utt
,
result
);
++
num_done
;
}
double
elapsed
=
timer
.
Elapsed
();
LOG
(
INFO
)
<<
"Done "
<<
num_done
<<
" out of "
<<
(
num_err
+
num_done
);
LOG
(
INFO
)
<<
"cost:"
<<
elapsed
<<
" sec"
;
LOG
(
INFO
)
<<
"total wav duration is: "
<<
tot_wav_duration
<<
" sec"
;
LOG
(
INFO
)
<<
"the RTF is: "
<<
elapsed
/
tot_wav_duration
;
}
speechx/speechx/frontend/audio/feature_pipeline.cc
浏览文件 @
86eb7189
...
@@ -18,7 +18,7 @@ namespace ppspeech {
...
@@ -18,7 +18,7 @@ namespace ppspeech {
using
std
::
unique_ptr
;
using
std
::
unique_ptr
;
FeaturePipeline
::
FeaturePipeline
(
const
FeaturePipelineOptions
&
opts
)
{
FeaturePipeline
::
FeaturePipeline
(
const
FeaturePipelineOptions
&
opts
)
:
opts_
(
opts
)
{
unique_ptr
<
FrontendInterface
>
data_source
(
unique_ptr
<
FrontendInterface
>
data_source
(
new
ppspeech
::
AudioCache
(
1000
*
kint16max
,
opts
.
to_float32
));
new
ppspeech
::
AudioCache
(
1000
*
kint16max
,
opts
.
to_float32
));
...
...
speechx/speechx/frontend/audio/feature_pipeline.h
浏览文件 @
86eb7189
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#include "frontend/audio/normalizer.h"
#include "frontend/audio/normalizer.h"
namespace
ppspeech
{
namespace
ppspeech
{
struct
FeaturePipelineOptions
{
struct
FeaturePipelineOptions
{
std
::
string
cmvn_file
;
std
::
string
cmvn_file
;
bool
to_float32
;
// true, only for linear feature
bool
to_float32
;
// true, only for linear feature
...
@@ -60,7 +59,21 @@ class FeaturePipeline : public FrontendInterface {
...
@@ -60,7 +59,21 @@ class FeaturePipeline : public FrontendInterface {
virtual
bool
IsFinished
()
const
{
return
base_extractor_
->
IsFinished
();
}
virtual
bool
IsFinished
()
const
{
return
base_extractor_
->
IsFinished
();
}
virtual
void
Reset
()
{
base_extractor_
->
Reset
();
}
virtual
void
Reset
()
{
base_extractor_
->
Reset
();
}
const
FeaturePipelineOptions
&
Config
()
{
return
opts_
;
}
const
BaseFloat
FrameShift
()
const
{
return
opts_
.
fbank_opts
.
frame_opts
.
frame_shift_ms
;
}
const
BaseFloat
FrameLength
()
const
{
return
opts_
.
fbank_opts
.
frame_opts
.
frame_length_ms
;
}
const
BaseFloat
SampleRate
()
const
{
return
opts_
.
fbank_opts
.
frame_opts
.
samp_freq
;
}
private:
private:
FeaturePipelineOptions
opts_
;
std
::
unique_ptr
<
FrontendInterface
>
base_extractor_
;
std
::
unique_ptr
<
FrontendInterface
>
base_extractor_
;
};
};
}
}
// namespace ppspeech
speechx/speechx/nnet/ds2_nnet.cc
浏览文件 @
86eb7189
...
@@ -48,6 +48,7 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
...
@@ -48,6 +48,7 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
}
}
PaddleNnet
::
PaddleNnet
(
const
ModelOptions
&
opts
)
:
opts_
(
opts
)
{
PaddleNnet
::
PaddleNnet
(
const
ModelOptions
&
opts
)
:
opts_
(
opts
)
{
subsampling_rate_
=
opts
.
subsample_rate
;
paddle_infer
::
Config
config
;
paddle_infer
::
Config
config
;
config
.
SetModel
(
opts
.
model_path
,
opts
.
param_path
);
config
.
SetModel
(
opts
.
model_path
,
opts
.
param_path
);
if
(
opts
.
use_gpu
)
{
if
(
opts
.
use_gpu
)
{
...
...
speechx/speechx/nnet/ds2_nnet.h
浏览文件 @
86eb7189
...
@@ -67,6 +67,7 @@ class PaddleNnet : public NnetInterface {
...
@@ -67,6 +67,7 @@ class PaddleNnet : public NnetInterface {
bool
IsLogProb
()
override
{
return
false
;
}
bool
IsLogProb
()
override
{
return
false
;
}
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
const
std
::
string
&
name
);
const
std
::
string
&
name
);
...
@@ -85,6 +86,7 @@ class PaddleNnet : public NnetInterface {
...
@@ -85,6 +86,7 @@ class PaddleNnet : public NnetInterface {
std
::
map
<
paddle_infer
::
Predictor
*
,
int
>
predictor_to_thread_id
;
std
::
map
<
paddle_infer
::
Predictor
*
,
int
>
predictor_to_thread_id
;
std
::
map
<
std
::
string
,
int
>
cache_names_idx_
;
std
::
map
<
std
::
string
,
int
>
cache_names_idx_
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>>
cache_encouts_
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>>
cache_encouts_
;
ModelOptions
opts_
;
ModelOptions
opts_
;
public:
public:
...
...
speechx/speechx/nnet/nnet_itf.h
浏览文件 @
86eb7189
...
@@ -35,6 +35,7 @@ struct ModelOptions {
...
@@ -35,6 +35,7 @@ struct ModelOptions {
std
::
string
cache_shape
;
std
::
string
cache_shape
;
bool
enable_fc_padding
;
bool
enable_fc_padding
;
bool
enable_profile
;
bool
enable_profile
;
int
subsample_rate
;
ModelOptions
()
ModelOptions
()
:
model_path
(
""
),
:
model_path
(
""
),
param_path
(
""
),
param_path
(
""
),
...
@@ -46,7 +47,8 @@ struct ModelOptions {
...
@@ -46,7 +47,8 @@ struct ModelOptions {
cache_shape
(
""
),
cache_shape
(
""
),
switch_ir_optim
(
false
),
switch_ir_optim
(
false
),
enable_fc_padding
(
false
),
enable_fc_padding
(
false
),
enable_profile
(
false
)
{}
enable_profile
(
false
),
subsample_rate
(
0
)
{}
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
void
Register
(
kaldi
::
OptionsItf
*
opts
)
{
opts
->
Register
(
"model-path"
,
&
model_path
,
"model file path"
);
opts
->
Register
(
"model-path"
,
&
model_path
,
"model file path"
);
...
@@ -102,9 +104,14 @@ class NnetInterface {
...
@@ -102,9 +104,14 @@ class NnetInterface {
// true, nnet output is logprob; otherwise is prob,
// true, nnet output is logprob; otherwise is prob,
virtual
bool
IsLogProb
()
=
0
;
virtual
bool
IsLogProb
()
=
0
;
int
SubsamplingRate
()
const
{
return
subsampling_rate_
;
}
// using to get encoder outs. e.g. seq2seq with Attention model.
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual
void
EncoderOuts
(
virtual
void
EncoderOuts
(
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>*
encoder_out
)
const
=
0
;
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>*
encoder_out
)
const
=
0
;
protected:
int
subsampling_rate_
{
1
};
};
};
}
// namespace ppspeech
}
// namespace ppspeech
speechx/speechx/nnet/u2_nnet.h
浏览文件 @
86eb7189
...
@@ -30,7 +30,7 @@ class U2NnetBase : public NnetInterface {
...
@@ -30,7 +30,7 @@ class U2NnetBase : public NnetInterface {
public:
public:
virtual
int
context
()
const
{
return
right_context_
+
1
;
}
virtual
int
context
()
const
{
return
right_context_
+
1
;
}
virtual
int
right_context
()
const
{
return
right_context_
;
}
virtual
int
right_context
()
const
{
return
right_context_
;
}
virtual
int
subsampling_rate
()
const
{
return
subsampling_rate_
;
}
virtual
int
eos
()
const
{
return
eos_
;
}
virtual
int
eos
()
const
{
return
eos_
;
}
virtual
int
sos
()
const
{
return
sos_
;
}
virtual
int
sos
()
const
{
return
sos_
;
}
virtual
int
is_bidecoder
()
const
{
return
is_bidecoder_
;
}
virtual
int
is_bidecoder
()
const
{
return
is_bidecoder_
;
}
...
@@ -64,7 +64,6 @@ class U2NnetBase : public NnetInterface {
...
@@ -64,7 +64,6 @@ class U2NnetBase : public NnetInterface {
protected:
protected:
// model specification
// model specification
int
right_context_
{
0
};
int
right_context_
{
0
};
int
subsampling_rate_
{
1
};
int
sos_
{
0
};
int
sos_
{
0
};
int
eos_
{
0
};
int
eos_
{
0
};
...
...
speechx/speechx/protocol/websocket/CMakeLists.txt
浏览文件 @
86eb7189
# project(websocket)
add_library
(
websocket STATIC
add_library
(
websocket STATIC
websocket_server.cc
websocket_server.cc
websocket_client.cc
websocket_client.cc
...
...
speechx/speechx/protocol/websocket/websocket_server_main.cc
浏览文件 @
86eb7189
...
@@ -17,11 +17,38 @@
...
@@ -17,11 +17,38 @@
DEFINE_int32
(
port
,
8082
,
"websocket listening port"
);
DEFINE_int32
(
port
,
8082
,
"websocket listening port"
);
ppspeech
::
RecognizerResource
InitRecognizerResoure
()
{
ppspeech
::
RecognizerResource
resource
;
resource
.
acoustic_scale
=
FLAGS_acoustic_scale
;
resource
.
feature_pipeline_opts
=
ppspeech
::
InitFeaturePipelineOptions
();
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
FLAGS_model_path
;
model_opts
.
param_path
=
FLAGS_param_path
;
model_opts
.
cache_names
=
FLAGS_model_cache_names
;
model_opts
.
cache_shape
=
FLAGS_model_cache_shapes
;
model_opts
.
input_names
=
FLAGS_model_input_names
;
model_opts
.
output_names
=
FLAGS_model_output_names
;
model_opts
.
subsample_rate
=
FLAGS_downsampling_rate
;
resource
.
model_opts
=
model_opts
;
ppspeech
::
TLGDecoderOptions
decoder_opts
;
decoder_opts
.
word_symbol_table
=
FLAGS_word_symbol_table
;
decoder_opts
.
fst_path
=
FLAGS_graph_path
;
decoder_opts
.
opts
.
max_active
=
FLAGS_max_active
;
decoder_opts
.
opts
.
beam
=
FLAGS_beam
;
decoder_opts
.
opts
.
lattice_beam
=
FLAGS_lattice_beam
;
resource
.
tlg_opts
=
decoder_opts
;
return
resource
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
google
::
InitGoogleLogging
(
argv
[
0
]);
ppspeech
::
RecognizerResource
resource
=
ppspeech
::
InitRecognizerResoure
();
ppspeech
::
RecognizerResource
resource
=
InitRecognizerResoure
();
ppspeech
::
WebSocketServer
server
(
FLAGS_port
,
resource
);
ppspeech
::
WebSocketServer
server
(
FLAGS_port
,
resource
);
LOG
(
INFO
)
<<
"Listening at port "
<<
FLAGS_port
;
LOG
(
INFO
)
<<
"Listening at port "
<<
FLAGS_port
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录