Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
a6b2a0a6
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a6b2a0a6
编写于
10月 24, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cpplint
上级
8271fcfb
变更
33
隐藏空白更改
内联
并排
Showing
33 changed file
with
118 addition
and
103 deletion
+118
-103
.pre-commit-config.yaml
.pre-commit-config.yaml
+8
-1
speechx/speechx/base/basic_types.h
speechx/speechx/base/basic_types.h
+21
-21
speechx/speechx/base/macros.h
speechx/speechx/base/macros.h
+1
-1
speechx/speechx/base/thread_pool.h
speechx/speechx/base/thread_pool.h
+1
-1
speechx/speechx/codelab/nnet/ds2_model_test_main.cc
speechx/speechx/codelab/nnet/ds2_model_test_main.cc
+2
-2
speechx/speechx/decoder/ctc_beam_search_decoder.cc
speechx/speechx/decoder/ctc_beam_search_decoder.cc
+3
-3
speechx/speechx/decoder/ctc_beam_search_decoder.h
speechx/speechx/decoder/ctc_beam_search_decoder.h
+1
-1
speechx/speechx/decoder/ctc_beam_search_decoder_main.cc
speechx/speechx/decoder/ctc_beam_search_decoder_main.cc
+2
-2
speechx/speechx/decoder/ctc_beam_search_opt.h
speechx/speechx/decoder/ctc_beam_search_opt.h
+1
-1
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
+3
-3
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
+1
-1
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
...hx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
+10
-7
speechx/speechx/decoder/ctc_tlg_decoder.h
speechx/speechx/decoder/ctc_tlg_decoder.h
+1
-1
speechx/speechx/frontend/audio/cmvn.cc
speechx/speechx/frontend/audio/cmvn.cc
+1
-1
speechx/speechx/frontend/audio/compute_fbank_main.cc
speechx/speechx/frontend/audio/compute_fbank_main.cc
+2
-2
speechx/speechx/frontend/audio/data_cache.h
speechx/speechx/frontend/audio/data_cache.h
+1
-1
speechx/speechx/frontend/audio/db_norm.cc
speechx/speechx/frontend/audio/db_norm.cc
+4
-3
speechx/speechx/frontend/audio/fbank.cc
speechx/speechx/frontend/audio/fbank.cc
+4
-3
speechx/speechx/frontend/audio/feature_pipeline.cc
speechx/speechx/frontend/audio/feature_pipeline.cc
+1
-1
speechx/speechx/frontend/audio/linear_spectrogram.cc
speechx/speechx/frontend/audio/linear_spectrogram.cc
+4
-3
speechx/speechx/frontend/audio/mfcc.cc
speechx/speechx/frontend/audio/mfcc.cc
+4
-3
speechx/speechx/nnet/ds2_nnet.cc
speechx/speechx/nnet/ds2_nnet.cc
+5
-4
speechx/speechx/nnet/ds2_nnet.h
speechx/speechx/nnet/ds2_nnet.h
+2
-2
speechx/speechx/nnet/ds2_nnet_main.cc
speechx/speechx/nnet/ds2_nnet_main.cc
+2
-2
speechx/speechx/nnet/u2_nnet.cc
speechx/speechx/nnet/u2_nnet.cc
+20
-20
speechx/speechx/nnet/u2_nnet.h
speechx/speechx/nnet/u2_nnet.h
+2
-2
speechx/speechx/nnet/u2_nnet_main.cc
speechx/speechx/nnet/u2_nnet_main.cc
+3
-3
speechx/speechx/protocol/websocket/websocket_client_main.cc
speechx/speechx/protocol/websocket/websocket_client_main.cc
+1
-1
speechx/speechx/recognizer/recognizer.h
speechx/speechx/recognizer/recognizer.h
+2
-1
speechx/speechx/recognizer/recognizer_main.cc
speechx/speechx/recognizer/recognizer_main.cc
+3
-2
speechx/speechx/recognizer/u2_recognizer.cc
speechx/speechx/recognizer/u2_recognizer.cc
+1
-1
speechx/speechx/recognizer/u2_recognizer.h
speechx/speechx/recognizer/u2_recognizer.h
+0
-2
speechx/speechx/utils/file_utils.cc
speechx/speechx/utils/file_utils.cc
+1
-1
未找到文件。
.pre-commit-config.yaml
浏览文件 @
a6b2a0a6
...
@@ -50,13 +50,20 @@ repos:
...
@@ -50,13 +50,20 @@ repos:
entry
:
bash .pre-commit-hooks/clang-format.hook -i
entry
:
bash .pre-commit-hooks/clang-format.hook -i
language
:
system
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude
:
(?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
exclude
:
(?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin
|third_party/ctc_decoders
).*(\.cpp|\.cc|\.h|\.py)$
#- id: copyright_checker
#- id: copyright_checker
# name: copyright_checker
# name: copyright_checker
# entry: python .pre-commit-hooks/copyright-check.hook
# entry: python .pre-commit-hooks/copyright-check.hook
# language: system
# language: system
# files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
# files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
# exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
# exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
-
id
:
cpplint
name
:
cpplint
description
:
Static code analysis of C/C++ files
language
:
python
files
:
\.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude
:
(?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.py)$
entry
:
cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent
-
repo
:
https://github.com/asottile/reorder_python_imports
-
repo
:
https://github.com/asottile/reorder_python_imports
rev
:
v2.4.0
rev
:
v2.4.0
hooks
:
hooks
:
...
...
speechx/speechx/base/basic_types.h
浏览文件 @
a6b2a0a6
...
@@ -22,39 +22,39 @@ typedef float BaseFloat;
...
@@ -22,39 +22,39 @@ typedef float BaseFloat;
typedef
double
double64
;
typedef
double
double64
;
typedef
signed
char
int8
;
typedef
signed
char
int8
;
typedef
short
int16
;
typedef
short
int16
;
// NOLINT
typedef
int
int32
;
typedef
int
int32
;
// NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef
long
int64
;
typedef
long
int64
;
// NOLINT
#else
#else
typedef
long
long
int64
;
typedef
long
long
int64
;
// NOLINT
#endif
#endif
typedef
unsigned
char
uint8
;
typedef
unsigned
char
uint8
;
// NOLINT
typedef
unsigned
short
uint16
;
typedef
unsigned
short
uint16
;
// NOLINT
typedef
unsigned
int
uint32
;
typedef
unsigned
int
uint32
;
// NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef
unsigned
long
uint64
;
typedef
unsigned
long
uint64
;
// NOLINT
#else
#else
typedef
unsigned
long
long
uint64
;
typedef
unsigned
long
long
uint64
;
// NOLINT
#endif
#endif
typedef
signed
int
char32
;
typedef
signed
int
char32
;
const
uint8
kuint8max
=
(
(
uint8
)
0xFF
);
const
uint8
kuint8max
=
(
static_cast
<
uint8
>
0xFF
);
const
uint16
kuint16max
=
(
(
uint16
)
0xFFFF
);
const
uint16
kuint16max
=
(
static_cast
<
uint16
>
0xFFFF
);
const
uint32
kuint32max
=
(
(
uint32
)
0xFFFFFFFF
);
const
uint32
kuint32max
=
(
static_cast
<
uint32
>
0xFFFFFFFF
);
const
uint64
kuint64max
=
(
(
uint64
)
(
0xFFFFFFFFFFFFFFFFLL
));
const
uint64
kuint64max
=
(
static_cast
<
uint64
>
(
0xFFFFFFFFFFFFFFFFLL
));
const
int8
kint8min
=
(
(
int8
)
0x80
);
const
int8
kint8min
=
(
static_cast
<
int8
>
0x80
);
const
int8
kint8max
=
(
(
int8
)
0x7F
);
const
int8
kint8max
=
(
static_cast
<
int8
>
0x7F
);
const
int16
kint16min
=
(
(
int16
)
0x8000
);
const
int16
kint16min
=
(
static_cast
<
int16
>
0x8000
);
const
int16
kint16max
=
(
(
int16
)
0x7FFF
);
const
int16
kint16max
=
(
static_cast
<
int16
>
0x7FFF
);
const
int32
kint32min
=
(
(
int32
)
0x80000000
);
const
int32
kint32min
=
(
static_cast
<
int32
>
0x80000000
);
const
int32
kint32max
=
(
(
int32
)
0x7FFFFFFF
);
const
int32
kint32max
=
(
static_cast
<
int32
>
0x7FFFFFFF
);
const
int64
kint64min
=
(
(
int64
)
(
0x8000000000000000LL
));
const
int64
kint64min
=
(
static_cast
<
int64
>
(
0x8000000000000000LL
));
const
int64
kint64max
=
(
(
int64
)
(
0x7FFFFFFFFFFFFFFFLL
));
const
int64
kint64max
=
(
static_cast
<
int64
>
(
0x7FFFFFFFFFFFFFFFLL
));
const
BaseFloat
kBaseFloatMax
=
std
::
numeric_limits
<
BaseFloat
>::
max
();
const
BaseFloat
kBaseFloatMax
=
std
::
numeric_limits
<
BaseFloat
>::
max
();
const
BaseFloat
kBaseFloatMin
=
std
::
numeric_limits
<
BaseFloat
>::
min
();
const
BaseFloat
kBaseFloatMin
=
std
::
numeric_limits
<
BaseFloat
>::
min
();
speechx/speechx/base/macros.h
浏览文件 @
a6b2a0a6
...
@@ -26,6 +26,6 @@ namespace ppspeech {
...
@@ -26,6 +26,6 @@ namespace ppspeech {
#endif
#endif
// kSpaceSymbol in UTF-8 is: ▁
// kSpaceSymbol in UTF-8 is: ▁
const
std
::
string
kSpaceSymbol
=
"
\xe2\x96\x81
"
;
const
char
[]
kSpaceSymbol
=
"
\xe2\x96\x81
"
;
}
// namespace ppspeech
}
// namespace ppspeech
speechx/speechx/base/thread_pool.h
浏览文件 @
a6b2a0a6
...
@@ -35,7 +35,7 @@
...
@@ -35,7 +35,7 @@
class
ThreadPool
{
class
ThreadPool
{
public:
public:
ThreadPool
(
size_t
);
explicit
ThreadPool
(
size_t
);
template
<
class
F
,
class
...
Args
>
template
<
class
F
,
class
...
Args
>
auto
enqueue
(
F
&&
f
,
Args
&&
...
args
)
auto
enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
;
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
;
...
...
speechx/speechx/codelab/nnet/ds2_model_test_main.cc
浏览文件 @
a6b2a0a6
...
@@ -64,8 +64,8 @@ void model_forward_test() {
...
@@ -64,8 +64,8 @@ void model_forward_test() {
;
;
std
::
string
model_graph
=
FLAGS_model_path
;
std
::
string
model_graph
=
FLAGS_model_path
;
std
::
string
model_params
=
FLAGS_param_path
;
std
::
string
model_params
=
FLAGS_param_path
;
CHECK
(
model_graph
!=
""
);
CHECK
_NE
(
model_graph
,
""
);
CHECK
(
model_params
!=
""
);
CHECK
_NE
(
model_params
,
""
);
cout
<<
"model path: "
<<
model_graph
<<
endl
;
cout
<<
"model path: "
<<
model_graph
<<
endl
;
cout
<<
"model param path : "
<<
model_params
<<
endl
;
cout
<<
"model param path : "
<<
model_params
<<
endl
;
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.cc
浏览文件 @
a6b2a0a6
...
@@ -39,12 +39,12 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
...
@@ -39,12 +39,12 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_
.
alpha
,
opts_
.
beta
,
opts_
.
lm_path
,
vocabulary_
);
opts_
.
alpha
,
opts_
.
beta
,
opts_
.
lm_path
,
vocabulary_
);
}
}
CHECK
(
opts_
.
blank
==
0
);
CHECK
_EQ
(
opts_
.
blank
,
0
);
auto
it
=
std
::
find
(
vocabulary_
.
begin
(),
vocabulary_
.
end
(),
" "
);
auto
it
=
std
::
find
(
vocabulary_
.
begin
(),
vocabulary_
.
end
(),
" "
);
space_id_
=
it
-
vocabulary_
.
begin
();
space_id_
=
it
-
vocabulary_
.
begin
();
// if no space in vocabulary
// if no space in vocabulary
if
(
(
size_t
)
space_id_
>=
vocabulary_
.
size
())
{
if
(
static_cast
<
size_t
>
(
space_id_
)
>=
vocabulary_
.
size
())
{
space_id_
=
-
2
;
space_id_
=
-
2
;
}
}
}
}
...
@@ -104,7 +104,7 @@ void CTCBeamSearch::ResetPrefixes() {
...
@@ -104,7 +104,7 @@ void CTCBeamSearch::ResetPrefixes() {
}
}
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
int
CTCBeamSearch
::
DecodeLikelihoods
(
const
vector
<
vector
<
float
>>&
probs
,
vector
<
string
>&
nbest_words
)
{
const
vector
<
string
>&
nbest_words
)
{
kaldi
::
Timer
timer
;
kaldi
::
Timer
timer
;
AdvanceDecoding
(
probs
);
AdvanceDecoding
(
probs
);
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
LOG
(
INFO
)
<<
"ctc decoding elapsed time(s) "
...
...
speechx/speechx/decoder/ctc_beam_search_decoder.h
浏览文件 @
a6b2a0a6
...
@@ -48,7 +48,7 @@ class CTCBeamSearch : public DecoderBase {
...
@@ -48,7 +48,7 @@ class CTCBeamSearch : public DecoderBase {
}
}
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
const
std
::
vector
<
std
::
string
>&
nbest_words
);
private:
private:
void
ResetPrefixes
();
void
ResetPrefixes
();
...
...
speechx/speechx/decoder/ctc_beam_search_decoder_main.cc
浏览文件 @
a6b2a0a6
...
@@ -59,8 +59,8 @@ int main(int argc, char* argv[]) {
...
@@ -59,8 +59,8 @@ int main(int argc, char* argv[]) {
google
::
InstallFailureSignalHandler
();
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
FLAGS_logtostderr
=
1
;
CHECK
(
FLAGS_result_wspecifier
!=
""
);
CHECK
_NE
(
FLAGS_result_wspecifier
,
""
);
CHECK
(
FLAGS_feature_rspecifier
!=
""
);
CHECK
_NE
(
FLAGS_feature_rspecifier
,
""
);
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
kaldi
::
SequentialBaseFloatMatrixReader
feature_reader
(
FLAGS_feature_rspecifier
);
FLAGS_feature_rspecifier
);
...
...
speechx/speechx/decoder/ctc_beam_search_opt.h
浏览文件 @
a6b2a0a6
...
@@ -36,7 +36,7 @@ struct CTCBeamSearchOptions {
...
@@ -36,7 +36,7 @@ struct CTCBeamSearchOptions {
// u2
// u2
int
first_beam_size
;
int
first_beam_size
;
int
second_beam_size
;
int
second_beam_size
;
explicit
CTCBeamSearchOptions
()
CTCBeamSearchOptions
()
:
blank
(
0
),
:
blank
(
0
),
dict_file
(
"vocab.txt"
),
dict_file
(
"vocab.txt"
),
lm_path
(
""
),
lm_path
(
""
),
...
...
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
浏览文件 @
a6b2a0a6
...
@@ -329,8 +329,8 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
...
@@ -329,8 +329,8 @@ void CTCPrefixBeamSearch::UpdateFinalContext() {
std
::
string
CTCPrefixBeamSearch
::
GetBestPath
(
int
index
)
{
std
::
string
CTCPrefixBeamSearch
::
GetBestPath
(
int
index
)
{
int
n_hyps
=
Outputs
().
size
();
int
n_hyps
=
Outputs
().
size
();
CHECK
(
n_hyps
>
0
);
CHECK
_GT
(
n_hyps
,
0
);
CHECK
(
index
<
n_hyps
);
CHECK
_LT
(
index
,
n_hyps
);
std
::
vector
<
int
>
one
=
Outputs
()[
index
];
std
::
vector
<
int
>
one
=
Outputs
()[
index
];
std
::
string
sentence
;
std
::
string
sentence
;
for
(
int
i
=
0
;
i
<
one
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
one
.
size
();
i
++
)
{
...
@@ -344,7 +344,7 @@ std::string CTCPrefixBeamSearch::GetBestPath() { return GetBestPath(0); }
...
@@ -344,7 +344,7 @@ std::string CTCPrefixBeamSearch::GetBestPath() { return GetBestPath(0); }
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
CTCPrefixBeamSearch
::
GetNBestPath
(
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
CTCPrefixBeamSearch
::
GetNBestPath
(
int
n
)
{
int
n
)
{
int
hyps_size
=
hypotheses_
.
size
();
int
hyps_size
=
hypotheses_
.
size
();
CHECK
(
hyps_size
>
0
);
CHECK
_GT
(
hyps_size
,
0
);
int
min_n
=
n
==
-
1
?
hypotheses_
.
size
()
:
std
::
min
(
n
,
hyps_size
);
int
min_n
=
n
==
-
1
?
hypotheses_
.
size
()
:
std
::
min
(
n
,
hyps_size
);
...
...
speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h
浏览文件 @
a6b2a0a6
...
@@ -28,7 +28,7 @@ class ContextGraph;
...
@@ -28,7 +28,7 @@ class ContextGraph;
class
CTCPrefixBeamSearch
:
public
DecoderBase
{
class
CTCPrefixBeamSearch
:
public
DecoderBase
{
public:
public:
CTCPrefixBeamSearch
(
const
std
::
string
&
vocab_path
,
CTCPrefixBeamSearch
(
const
std
::
string
&
vocab_path
,
const
CTCBeamSearchOptions
&
opts
);
const
CTCBeamSearchOptions
&
opts
);
~
CTCPrefixBeamSearch
()
{}
~
CTCPrefixBeamSearch
()
{}
SearchType
Type
()
const
{
return
SearchType
::
kPrefixBeamSearch
;
}
SearchType
Type
()
const
{
return
SearchType
::
kPrefixBeamSearch
;
}
...
...
speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc
浏览文件 @
a6b2a0a6
...
@@ -50,10 +50,10 @@ int main(int argc, char* argv[]) {
...
@@ -50,10 +50,10 @@ int main(int argc, char* argv[]) {
int32
num_done
=
0
,
num_err
=
0
;
int32
num_done
=
0
,
num_err
=
0
;
CHECK
(
FLAGS_result_wspecifier
!=
""
);
CHECK
_NE
(
FLAGS_result_wspecifier
,
""
);
CHECK
(
FLAGS_feature_rspecifier
!=
""
);
CHECK
_NE
(
FLAGS_feature_rspecifier
,
""
);
CHECK
(
FLAGS_vocab_path
!=
""
);
CHECK
_NE
(
FLAGS_vocab_path
,
""
);
CHECK
(
FLAGS_model_path
!=
""
);
CHECK
_NE
(
FLAGS_model_path
,
""
);
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
LOG
(
INFO
)
<<
"Reading vocab table "
<<
FLAGS_vocab_path
;
LOG
(
INFO
)
<<
"Reading vocab table "
<<
FLAGS_vocab_path
;
...
@@ -64,11 +64,14 @@ int main(int argc, char* argv[]) {
...
@@ -64,11 +64,14 @@ int main(int argc, char* argv[]) {
// nnet
// nnet
ppspeech
::
ModelOptions
model_opts
;
ppspeech
::
ModelOptions
model_opts
;
model_opts
.
model_path
=
FLAGS_model_path
;
model_opts
.
model_path
=
FLAGS_model_path
;
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet
=
std
::
make_shared
<
ppspeech
::
U2Nnet
>
(
model_opts
);
std
::
shared_ptr
<
ppspeech
::
U2Nnet
>
nnet
=
std
::
make_shared
<
ppspeech
::
U2Nnet
>
(
model_opts
);
// decodeable
// decodeable
std
::
shared_ptr
<
ppspeech
::
DataCache
>
raw_data
=
std
::
make_shared
<
ppspeech
::
DataCache
>
();
std
::
shared_ptr
<
ppspeech
::
DataCache
>
raw_data
=
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
=
std
::
make_shared
<
ppspeech
::
Decodable
>
(
nnet
,
raw_data
);
std
::
make_shared
<
ppspeech
::
DataCache
>
();
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
=
std
::
make_shared
<
ppspeech
::
Decodable
>
(
nnet
,
raw_data
);
// decoder
// decoder
ppspeech
::
CTCBeamSearchOptions
opts
;
ppspeech
::
CTCBeamSearchOptions
opts
;
...
...
speechx/speechx/decoder/ctc_tlg_decoder.h
浏览文件 @
a6b2a0a6
...
@@ -71,7 +71,7 @@ class TLGDecoder : public DecoderBase {
...
@@ -71,7 +71,7 @@ class TLGDecoder : public DecoderBase {
std
::
string
GetPartialResult
()
override
;
std
::
string
GetPartialResult
()
override
;
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
int
DecodeLikelihoods
(
const
std
::
vector
<
std
::
vector
<
BaseFloat
>>&
probs
,
std
::
vector
<
std
::
string
>&
nbest_words
);
const
std
::
vector
<
std
::
string
>&
nbest_words
);
protected:
protected:
std
::
string
GetBestPath
()
override
{
std
::
string
GetBestPath
()
override
{
...
...
speechx/speechx/frontend/audio/cmvn.cc
浏览文件 @
a6b2a0a6
...
@@ -30,7 +30,7 @@ using std::vector;
...
@@ -30,7 +30,7 @@ using std::vector;
CMVN
::
CMVN
(
std
::
string
cmvn_file
,
unique_ptr
<
FrontendInterface
>
base_extractor
)
CMVN
::
CMVN
(
std
::
string
cmvn_file
,
unique_ptr
<
FrontendInterface
>
base_extractor
)
:
var_norm_
(
true
)
{
:
var_norm_
(
true
)
{
CHECK
(
cmvn_file
!=
""
);
CHECK
_NE
(
cmvn_file
,
""
);
base_extractor_
=
std
::
move
(
base_extractor
);
base_extractor_
=
std
::
move
(
base_extractor
);
bool
binary
;
bool
binary
;
...
...
speechx/speechx/frontend/audio/compute_fbank_main.cc
浏览文件 @
a6b2a0a6
...
@@ -40,8 +40,8 @@ int main(int argc, char* argv[]) {
...
@@ -40,8 +40,8 @@ int main(int argc, char* argv[]) {
google
::
InstallFailureSignalHandler
();
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
FLAGS_logtostderr
=
1
;
CHECK
(
FLAGS_wav_rspecifier
.
size
()
>
0
);
CHECK
_GT
(
FLAGS_wav_rspecifier
.
size
(),
0
);
CHECK
(
FLAGS_feature_wspecifier
.
size
()
>
0
);
CHECK
_GT
(
FLAGS_feature_wspecifier
.
size
(),
0
);
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
FLAGS_wav_rspecifier
);
FLAGS_wav_rspecifier
);
kaldi
::
SequentialTableReader
<
kaldi
::
WaveInfoHolder
>
wav_info_reader
(
kaldi
::
SequentialTableReader
<
kaldi
::
WaveInfoHolder
>
wav_info_reader
(
...
...
speechx/speechx/frontend/audio/data_cache.h
浏览文件 @
a6b2a0a6
...
@@ -27,7 +27,7 @@ namespace ppspeech {
...
@@ -27,7 +27,7 @@ namespace ppspeech {
// pre-recorded audio/feature
// pre-recorded audio/feature
class
DataCache
:
public
FrontendInterface
{
class
DataCache
:
public
FrontendInterface
{
public:
public:
explicit
DataCache
()
{
finished_
=
false
;
}
DataCache
()
{
finished_
=
false
;
}
// accept waves/feats
// accept waves/feats
virtual
void
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
inputs
)
{
virtual
void
Accept
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
inputs
)
{
...
...
speechx/speechx/frontend/audio/db_norm.cc
浏览文件 @
a6b2a0a6
...
@@ -14,17 +14,18 @@
...
@@ -14,17 +14,18 @@
#include "frontend/audio/db_norm.h"
#include "frontend/audio/db_norm.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/kaldi-io.h"
namespace
ppspeech
{
namespace
ppspeech
{
using
kaldi
::
Vector
;
using
kaldi
::
VectorBase
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
BaseFloat
;
using
std
::
vector
;
using
kaldi
::
SubVector
;
using
kaldi
::
SubVector
;
using
kaldi
::
Vector
;
using
kaldi
::
VectorBase
;
using
std
::
unique_ptr
;
using
std
::
unique_ptr
;
using
std
::
vector
;
DecibelNormalizer
::
DecibelNormalizer
(
DecibelNormalizer
::
DecibelNormalizer
(
const
DecibelNormalizerOptions
&
opts
,
const
DecibelNormalizerOptions
&
opts
,
...
...
speechx/speechx/frontend/audio/fbank.cc
浏览文件 @
a6b2a0a6
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "frontend/audio/fbank.h"
#include "frontend/audio/fbank.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h"
#include "kaldi/feat/feature-functions.h"
...
@@ -20,12 +21,12 @@
...
@@ -20,12 +21,12 @@
namespace
ppspeech
{
namespace
ppspeech
{
using
kaldi
::
int32
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
Vector
;
using
kaldi
::
int32
;
using
kaldi
::
Matrix
;
using
kaldi
::
SubVector
;
using
kaldi
::
SubVector
;
using
kaldi
::
Vector
;
using
kaldi
::
VectorBase
;
using
kaldi
::
VectorBase
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
using
std
::
vector
;
FbankComputer
::
FbankComputer
(
const
Options
&
opts
)
FbankComputer
::
FbankComputer
(
const
Options
&
opts
)
...
...
speechx/speechx/frontend/audio/feature_pipeline.cc
浏览文件 @
a6b2a0a6
...
@@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
...
@@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
opts
.
linear_spectrogram_opts
,
std
::
move
(
data_source
)));
opts
.
linear_spectrogram_opts
,
std
::
move
(
data_source
)));
}
}
CHECK
(
opts
.
cmvn_file
!=
""
);
CHECK
_NE
(
opts
.
cmvn_file
,
""
);
unique_ptr
<
FrontendInterface
>
cmvn
(
unique_ptr
<
FrontendInterface
>
cmvn
(
new
ppspeech
::
CMVN
(
opts
.
cmvn_file
,
std
::
move
(
base_feature
)));
new
ppspeech
::
CMVN
(
opts
.
cmvn_file
,
std
::
move
(
base_feature
)));
...
...
speechx/speechx/frontend/audio/linear_spectrogram.cc
浏览文件 @
a6b2a0a6
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/linear_spectrogram.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h"
#include "kaldi/feat/feature-functions.h"
...
@@ -20,12 +21,12 @@
...
@@ -20,12 +21,12 @@
namespace
ppspeech
{
namespace
ppspeech
{
using
kaldi
::
int32
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
Vector
;
using
kaldi
::
int32
;
using
kaldi
::
Matrix
;
using
kaldi
::
SubVector
;
using
kaldi
::
SubVector
;
using
kaldi
::
Vector
;
using
kaldi
::
VectorBase
;
using
kaldi
::
VectorBase
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
using
std
::
vector
;
LinearSpectrogramComputer
::
LinearSpectrogramComputer
(
const
Options
&
opts
)
LinearSpectrogramComputer
::
LinearSpectrogramComputer
(
const
Options
&
opts
)
...
...
speechx/speechx/frontend/audio/mfcc.cc
浏览文件 @
a6b2a0a6
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "frontend/audio/mfcc.h"
#include "frontend/audio/mfcc.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h"
#include "kaldi/feat/feature-functions.h"
...
@@ -21,12 +22,12 @@
...
@@ -21,12 +22,12 @@
namespace
ppspeech
{
namespace
ppspeech
{
using
kaldi
::
int32
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
BaseFloat
;
using
kaldi
::
Vector
;
using
kaldi
::
int32
;
using
kaldi
::
Matrix
;
using
kaldi
::
SubVector
;
using
kaldi
::
SubVector
;
using
kaldi
::
Vector
;
using
kaldi
::
VectorBase
;
using
kaldi
::
VectorBase
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
using
std
::
vector
;
Mfcc
::
Mfcc
(
const
MfccOptions
&
opts
,
Mfcc
::
Mfcc
(
const
MfccOptions
&
opts
,
...
...
speechx/speechx/nnet/ds2_nnet.cc
浏览文件 @
a6b2a0a6
...
@@ -13,15 +13,16 @@
...
@@ -13,15 +13,16 @@
// limitations under the License.
// limitations under the License.
#include "nnet/ds2_nnet.h"
#include "nnet/ds2_nnet.h"
#include "absl/strings/str_split.h"
#include "absl/strings/str_split.h"
namespace
ppspeech
{
namespace
ppspeech
{
using
std
::
vector
;
using
std
::
string
;
using
std
::
shared_ptr
;
using
kaldi
::
Matrix
;
using
kaldi
::
Matrix
;
using
kaldi
::
Vector
;
using
kaldi
::
Vector
;
using
std
::
shared_ptr
;
using
std
::
string
;
using
std
::
vector
;
void
PaddleNnet
::
InitCacheEncouts
(
const
ModelOptions
&
opts
)
{
void
PaddleNnet
::
InitCacheEncouts
(
const
ModelOptions
&
opts
)
{
std
::
vector
<
std
::
string
>
cache_names
;
std
::
vector
<
std
::
string
>
cache_names
;
...
@@ -207,7 +208,7 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
...
@@ -207,7 +208,7 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
// inferences->Resize(row * col);
// inferences->Resize(row * col);
// *inference_dim = col;
// *inference_dim = col;
out
->
logprobs
.
Resize
(
row
*
col
);
out
->
logprobs
.
Resize
(
row
*
col
);
out
->
vocab_dim
=
col
;
out
->
vocab_dim
=
col
;
output_tensor
->
CopyToCpu
(
out
->
logprobs
.
Data
());
output_tensor
->
CopyToCpu
(
out
->
logprobs
.
Data
());
...
...
speechx/speechx/nnet/ds2_nnet.h
浏览文件 @
a6b2a0a6
...
@@ -26,7 +26,7 @@ template <typename T>
...
@@ -26,7 +26,7 @@ template <typename T>
class
Tensor
{
class
Tensor
{
public:
public:
Tensor
()
{}
Tensor
()
{}
Tensor
(
const
std
::
vector
<
int
>&
shape
)
:
_shape
(
shape
)
{
explicit
Tensor
(
const
std
::
vector
<
int
>&
shape
)
:
_shape
(
shape
)
{
int
neml
=
std
::
accumulate
(
int
neml
=
std
::
accumulate
(
_shape
.
begin
(),
_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
_shape
.
begin
(),
_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
LOG
(
INFO
)
<<
"Tensor neml: "
<<
neml
;
LOG
(
INFO
)
<<
"Tensor neml: "
<<
neml
;
...
@@ -50,7 +50,7 @@ class Tensor {
...
@@ -50,7 +50,7 @@ class Tensor {
class
PaddleNnet
:
public
NnetBase
{
class
PaddleNnet
:
public
NnetBase
{
public:
public:
PaddleNnet
(
const
ModelOptions
&
opts
);
explicit
PaddleNnet
(
const
ModelOptions
&
opts
);
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
const
int32
&
feature_dim
,
const
int32
&
feature_dim
,
...
...
speechx/speechx/nnet/ds2_nnet_main.cc
浏览文件 @
a6b2a0a6
...
@@ -12,13 +12,13 @@
...
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "nnet/ds2_nnet.h"
#include "base/common.h"
#include "base/common.h"
#include "decoder/param.h"
#include "decoder/param.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
DEFINE_string
(
feature_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
feature_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
nnet_prob_wspecifier
,
""
,
"nnet porb wspecifier"
);
DEFINE_string
(
nnet_prob_wspecifier
,
""
,
"nnet porb wspecifier"
);
...
@@ -44,7 +44,7 @@ int main(int argc, char* argv[]) {
...
@@ -44,7 +44,7 @@ int main(int argc, char* argv[]) {
int32
num_done
=
0
,
num_err
=
0
;
int32
num_done
=
0
,
num_err
=
0
;
ppspeech
::
ModelOptions
model_opts
=
ppspeech
::
ModelOptions
::
InitFromFlags
();
ppspeech
::
ModelOptions
model_opts
=
ppspeech
::
ModelOptions
::
InitFromFlags
();
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
new
ppspeech
::
PaddleNnet
(
model_opts
));
new
ppspeech
::
PaddleNnet
(
model_opts
));
...
...
speechx/speechx/nnet/u2_nnet.cc
浏览文件 @
a6b2a0a6
...
@@ -158,7 +158,7 @@ void U2Nnet::Reset() {
...
@@ -158,7 +158,7 @@ void U2Nnet::Reset() {
}
}
// Debug API
// Debug API
void
U2Nnet
::
FeedEncoderOuts
(
paddle
::
Tensor
&
encoder_out
)
{
void
U2Nnet
::
FeedEncoderOuts
(
const
paddle
::
Tensor
&
encoder_out
)
{
// encoder_out (T,D)
// encoder_out (T,D)
encoder_outs_
.
clear
();
encoder_outs_
.
clear
();
encoder_outs_
.
push_back
(
encoder_out
);
encoder_outs_
.
push_back
(
encoder_out
);
...
@@ -206,7 +206,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
...
@@ -206,7 +206,7 @@ void U2Nnet::ForwardEncoderChunkImpl(
float
*
feats_ptr
=
feats
.
mutable_data
<
float
>
();
float
*
feats_ptr
=
feats
.
mutable_data
<
float
>
();
// not cache feature in nnet
// not cache feature in nnet
CHECK
(
cached_feats_
.
size
()
==
0
);
CHECK
_EQ
(
cached_feats_
.
size
(),
0
);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
std
::
memcpy
(
feats_ptr
,
std
::
memcpy
(
feats_ptr
,
chunk_feats
.
data
(),
chunk_feats
.
data
(),
...
@@ -247,9 +247,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
...
@@ -247,9 +247,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
// call.
// call.
std
::
vector
<
paddle
::
Tensor
>
inputs
=
{
std
::
vector
<
paddle
::
Tensor
>
inputs
=
{
feats
,
offset
,
/*required_cache_size, */
att_cache_
,
cnn_cache_
};
feats
,
offset
,
/*required_cache_size, */
att_cache_
,
cnn_cache_
};
CHECK
(
inputs
.
size
()
==
4
);
CHECK
_EQ
(
inputs
.
size
(),
4
);
std
::
vector
<
paddle
::
Tensor
>
outputs
=
forward_encoder_chunk_
(
inputs
);
std
::
vector
<
paddle
::
Tensor
>
outputs
=
forward_encoder_chunk_
(
inputs
);
CHECK
(
outputs
.
size
()
==
3
);
CHECK
_EQ
(
outputs
.
size
(),
3
);
#ifdef USE_GPU
#ifdef USE_GPU
paddle
::
Tensor
chunk_out
=
outputs
[
0
].
copy_to
(
paddle
::
CPUPlace
());
paddle
::
Tensor
chunk_out
=
outputs
[
0
].
copy_to
(
paddle
::
CPUPlace
());
...
@@ -319,9 +319,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
...
@@ -319,9 +319,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
inputs
.
clear
();
inputs
.
clear
();
outputs
.
clear
();
outputs
.
clear
();
inputs
.
push_back
(
chunk_out
);
inputs
.
push_back
(
chunk_out
);
CHECK
(
inputs
.
size
()
==
1
);
CHECK
_EQ
(
inputs
.
size
(),
1
);
outputs
=
ctc_activation_
(
inputs
);
outputs
=
ctc_activation_
(
inputs
);
CHECK
(
outputs
.
size
()
==
1
);
CHECK
_EQ
(
outputs
.
size
(),
1
);
paddle
::
Tensor
ctc_log_probs
=
outputs
[
0
];
paddle
::
Tensor
ctc_log_probs
=
outputs
[
0
];
#ifdef TEST_DEBUG
#ifdef TEST_DEBUG
...
@@ -350,9 +350,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
...
@@ -350,9 +350,9 @@ void U2Nnet::ForwardEncoderChunkImpl(
// Copy to output, (B=1,T,D)
// Copy to output, (B=1,T,D)
std
::
vector
<
int64_t
>
ctc_log_probs_shape
=
ctc_log_probs
.
shape
();
std
::
vector
<
int64_t
>
ctc_log_probs_shape
=
ctc_log_probs
.
shape
();
CHECK
(
ctc_log_probs_shape
.
size
()
==
3
);
CHECK
_EQ
(
ctc_log_probs_shape
.
size
(),
3
);
int
B
=
ctc_log_probs_shape
[
0
];
int
B
=
ctc_log_probs_shape
[
0
];
CHECK
(
B
==
1
);
CHECK
_EQ
(
B
,
1
);
int
T
=
ctc_log_probs_shape
[
1
];
int
T
=
ctc_log_probs_shape
[
1
];
int
D
=
ctc_log_probs_shape
[
2
];
int
D
=
ctc_log_probs_shape
[
2
];
*
vocab_dim
=
D
;
*
vocab_dim
=
D
;
...
@@ -393,9 +393,9 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob,
...
@@ -393,9 +393,9 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob,
// hyp (U,)
// hyp (U,)
float
score
=
0.0
f
;
float
score
=
0.0
f
;
std
::
vector
<
int64_t
>
dims
=
prob
.
shape
();
std
::
vector
<
int64_t
>
dims
=
prob
.
shape
();
CHECK
(
dims
.
size
()
==
3
);
CHECK
_EQ
(
dims
.
size
(),
3
);
VLOG
(
2
)
<<
"prob shape: "
<<
dims
[
0
]
<<
", "
<<
dims
[
1
]
<<
", "
<<
dims
[
2
];
VLOG
(
2
)
<<
"prob shape: "
<<
dims
[
0
]
<<
", "
<<
dims
[
1
]
<<
", "
<<
dims
[
2
];
CHECK
(
dims
[
0
]
==
1
);
CHECK
_EQ
(
dims
[
0
],
1
);
int
vocab_dim
=
static_cast
<
int
>
(
dims
[
2
]);
int
vocab_dim
=
static_cast
<
int
>
(
dims
[
2
]);
const
float
*
prob_ptr
=
prob
.
data
<
float
>
();
const
float
*
prob_ptr
=
prob
.
data
<
float
>
();
...
@@ -520,14 +520,14 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
...
@@ -520,14 +520,14 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
std
::
vector
<
paddle
::
experimental
::
Tensor
>
inputs
{
std
::
vector
<
paddle
::
experimental
::
Tensor
>
inputs
{
hyps_tensor
,
hyps_lens
,
encoder_out
};
hyps_tensor
,
hyps_lens
,
encoder_out
};
std
::
vector
<
paddle
::
Tensor
>
outputs
=
forward_attention_decoder_
(
inputs
);
std
::
vector
<
paddle
::
Tensor
>
outputs
=
forward_attention_decoder_
(
inputs
);
CHECK
(
outputs
.
size
()
==
2
);
CHECK
_EQ
(
outputs
.
size
(),
2
);
// (B, Umax, V)
// (B, Umax, V)
paddle
::
Tensor
probs
=
outputs
[
0
];
paddle
::
Tensor
probs
=
outputs
[
0
];
std
::
vector
<
int64_t
>
probs_shape
=
probs
.
shape
();
std
::
vector
<
int64_t
>
probs_shape
=
probs
.
shape
();
CHECK
(
probs_shape
.
size
()
==
3
);
CHECK
_EQ
(
probs_shape
.
size
(),
3
);
CHECK
(
probs_shape
[
0
]
==
num_hyps
);
CHECK
_EQ
(
probs_shape
[
0
],
num_hyps
);
CHECK
(
probs_shape
[
1
]
==
max_hyps_len
);
CHECK
_EQ
(
probs_shape
[
1
],
max_hyps_len
);
#ifdef TEST_DEBUG
#ifdef TEST_DEBUG
{
{
...
@@ -582,13 +582,13 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
...
@@ -582,13 +582,13 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
paddle
::
Tensor
r_probs
=
outputs
[
1
];
paddle
::
Tensor
r_probs
=
outputs
[
1
];
std
::
vector
<
int64_t
>
r_probs_shape
=
r_probs
.
shape
();
std
::
vector
<
int64_t
>
r_probs_shape
=
r_probs
.
shape
();
if
(
is_bidecoder_
&&
reverse_weight
>
0
)
{
if
(
is_bidecoder_
&&
reverse_weight
>
0
)
{
CHECK
(
r_probs_shape
.
size
()
==
3
);
CHECK
_EQ
(
r_probs_shape
.
size
(),
3
);
CHECK
(
r_probs_shape
[
0
]
==
num_hyps
);
CHECK
_EQ
(
r_probs_shape
[
0
],
num_hyps
);
CHECK
(
r_probs_shape
[
1
]
==
max_hyps_len
);
CHECK
_EQ
(
r_probs_shape
[
1
],
max_hyps_len
);
}
else
{
}
else
{
// dump r_probs
// dump r_probs
CHECK
(
r_probs_shape
.
size
()
==
1
);
CHECK
_EQ
(
r_probs_shape
.
size
(),
1
);
CHECK
(
r_probs_shape
[
0
]
==
1
)
<<
r_probs_shape
[
0
];
CHECK
_EQ
(
r_probs_shape
[
0
],
1
)
<<
r_probs_shape
[
0
];
}
}
// compute rescoring score
// compute rescoring score
...
@@ -644,7 +644,7 @@ void U2Nnet::EncoderOuts(
...
@@ -644,7 +644,7 @@ void U2Nnet::EncoderOuts(
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
const
paddle
::
Tensor
&
item
=
encoder_outs_
[
i
];
const
paddle
::
Tensor
&
item
=
encoder_outs_
[
i
];
const
std
::
vector
<
int64_t
>
shape
=
item
.
shape
();
const
std
::
vector
<
int64_t
>
shape
=
item
.
shape
();
CHECK
(
shape
.
size
()
==
3
);
CHECK
_EQ
(
shape
.
size
(),
3
);
const
int
&
B
=
shape
[
0
];
const
int
&
B
=
shape
[
0
];
const
int
&
T
=
shape
[
1
];
const
int
&
T
=
shape
[
1
];
const
int
&
D
=
shape
[
2
];
const
int
&
D
=
shape
[
2
];
...
...
speechx/speechx/nnet/u2_nnet.h
浏览文件 @
a6b2a0a6
...
@@ -73,7 +73,7 @@ class U2NnetBase : public NnetBase {
...
@@ -73,7 +73,7 @@ class U2NnetBase : public NnetBase {
class
U2Nnet
:
public
U2NnetBase
{
class
U2Nnet
:
public
U2NnetBase
{
public:
public:
U2Nnet
(
const
ModelOptions
&
opts
);
explicit
U2Nnet
(
const
ModelOptions
&
opts
);
U2Nnet
(
const
U2Nnet
&
other
);
U2Nnet
(
const
U2Nnet
&
other
);
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
void
FeedForward
(
const
kaldi
::
Vector
<
kaldi
::
BaseFloat
>&
features
,
...
@@ -108,7 +108,7 @@ class U2Nnet : public U2NnetBase {
...
@@ -108,7 +108,7 @@ class U2Nnet : public U2NnetBase {
std
::
vector
<
float
>*
rescoring_score
)
override
;
std
::
vector
<
float
>*
rescoring_score
)
override
;
// debug
// debug
void
FeedEncoderOuts
(
paddle
::
Tensor
&
encoder_out
);
void
FeedEncoderOuts
(
const
paddle
::
Tensor
&
encoder_out
);
void
EncoderOuts
(
void
EncoderOuts
(
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>*
encoder_out
)
const
;
std
::
vector
<
kaldi
::
Vector
<
kaldi
::
BaseFloat
>>*
encoder_out
)
const
;
...
...
speechx/speechx/nnet/u2_nnet_main.cc
浏览文件 @
a6b2a0a6
...
@@ -39,9 +39,9 @@ int main(int argc, char* argv[]) {
...
@@ -39,9 +39,9 @@ int main(int argc, char* argv[]) {
int32
num_done
=
0
,
num_err
=
0
;
int32
num_done
=
0
,
num_err
=
0
;
CHECK
(
FLAGS_feature_rspecifier
.
size
()
>
0
);
CHECK
_GT
(
FLAGS_feature_rspecifier
.
size
(),
0
);
CHECK
(
FLAGS_nnet_prob_wspecifier
.
size
()
>
0
);
CHECK
_GT
(
FLAGS_nnet_prob_wspecifier
.
size
(),
0
);
CHECK
(
FLAGS_model_path
.
size
()
>
0
);
CHECK
_GT
(
FLAGS_model_path
.
size
(),
0
);
LOG
(
INFO
)
<<
"input rspecifier: "
<<
FLAGS_feature_rspecifier
;
LOG
(
INFO
)
<<
"input rspecifier: "
<<
FLAGS_feature_rspecifier
;
LOG
(
INFO
)
<<
"output wspecifier: "
<<
FLAGS_nnet_prob_wspecifier
;
LOG
(
INFO
)
<<
"output wspecifier: "
<<
FLAGS_nnet_prob_wspecifier
;
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
LOG
(
INFO
)
<<
"model path: "
<<
FLAGS_model_path
;
...
...
speechx/speechx/protocol/websocket/websocket_client_main.cc
浏览文件 @
a6b2a0a6
...
@@ -12,10 +12,10 @@
...
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "websocket/websocket_client.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "kaldi/util/table-types.h"
#include "websocket/websocket_client.h"
DEFINE_string
(
host
,
"127.0.0.1"
,
"host of websocket server"
);
DEFINE_string
(
host
,
"127.0.0.1"
,
"host of websocket server"
);
DEFINE_int32
(
port
,
8082
,
"port of websocket server"
);
DEFINE_int32
(
port
,
8082
,
"port of websocket server"
);
...
...
speechx/speechx/recognizer/recognizer.h
浏览文件 @
a6b2a0a6
...
@@ -39,7 +39,8 @@ struct RecognizerResource {
...
@@ -39,7 +39,8 @@ struct RecognizerResource {
resource
.
feature_pipeline_opts
=
resource
.
feature_pipeline_opts
=
FeaturePipelineOptions
::
InitFromFlags
();
FeaturePipelineOptions
::
InitFromFlags
();
resource
.
feature_pipeline_opts
.
assembler_opts
.
fill_zero
=
true
;
resource
.
feature_pipeline_opts
.
assembler_opts
.
fill_zero
=
true
;
LOG
(
INFO
)
<<
"ds2 need fill zero be true: "
<<
resource
.
feature_pipeline_opts
.
assembler_opts
.
fill_zero
;
LOG
(
INFO
)
<<
"ds2 need fill zero be true: "
<<
resource
.
feature_pipeline_opts
.
assembler_opts
.
fill_zero
;
resource
.
model_opts
=
ModelOptions
::
InitFromFlags
();
resource
.
model_opts
=
ModelOptions
::
InitFromFlags
();
resource
.
tlg_opts
=
TLGDecoderOptions
::
InitFromFlags
();
resource
.
tlg_opts
=
TLGDecoderOptions
::
InitFromFlags
();
return
resource
;
return
resource
;
...
...
speechx/speechx/recognizer/recognizer_main.cc
浏览文件 @
a6b2a0a6
...
@@ -13,9 +13,9 @@
...
@@ -13,9 +13,9 @@
// limitations under the License.
// limitations under the License.
#include "decoder/param.h"
#include "decoder/param.h"
#include "recognizer/recognizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
#include "kaldi/util/table-types.h"
#include "recognizer/recognizer.h"
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
wav_rspecifier
,
""
,
"test feature rspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
DEFINE_string
(
result_wspecifier
,
""
,
"test result wspecifier"
);
...
@@ -30,7 +30,8 @@ int main(int argc, char* argv[]) {
...
@@ -30,7 +30,8 @@ int main(int argc, char* argv[]) {
google
::
InstallFailureSignalHandler
();
google
::
InstallFailureSignalHandler
();
FLAGS_logtostderr
=
1
;
FLAGS_logtostderr
=
1
;
ppspeech
::
RecognizerResource
resource
=
ppspeech
::
RecognizerResource
::
InitFromFlags
();
ppspeech
::
RecognizerResource
resource
=
ppspeech
::
RecognizerResource
::
InitFromFlags
();
ppspeech
::
Recognizer
recognizer
(
resource
);
ppspeech
::
Recognizer
recognizer
(
resource
);
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
kaldi
::
SequentialTableReader
<
kaldi
::
WaveHolder
>
wav_reader
(
...
...
speechx/speechx/recognizer/u2_recognizer.cc
浏览文件 @
a6b2a0a6
...
@@ -35,7 +35,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
...
@@ -35,7 +35,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
BaseFloat
am_scale
=
resource
.
acoustic_scale
;
BaseFloat
am_scale
=
resource
.
acoustic_scale
;
decodable_
.
reset
(
new
Decodable
(
nnet
,
feature_pipeline_
,
am_scale
));
decodable_
.
reset
(
new
Decodable
(
nnet
,
feature_pipeline_
,
am_scale
));
CHECK
(
resource
.
vocab_path
!=
""
);
CHECK
_NE
(
resource
.
vocab_path
,
""
);
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
decoder_
.
reset
(
new
CTCPrefixBeamSearch
(
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
));
resource
.
vocab_path
,
resource
.
decoder_opts
.
ctc_prefix_search_opts
));
...
...
speechx/speechx/recognizer/u2_recognizer.h
浏览文件 @
a6b2a0a6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
...
...
speechx/speechx/utils/file_utils.cc
浏览文件 @
a6b2a0a6
...
@@ -40,4 +40,4 @@ std::string ReadFile2String(const std::string& path) {
...
@@ -40,4 +40,4 @@ std::string ReadFile2String(const std::string& path) {
return
std
::
string
((
std
::
istreambuf_iterator
<
char
>
(
input_file
)),
return
std
::
string
((
std
::
istreambuf_iterator
<
char
>
(
input_file
)),
std
::
istreambuf_iterator
<
char
>
());
std
::
istreambuf_iterator
<
char
>
());
}
}
}
}
// namespace ppspeech
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录