Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
406b4fc7
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
406b4fc7
编写于
3月 02, 2022
作者:
S
SmileGoat
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
align linear_feature & nnet
上级
b584b969
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
864 addition
and
172 deletion
+864
-172
speechx/CMakeLists.txt
speechx/CMakeLists.txt
+53
-16
speechx/patch/CPPLINT.cfg
speechx/patch/CPPLINT.cfg
+1
-0
speechx/patch/openfst/src/include/fst/flags.h
speechx/patch/openfst/src/include/fst/flags.h
+228
-0
speechx/patch/openfst/src/include/fst/log.h
speechx/patch/openfst/src/include/fst/log.h
+82
-0
speechx/patch/openfst/src/lib/flags.cc
speechx/patch/openfst/src/lib/flags.cc
+166
-0
speechx/speechx/CMakeLists.txt
speechx/speechx/CMakeLists.txt
+8
-7
speechx/speechx/base/flags.h
speechx/speechx/base/flags.h
+1
-1
speechx/speechx/base/log.h
speechx/speechx/base/log.h
+1
-1
speechx/speechx/codelab/decoder_test/offline_decoder_main.cc
speechx/speechx/codelab/decoder_test/offline_decoder_main.cc
+19
-20
speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc
speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc
+49
-3
speechx/speechx/codelab/nnet_test/model_test.cc
speechx/speechx/codelab/nnet_test/model_test.cc
+134
-0
speechx/speechx/decoder/CMakeLists.txt
speechx/speechx/decoder/CMakeLists.txt
+2
-2
speechx/speechx/frontend/CMakeLists.txt
speechx/speechx/frontend/CMakeLists.txt
+2
-2
speechx/speechx/frontend/linear_spectrogram.cc
speechx/speechx/frontend/linear_spectrogram.cc
+4
-6
speechx/speechx/frontend/normalizer.h
speechx/speechx/frontend/normalizer.h
+1
-1
speechx/speechx/kaldi/matrix/BUILD
speechx/speechx/kaldi/matrix/BUILD
+0
-39
speechx/speechx/kaldi/matrix/CMakeLists.txt
speechx/speechx/kaldi/matrix/CMakeLists.txt
+1
-1
speechx/speechx/kaldi/matrix/kaldi-blas.h
speechx/speechx/kaldi/matrix/kaldi-blas.h
+1
-1
speechx/speechx/nnet/CMakeLists.txt
speechx/speechx/nnet/CMakeLists.txt
+7
-2
speechx/speechx/nnet/decodable-itf.h
speechx/speechx/nnet/decodable-itf.h
+1
-1
speechx/speechx/nnet/decodable.cc
speechx/speechx/nnet/decodable.cc
+9
-0
speechx/speechx/nnet/decodable.h
speechx/speechx/nnet/decodable.h
+1
-1
speechx/speechx/nnet/nnet_interface.h
speechx/speechx/nnet/nnet_interface.h
+4
-4
speechx/speechx/nnet/paddle_nnet.cc
speechx/speechx/nnet/paddle_nnet.cc
+77
-58
speechx/speechx/nnet/paddle_nnet.h
speechx/speechx/nnet/paddle_nnet.h
+8
-6
speechx/speechx/utils/file_utils.cc
speechx/speechx/utils/file_utils.cc
+4
-0
未找到文件。
speechx/CMakeLists.txt
浏览文件 @
406b4fc7
...
...
@@ -26,7 +26,6 @@ option(TEST_DEBUG "option for debug" OFF)
# Include third party
###############################################################################
# #example for include third party
# FetchContent_Declare()
# # FetchContent_MakeAvailable was not added until CMake 3.14
# FetchContent_MakeAvailable()
# include_directories()
...
...
@@ -50,20 +49,25 @@ include_directories(${absl_SOURCE_DIR})
#)
#FetchContent_MakeAvailable(libsndfile)
# todo boost build
#include(FetchContent)
#FetchContent_Declare(
# Boost
# URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.zip
# URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
#)
#FetchContent_MakeAvailable(Boost)
#include_directories(${Boost_SOURCE_DIR})
#boost
set
(
boost_SOURCE_DIR
${
fc_patch
}
/boost-src
)
set
(
boost_PREFIX_DIR
${
fc_patch
}
/boost-subbuild/boost-prefix
)
include
(
ExternalProject
)
ExternalProject_Add
(
boost
URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
SOURCE_DIR
${
boost_SOURCE_DIR
}
PREFIX
${
boost_PREFIX_DIR
}
BUILD_IN_SOURCE 1
CONFIGURE_COMMAND ./bootstrap.sh
BUILD_COMMAND ./b2
INSTALL_COMMAND
""
)
link_directories
(
${
boost_SOURCE_DIR
}
/stage/lib
)
include_directories
(
${
boost_SOURCE_DIR
}
)
set
(
BOOST_ROOT
${
fc_patch
}
/boost-subbuild/boost-populate-prefix/src/boost_1_75_0
)
include_directories
(
${
fc_patch
}
/boost-subbuild/boost-populate-prefix/src/boost_1_75_0
)
link_directories
(
${
fc_patch
}
/boost-subbuild/boost-populate-prefix/src/boost_1_75_0/stage/lib
)
set
(
BOOST_ROOT
${
boost_SOURCE_DIR
}
)
include
(
FetchContent
)
FetchContent_Declare
(
kenlm
...
...
@@ -71,9 +75,10 @@ FetchContent_Declare(
GIT_TAG
"df2d717e95183f79a90b2fa6e4307083a351ca6a"
)
FetchContent_MakeAvailable
(
kenlm
)
add_dependencies
(
kenlm
B
oost
)
add_dependencies
(
kenlm
b
oost
)
include_directories
(
${
kenlm_SOURCE_DIR
}
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb"
)
# gflags
FetchContent_Declare
(
gflags
...
...
@@ -106,19 +111,34 @@ set(openfst_PREFIX_DIR ${fc_patch}/openfst-subbuild/openfst-populate-prefix)
ExternalProject_Add
(
openfst
URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip
URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
# PREFIX ${openfst_PREFIX_DIR}
SOURCE_DIR
${
openfst_SOURCE_DIR
}
BINARY_DIR
${
openfst_BINARY_DIR
}
CONFIGURE_COMMAND
${
openfst_SOURCE_DIR
}
/configure --prefix=
${
openfst_PREFIX_DIR
}
"CPPFLAGS=-I
${
gflags_BINARY_DIR
}
/include -I
${
glog_SOURCE_DIR
}
/src -I
${
glog_BINARY_DIR
}
"
"LDFLAGS=-L
${
gflags_BINARY_DIR
}
-L
${
glog_BINARY_DIR
}
"
"LIBS=-lgflags_nothreads -lglog -lpthread"
COMMAND
${
CMAKE_COMMAND
}
-E copy_directory
${
CMAKE_CURRENT_SOURCE_DIR
}
/patch/openfst
${
openfst_SOURCE_DIR
}
BUILD_COMMAND make -j 4
)
add_dependencies
(
openfst gflags glog
)
link_directories
(
${
openfst_PREFIX_DIR
}
/lib
)
include_directories
(
${
openfst_PREFIX_DIR
}
/include
)
set
(
PADDLE_LIB
${
fc_patch
}
/paddle-lib/paddle_inference
)
# paddle lib
set
(
paddle_SOURCE_DIR
${
fc_patch
}
/paddle-lib
)
set
(
paddle_PREFIX_DIR
${
fc_patch
}
/paddle-lib-prefix
)
ExternalProject_Add
(
paddle
URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
PREFIX
${
paddle_PREFIX_DIR
}
SOURCE_DIR
${
paddle_SOURCE_DIR
}
CONFIGURE_COMMAND
""
BUILD_COMMAND
""
INSTALL_COMMAND
""
)
set
(
PADDLE_LIB
${
fc_patch
}
/paddle-lib
)
include_directories
(
"
${
PADDLE_LIB
}
/paddle/include"
)
set
(
PADDLE_LIB_THIRD_PARTY_PATH
"
${
PADDLE_LIB
}
/third_party/install/"
)
include_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
protobuf/include"
)
...
...
@@ -133,6 +153,23 @@ link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
xxhash/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
cryptopp/lib"
)
link_directories
(
"
${
PADDLE_LIB
}
/paddle/lib"
)
link_directories
(
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
mklml/lib"
)
##paddle with mkl
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fopenmp"
)
set
(
MATH_LIB_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
mklml"
)
include_directories
(
"
${
MATH_LIB_PATH
}
/include"
)
set
(
MATH_LIB
${
MATH_LIB_PATH
}
/lib/libmklml_intel
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
MATH_LIB_PATH
}
/lib/libiomp5
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
MKLDNN_PATH
"
${
PADDLE_LIB_THIRD_PARTY_PATH
}
mkldnn"
)
include_directories
(
"
${
MKLDNN_PATH
}
/include"
)
set
(
MKLDNN_LIB
${
MKLDNN_PATH
}
/lib/libmkldnn.so.0
)
set
(
EXTERNAL_LIB
"-lrt -ldl -lpthread"
)
set
(
DEPS
${
PADDLE_LIB
}
/paddle/lib/libpaddle_inference
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
MATH_LIB
}
${
MKLDNN_LIB
}
glog gflags protobuf xxhash cryptopp
${
EXTERNAL_LIB
}
)
add_subdirectory
(
speechx
)
...
...
speechx/patch/CPPLINT.cfg
0 → 100644
浏览文件 @
406b4fc7
exclude_files=.*
speechx/patch/openfst/src/include/fst/flags.h
0 → 100644
浏览文件 @
406b4fc7
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Google-style flag handling declarations and inline definitions.
#ifndef FST_LIB_FLAGS_H_
#define FST_LIB_FLAGS_H_
#include <cstdlib>
#include <iostream>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <fst/types.h>
#include <fst/lock.h>
#include "gflags/gflags.h"
#include "glog/logging.h"
using
std
::
string
;
// FLAGS USAGE:
//
// Definition example:
//
// DEFINE_int32(length, 0, "length");
//
// This defines variable FLAGS_length, initialized to 0.
//
// Declaration example:
//
// DECLARE_int32(length);
//
// SET_FLAGS() can be used to set flags from the command line
// using, for example, '--length=2'.
//
// ShowUsage() can be used to print out command and flag usage.
// #define DECLARE_bool(name) extern bool FLAGS_ ## name
// #define DECLARE_string(name) extern string FLAGS_ ## name
// #define DECLARE_int32(name) extern int32 FLAGS_ ## name
// #define DECLARE_int64(name) extern int64 FLAGS_ ## name
// #define DECLARE_double(name) extern double FLAGS_ ## name
template
<
typename
T
>
struct
FlagDescription
{
FlagDescription
(
T
*
addr
,
const
char
*
doc
,
const
char
*
type
,
const
char
*
file
,
const
T
val
)
:
address
(
addr
),
doc_string
(
doc
),
type_name
(
type
),
file_name
(
file
),
default_value
(
val
)
{}
T
*
address
;
const
char
*
doc_string
;
const
char
*
type_name
;
const
char
*
file_name
;
const
T
default_value
;
};
template
<
typename
T
>
class
FlagRegister
{
public:
static
FlagRegister
<
T
>
*
GetRegister
()
{
static
auto
reg
=
new
FlagRegister
<
T
>
;
return
reg
;
}
const
FlagDescription
<
T
>
&
GetFlagDescription
(
const
string
&
name
)
const
{
fst
::
MutexLock
l
(
&
flag_lock_
);
auto
it
=
flag_table_
.
find
(
name
);
return
it
!=
flag_table_
.
end
()
?
it
->
second
:
0
;
}
void
SetDescription
(
const
string
&
name
,
const
FlagDescription
<
T
>
&
desc
)
{
fst
::
MutexLock
l
(
&
flag_lock_
);
flag_table_
.
insert
(
make_pair
(
name
,
desc
));
}
bool
SetFlag
(
const
string
&
val
,
bool
*
address
)
const
{
if
(
val
==
"true"
||
val
==
"1"
||
val
.
empty
())
{
*
address
=
true
;
return
true
;
}
else
if
(
val
==
"false"
||
val
==
"0"
)
{
*
address
=
false
;
return
true
;
}
else
{
return
false
;
}
}
bool
SetFlag
(
const
string
&
val
,
string
*
address
)
const
{
*
address
=
val
;
return
true
;
}
bool
SetFlag
(
const
string
&
val
,
int32
*
address
)
const
{
char
*
p
=
0
;
*
address
=
strtol
(
val
.
c_str
(),
&
p
,
0
);
return
!
val
.
empty
()
&&
*
p
==
'\0'
;
}
bool
SetFlag
(
const
string
&
val
,
int64
*
address
)
const
{
char
*
p
=
0
;
*
address
=
strtoll
(
val
.
c_str
(),
&
p
,
0
);
return
!
val
.
empty
()
&&
*
p
==
'\0'
;
}
bool
SetFlag
(
const
string
&
val
,
double
*
address
)
const
{
char
*
p
=
0
;
*
address
=
strtod
(
val
.
c_str
(),
&
p
);
return
!
val
.
empty
()
&&
*
p
==
'\0'
;
}
bool
SetFlag
(
const
string
&
arg
,
const
string
&
val
)
const
{
for
(
typename
std
::
map
<
string
,
FlagDescription
<
T
>
>::
const_iterator
it
=
flag_table_
.
begin
();
it
!=
flag_table_
.
end
();
++
it
)
{
const
string
&
name
=
it
->
first
;
const
FlagDescription
<
T
>
&
desc
=
it
->
second
;
if
(
arg
==
name
)
return
SetFlag
(
val
,
desc
.
address
);
}
return
false
;
}
void
GetUsage
(
std
::
set
<
std
::
pair
<
string
,
string
>>
*
usage_set
)
const
{
for
(
auto
it
=
flag_table_
.
begin
();
it
!=
flag_table_
.
end
();
++
it
)
{
const
string
&
name
=
it
->
first
;
const
FlagDescription
<
T
>
&
desc
=
it
->
second
;
string
usage
=
" --"
+
name
;
usage
+=
": type = "
;
usage
+=
desc
.
type_name
;
usage
+=
", default = "
;
usage
+=
GetDefault
(
desc
.
default_value
)
+
"
\n
"
;
usage
+=
desc
.
doc_string
;
usage_set
->
insert
(
make_pair
(
desc
.
file_name
,
usage
));
}
}
private:
string
GetDefault
(
bool
default_value
)
const
{
return
default_value
?
"true"
:
"false"
;
}
string
GetDefault
(
const
string
&
default_value
)
const
{
return
"
\"
"
+
default_value
+
"
\"
"
;
}
template
<
class
V
>
string
GetDefault
(
const
V
&
default_value
)
const
{
std
::
ostringstream
strm
;
strm
<<
default_value
;
return
strm
.
str
();
}
mutable
fst
::
Mutex
flag_lock_
;
// Multithreading lock.
std
::
map
<
string
,
FlagDescription
<
T
>>
flag_table_
;
};
template
<
typename
T
>
class
FlagRegisterer
{
public:
FlagRegisterer
(
const
string
&
name
,
const
FlagDescription
<
T
>
&
desc
)
{
auto
registr
=
FlagRegister
<
T
>::
GetRegister
();
registr
->
SetDescription
(
name
,
desc
);
}
private:
FlagRegisterer
(
const
FlagRegisterer
&
)
=
delete
;
FlagRegisterer
&
operator
=
(
const
FlagRegisterer
&
)
=
delete
;
};
#define DEFINE_VAR(type, name, value, doc) \
type FLAGS_ ## name = value; \
static FlagRegisterer<type> \
name ## _flags_registerer(#name, FlagDescription<type>(&FLAGS_ ## name, \
doc, \
#type, \
__FILE__, \
value))
// #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc)
// #define DEFINE_string(name, value, doc) \
// DEFINE_VAR(string, name, value, doc)
// #define DEFINE_int32(name, value, doc) DEFINE_VAR(int32, name, value, doc)
// #define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc)
// #define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc)
// Temporary directory.
DECLARE_string
(
tmpdir
);
void
SetFlags
(
const
char
*
usage
,
int
*
argc
,
char
***
argv
,
bool
remove_flags
,
const
char
*
src
=
""
);
#define SET_FLAGS(usage, argc, argv, rmflags) \
gflags::ParseCommandLineFlags(argc, argv, true)
// SetFlags(usage, argc, argv, rmflags, __FILE__)
// Deprecated; for backward compatibility.
inline
void
InitFst
(
const
char
*
usage
,
int
*
argc
,
char
***
argv
,
bool
rmflags
)
{
return
SetFlags
(
usage
,
argc
,
argv
,
rmflags
);
}
void
ShowUsage
(
bool
long_usage
=
true
);
#endif // FST_LIB_FLAGS_H_
speechx/patch/openfst/src/include/fst/log.h
0 → 100644
浏览文件 @
406b4fc7
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Google-style logging declarations and inline definitions.
#ifndef FST_LIB_LOG_H_
#define FST_LIB_LOG_H_
#include <cassert>
#include <iostream>
#include <string>
#include <fst/types.h>
#include <fst/flags.h>
using
std
::
string
;
DECLARE_int32
(
v
);
class
LogMessage
{
public:
LogMessage
(
const
string
&
type
)
:
fatal_
(
type
==
"FATAL"
)
{
std
::
cerr
<<
type
<<
": "
;
}
~
LogMessage
()
{
std
::
cerr
<<
std
::
endl
;
if
(
fatal_
)
exit
(
1
);
}
std
::
ostream
&
stream
()
{
return
std
::
cerr
;
}
private:
bool
fatal_
;
};
// #define LOG(type) LogMessage(#type).stream()
// #define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO)
// Checks
inline
void
FstCheck
(
bool
x
,
const
char
*
expr
,
const
char
*
file
,
int
line
)
{
if
(
!
x
)
{
LOG
(
FATAL
)
<<
"Check failed:
\"
"
<<
expr
<<
"
\"
file: "
<<
file
<<
" line: "
<<
line
;
}
}
// #define CHECK(x) FstCheck(static_cast<bool>(x), #x, __FILE__, __LINE__)
// #define CHECK_EQ(x, y) CHECK((x) == (y))
// #define CHECK_LT(x, y) CHECK((x) < (y))
// #define CHECK_GT(x, y) CHECK((x) > (y))
// #define CHECK_LE(x, y) CHECK((x) <= (y))
// #define CHECK_GE(x, y) CHECK((x) >= (y))
// #define CHECK_NE(x, y) CHECK((x) != (y))
// Debug checks
// #define DCHECK(x) assert(x)
// #define DCHECK_EQ(x, y) DCHECK((x) == (y))
// #define DCHECK_LT(x, y) DCHECK((x) < (y))
// #define DCHECK_GT(x, y) DCHECK((x) > (y))
// #define DCHECK_LE(x, y) DCHECK((x) <= (y))
// #define DCHECK_GE(x, y) DCHECK((x) >= (y))
// #define DCHECK_NE(x, y) DCHECK((x) != (y))
// Ports
#define ATTRIBUTE_DEPRECATED __attribute__((deprecated))
#endif // FST_LIB_LOG_H_
speechx/patch/openfst/src/lib/flags.cc
0 → 100644
浏览文件 @
406b4fc7
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Google-style flag handling definitions.
#include <cstring>
#if _MSC_VER
#include <io.h>
#include <fcntl.h>
#endif
#include <fst/compat.h>
#include <fst/flags.h>
static
const
char
*
private_tmpdir
=
getenv
(
"TMPDIR"
);
// DEFINE_int32(v, 0, "verbosity level");
// DEFINE_bool(help, false, "show usage information");
// DEFINE_bool(helpshort, false, "show brief usage information");
#ifndef _MSC_VER
DEFINE_string
(
tmpdir
,
private_tmpdir
?
private_tmpdir
:
"/tmp"
,
"temporary directory"
);
#else
DEFINE_string
(
tmpdir
,
private_tmpdir
?
private_tmpdir
:
getenv
(
"TEMP"
),
"temporary directory"
);
#endif // !_MSC_VER
using
namespace
std
;
static
string
flag_usage
;
static
string
prog_src
;
// Sets prog_src to src.
static
void
SetProgSrc
(
const
char
*
src
)
{
prog_src
=
src
;
#if _MSC_VER
// This common code is invoked by all FST binaries, and only by them. Switch
// stdin and stdout into "binary" mode, so that 0x0A won't be translated into
// a 0x0D 0x0A byte pair in a pipe or a shell redirect. Other streams are
// already using ios::binary where binary files are read or written.
// Kudos to @daanzu for the suggested fix.
// https://github.com/kkm000/openfst/issues/20
// https://github.com/kkm000/openfst/pull/23
// https://github.com/kkm000/openfst/pull/32
_setmode
(
_fileno
(
stdin
),
O_BINARY
);
_setmode
(
_fileno
(
stdout
),
O_BINARY
);
#endif
// Remove "-main" in src filename. Flags are defined in fstx.cc but SetFlags()
// is called in fstx-main.cc, which results in a filename mismatch in
// ShowUsageRestrict() below.
static
constexpr
char
kMainSuffix
[]
=
"-main.cc"
;
const
int
prefix_length
=
prog_src
.
size
()
-
strlen
(
kMainSuffix
);
if
(
prefix_length
>
0
&&
prog_src
.
substr
(
prefix_length
)
==
kMainSuffix
)
{
prog_src
.
erase
(
prefix_length
,
strlen
(
"-main"
));
}
}
void
SetFlags
(
const
char
*
usage
,
int
*
argc
,
char
***
argv
,
bool
remove_flags
,
const
char
*
src
)
{
flag_usage
=
usage
;
SetProgSrc
(
src
);
int
index
=
1
;
for
(;
index
<
*
argc
;
++
index
)
{
string
argval
=
(
*
argv
)[
index
];
if
(
argval
[
0
]
!=
'-'
||
argval
==
"-"
)
break
;
while
(
argval
[
0
]
==
'-'
)
argval
=
argval
.
substr
(
1
);
// Removes initial '-'.
string
arg
=
argval
;
string
val
=
""
;
// Splits argval (arg=val) into arg and val.
auto
pos
=
argval
.
find
(
"="
);
if
(
pos
!=
string
::
npos
)
{
arg
=
argval
.
substr
(
0
,
pos
);
val
=
argval
.
substr
(
pos
+
1
);
}
auto
bool_register
=
FlagRegister
<
bool
>::
GetRegister
();
if
(
bool_register
->
SetFlag
(
arg
,
val
))
continue
;
auto
string_register
=
FlagRegister
<
string
>::
GetRegister
();
if
(
string_register
->
SetFlag
(
arg
,
val
))
continue
;
auto
int32_register
=
FlagRegister
<
int32
>::
GetRegister
();
if
(
int32_register
->
SetFlag
(
arg
,
val
))
continue
;
auto
int64_register
=
FlagRegister
<
int64
>::
GetRegister
();
if
(
int64_register
->
SetFlag
(
arg
,
val
))
continue
;
auto
double_register
=
FlagRegister
<
double
>::
GetRegister
();
if
(
double_register
->
SetFlag
(
arg
,
val
))
continue
;
LOG
(
FATAL
)
<<
"SetFlags: Bad option: "
<<
(
*
argv
)[
index
];
}
if
(
remove_flags
)
{
for
(
auto
i
=
0
;
i
<
*
argc
-
index
;
++
i
)
{
(
*
argv
)[
i
+
1
]
=
(
*
argv
)[
i
+
index
];
}
*
argc
-=
index
-
1
;
}
// if (FLAGS_help) {
// ShowUsage(true);
// exit(1);
// }
// if (FLAGS_helpshort) {
// ShowUsage(false);
// exit(1);
// }
}
// If flag is defined in file 'src' and 'in_src' true or is not
// defined in file 'src' and 'in_src' is false, then print usage.
static
void
ShowUsageRestrict
(
const
std
::
set
<
pair
<
string
,
string
>>
&
usage_set
,
const
string
&
src
,
bool
in_src
,
bool
show_file
)
{
string
old_file
;
bool
file_out
=
false
;
bool
usage_out
=
false
;
for
(
const
auto
&
pair
:
usage_set
)
{
const
auto
&
file
=
pair
.
first
;
const
auto
&
usage
=
pair
.
second
;
bool
match
=
file
==
src
;
if
((
match
&&
!
in_src
)
||
(
!
match
&&
in_src
))
continue
;
if
(
file
!=
old_file
)
{
if
(
show_file
)
{
if
(
file_out
)
cout
<<
"
\n
"
;
cout
<<
"Flags from: "
<<
file
<<
"
\n
"
;
file_out
=
true
;
}
old_file
=
file
;
}
cout
<<
usage
<<
"
\n
"
;
usage_out
=
true
;
}
if
(
usage_out
)
cout
<<
"
\n
"
;
}
void
ShowUsage
(
bool
long_usage
)
{
std
::
set
<
pair
<
string
,
string
>>
usage_set
;
cout
<<
flag_usage
<<
"
\n
"
;
auto
bool_register
=
FlagRegister
<
bool
>::
GetRegister
();
bool_register
->
GetUsage
(
&
usage_set
);
auto
string_register
=
FlagRegister
<
string
>::
GetRegister
();
string_register
->
GetUsage
(
&
usage_set
);
auto
int32_register
=
FlagRegister
<
int32
>::
GetRegister
();
int32_register
->
GetUsage
(
&
usage_set
);
auto
int64_register
=
FlagRegister
<
int64
>::
GetRegister
();
int64_register
->
GetUsage
(
&
usage_set
);
auto
double_register
=
FlagRegister
<
double
>::
GetRegister
();
double_register
->
GetUsage
(
&
usage_set
);
if
(
!
prog_src
.
empty
())
{
cout
<<
"PROGRAM FLAGS:
\n\n
"
;
ShowUsageRestrict
(
usage_set
,
prog_src
,
true
,
false
);
}
if
(
!
long_usage
)
return
;
if
(
!
prog_src
.
empty
())
cout
<<
"LIBRARY FLAGS:
\n\n
"
;
ShowUsageRestrict
(
usage_set
,
prog_src
,
false
,
true
);
}
speechx/speechx/CMakeLists.txt
浏览文件 @
406b4fc7
...
...
@@ -2,10 +2,6 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project
(
speechx LANGUAGES CXX
)
link_directories
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/third_party/openblas
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++14"
)
include_directories
(
${
CMAKE_CURRENT_SOURCE_DIR
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/kaldi
...
...
@@ -36,11 +32,16 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory
(
decoder
)
link_libraries
(
dl
)
add_executable
(
mfcc-test codelab/feat_test/feature-mfcc-test.cc
)
target_link_libraries
(
mfcc-test kaldi-mfcc
)
target_link_libraries
(
mfcc-test kaldi-mfcc
${
MATH_LIB
}
)
add_executable
(
linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc
)
target_link_libraries
(
linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog
)
#add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
#target_link_libraries(offline_decoder_main nnet decoder gflags glog)
add_executable
(
offline_decoder_main codelab/decoder_test/offline_decoder_main.cc
)
target_link_libraries
(
offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util
${
DEPS
}
)
add_executable
(
model_test codelab/nnet_test/model_test.cc
)
target_link_libraries
(
model_test PUBLIC nnet gflags
${
DEPS
}
)
speechx/speechx/base/flags.h
浏览文件 @
406b4fc7
...
...
@@ -14,4 +14,4 @@
#pragma once
#include "
gflags/g
flags.h"
#include "
fst/
flags.h"
speechx/speechx/base/log.h
浏览文件 @
406b4fc7
...
...
@@ -14,4 +14,4 @@
#pragma once
#include "
glog/loggin
g.h"
#include "
fst/lo
g.h"
speechx/speechx/codelab/decoder_test/offline_decoder_main.cc
浏览文件 @
406b4fc7
...
...
@@ -4,16 +4,20 @@
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "nnet/paddle_nnet.h"
#include "nnet/decodable.h"
DEFINE_string
(
feature_respecifier
,
""
,
"test nnet prob"
);
using
kaldi
::
BaseFloat
;
using
kaldi
::
Matrix
;
using
std
::
vector
;
void
SplitFeature
(
kaldi
::
Matrix
<
BaseFloat
>
feature
,
int32
chunk_size
,
std
::
vector
<
kaldi
::
Matrix
<
BaseFloat
>>
feature_chunks
)
{
//
void SplitFeature(kaldi::Matrix<BaseFloat> feature,
//
int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>*
feature_chunks) {
}
//
}
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
...
...
@@ -24,31 +28,26 @@ int main(int argc, char* argv[]) {
// test nnet_output --> decoder result
int32
num_done
=
0
,
num_err
=
0
;
CTCBeamSearchOptions
opts
;
CTCBeamSearch
decoder
(
opts
);
ppspeech
::
CTCBeamSearchOptions
opts
;
ppspeech
::
CTCBeamSearch
decoder
(
opts
);
ModelOptions
model_opts
;
std
::
shared_ptr
<
PaddleNnet
>
nnet
(
new
PaddleNnet
(
model_opts
));
ppspeech
::
ModelOptions
model_opts
;
std
::
shared_ptr
<
ppspeech
::
PaddleNnet
>
nnet
(
new
ppspeech
::
PaddleNnet
(
model_opts
));
Decodable
decodable
();
decodable
.
SetNnet
(
nnet
);
std
::
shared_ptr
<
ppspeech
::
Decodable
>
decodable
(
new
ppspeech
::
Decodable
(
nnet
));
int32
chunk_size
=
0
;
//int32 chunk_size = 35;
decoder
.
InitDecoder
();
for
(;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
string
utt
=
feature_reader
.
Key
();
const
kaldi
::
Matrix
<
BaseFloat
>
feature
=
feature_reader
.
Value
();
vector
<
Matrix
<
BaseFloat
>>
feature_chunks
;
SplitFeature
(
feature
,
chunk_size
,
&
feature_chunks
);
for
(
auto
feature_chunk
:
feature_chunks
)
{
decodable
.
FeedFeatures
(
feature_chunk
);
decoder
.
InitDecoder
();
decoder
.
AdvanceDecode
(
decodable
,
chunk_size
);
}
decodable
.
InputFinished
();
decodable
->
FeedFeatures
(
feature
);
decoder
.
AdvanceDecode
(
decodable
,
8
);
decodable
->
InputFinished
();
std
::
string
result
;
result
=
decoder
.
GetFinalBestPath
();
KALDI_LOG
<<
" the result of "
<<
utt
<<
" is "
<<
result
;
decodable
.
Reset
();
decodable
->
Reset
();
++
num_done
;
}
...
...
speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc
浏览文件 @
406b4fc7
...
...
@@ -11,6 +11,10 @@
DEFINE_string
(
wav_rspecifier
,
""
,
"test wav path"
);
DEFINE_string
(
feature_wspecifier
,
""
,
"test wav ark"
);
std
::
vector
<
float
>
mean_
{
-
13730251.531853663
,
-
12982852.199316509
,
-
13673844.299583456
,
-
13089406.559646806
,
-
12673095.524938712
,
-
12823859.223276224
,
-
13590267.158903603
,
-
14257618.467152044
,
-
14374605.116185192
,
-
14490009.21822485
,
-
14849827.158924166
,
-
15354435.470563512
,
-
15834149.206532761
,
-
16172971.985514281
,
-
16348740.496746974
,
-
16423536.699409386
,
-
16556246.263649225
,
-
16744088.772748645
,
-
16916184.08510357
,
-
17054034.840031497
,
-
17165612.509455364
,
-
17255955.470915023
,
-
17322572.527648456
,
-
17408943.862033736
,
-
17521554.799865916
,
-
17620623.254924215
,
-
17699792.395918526
,
-
17723364.411134344
,
-
17741483.4433254
,
-
17747426.888704527
,
-
17733315.928209435
,
-
17748780.160905756
,
-
17808336.883775543
,
-
17895918.671983004
,
-
18009812.59173023
,
-
18098188.66548325
,
-
18195798.958462656
,
-
18293617.62980999
,
-
18397432.92077201
,
-
18505834.787318766
,
-
18585451.8100908
,
-
18652438.235649142
,
-
18700960.306275308
,
-
18734944.58792185
,
-
18737426.313365128
,
-
18735347.165987637
,
-
18738813.444170244
,
-
18737086.848890636
,
-
18731576.2474336
,
-
18717405.44095871
,
-
18703089.25545657
,
-
18691014.546456724
,
-
18692460.568905357
,
-
18702119.628629155
,
-
18727710.621126678
,
-
18761582.72034647
,
-
18806745.835547544
,
-
18850674.8692112
,
-
18884431.510951452
,
-
18919999.992506847
,
-
18939303.799078144
,
-
18952946.273760635
,
-
18980289.22996379
,
-
19011610.17803294
,
-
19040948.61805145
,
-
19061021.429847397
,
-
19112055.53768819
,
-
19149667.414264943
,
-
19201127.05091321
,
-
19270250.82564605
,
-
19334606.883057203
,
-
19390513.336589377
,
-
19444176.259208687
,
-
19502755.000038862
,
-
19544333.014549147
,
-
19612668.183176614
,
-
19681902.19006569
,
-
19771969.951249883
,
-
19873329.723376893
,
-
19996752.59235844
,
-
20110031.131400537
,
-
20231658.612529557
,
-
20319378.894054495
,
-
20378534.45718066
,
-
20413332.089584175
,
-
20438147.844177883
,
-
20443710.248040095
,
-
20465457.02238927
,
-
20488610.969337028
,
-
20516295.16424432
,
-
20541423.795738827
,
-
20553192.874953747
,
-
20573605.50701977
,
-
20577871.61936797
,
-
20571807.008916274
,
-
20556242.38912231
,
-
20542199.30819195
,
-
20521239.063551214
,
-
20519150.80004532
,
-
20527204.80248933
,
-
20536933.769257784
,
-
20543470.522332076
,
-
20549700.089992985
,
-
20551525.24958494
,
-
20554873.406493705
,
-
20564277.65794227
,
-
20572211.740052115
,
-
20574305.69550465
,
-
20575494.450104576
,
-
20567092.577932164
,
-
20549302.929608088
,
-
20545445.11878376
,
-
20546625.326603737
,
-
20549190.03499401
,
-
20554824.947828256
,
-
20568341.378989458
,
-
20577582.331383612
,
-
20577980.519402675
,
-
20566603.03458152
,
-
20560131.592262644
,
-
20552166.469060015
,
-
20549063.06763577
,
-
20544490.562339947
,
-
20539817.82346569
,
-
20528747.715731595
,
-
20518026.24576161
,
-
20510977.844974525
,
-
20506874.36087992
,
-
20506731.11977665
,
-
20510482.133420516
,
-
20507760.92101862
,
-
20494644.834457114
,
-
20480107.89304893
,
-
20461312.091867123
,
-
20442941.75080173
,
-
20426123.02834838
,
-
20424607.675283
,
-
20426810.369107097
,
-
20434024.50097819
,
-
20437404.75544205
,
-
20447688.63916367
,
-
20460893.335563846
,
-
20482922.735127095
,
-
20503610.119434915
,
-
20527062.76448319
,
-
20557830.035128627
,
-
20593274.72068722
,
-
20632528.452965066
,
-
20673637.471334763
,
-
20733106.97143075
,
-
20842921.0447562
,
-
21054357.83621519
,
-
21416569.534189366
,
-
21978460.272811692
,
-
22753170.052172784
,
-
23671344.10563395
,
-
24613499.293358143
,
-
25406477.12230188
,
-
25884377.82156489
,
-
26049040.62791664
,
-
26996879.104431007
};
std
::
vector
<
float
>
variance_
{
213747175.10846674
,
188395815.34302503
,
212706429.10966414
,
199109025.81461075
,
189235901.23864496
,
194901336.53253657
,
217481594.29306737
,
238689869.12327808
,
243977501.24115244
,
248479623.6431067
,
259766741.47116545
,
275516766.7790273
,
291271202.3691234
,
302693239.8220509
,
308627358.3997694
,
311143911.38788426
,
315446105.07731867
,
321705430.9341829
,
327458907.4659941
,
332245072.43223983
,
336251717.5935284
,
339694069.7639722
,
342188204.4322228
,
345587110.31313115
,
349903086.2875232
,
353660214.20643026
,
356700344.5270885
,
357665362.3529641
,
358493352.05658793
,
358857951.620328
,
358375239.52774596
,
358899733.6342954
,
361051818.3511561
,
364361716.05025816
,
368750322.3771452
,
372047800.6462831
,
375655861.1349018
,
379358519.1980013
,
383327605.3935181
,
387458599.282341
,
390434692.3406868
,
392994486.35057056
,
394874418.04603153
,
396230525.79763395
,
396365592.0414835
,
396334819.8242737
,
396488353.19250053
,
396438877.00744957
,
396197980.4459586
,
395590921.6672991
,
395001107.62072515
,
394528291.7318225
,
394593110.424006
,
395018405.59353715
,
396110577.5415993
,
397506704.0371068
,
399400197.4657644
,
401243568.2468382
,
402687134.7805103
,
404136047.2872507
,
404883170.001883
,
405522253.219517
,
406660365.3626476
,
407919346.0991902
,
409045348.5384909
,
409759588.7889818
,
411974821.8564483
,
413489718.78201455
,
415535392.56684107
,
418466481.97674364
,
421104678.35678065
,
423405392.5200779
,
425550570.40798235
,
427929423.9579701
,
429585274.253478
,
432368493.55181056
,
435193587.13513297
,
438886855.20476013
,
443058876.8633751
,
448181232.5093362
,
452883835.6332396
,
458056721.77926534
,
461816531.22735566
,
464363620.1970998
,
465886343.5057493
,
466928872.0651
,
467180536.42647296
,
468111848.70714295
,
469138695.3071312
,
470378429.6930793
,
471517958.7132626
,
472109050.4262365
,
473087417.0177867
,
473381322.04648733
,
473220195.85483915
,
472666071.8998819
,
472124669.87879956
,
471298571.411737
,
471251033.2902761
,
471672676.43128747
,
472177147.2193172
,
472572361.7711908
,
472968783.7751127
,
473156295.4164052
,
473398034.82676554
,
473897703.5203811
,
474328271.33112127
,
474452670.98002136
,
474549003.99284613
,
474252887.13567275
,
473557462.909069
,
473483385.85193115
,
473609738.04855174
,
473746944.82085115
,
474016729.91696435
,
474617321.94138587
,
475045097.237122
,
475125402.586558
,
474664112.9824912
,
474426247.5800283
,
474104075.42796475
,
473978219.7273978
,
473773171.7798875
,
473578534.69508696
,
473102924.16904145
,
472651240.5232615
,
472374383.1810912
,
472209479.6956096
,
472202298.8921673
,
472370090.76781124
,
472220933.99374026
,
471625467.37106377
,
470994646.51883453
,
470182428.9637543
,
469348211.5939578
,
468570387.4467277
,
468540442.7225135
,
468672018.90414184
,
468994346.9533251
,
469138757.58201426
,
469553915.95710236
,
470134523.38582784
,
471082421.62055486
,
471962316.51804745
,
472939745.1708408
,
474250621.5944825
,
475773933.43199486
,
477465399.71087736
,
479218782.61382693
,
481752299.7930922
,
486608947.8984568
,
496119403.2067917
,
512730085.5704984
,
539048915.2641417
,
576285298.3548826
,
621610270.2240586
,
669308196.4436442
,
710656993.5957186
,
736344437.3725077
,
745481288.0241544
,
801121432.9925804
};
int
count_
=
912592
;
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
google
::
InitGoogleLogging
(
argv
[
0
]);
...
...
@@ -21,20 +25,62 @@ int main(int argc, char* argv[]) {
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32
num_done
=
0
,
num_err
=
0
;
ppspeech
::
LinearSpectrogramOptions
opt
;
opt
.
frame_opts
.
frame_length_ms
=
20
;
opt
.
frame_opts
.
frame_shift_ms
=
10
;
ppspeech
::
DecibelNormalizerOptions
db_norm_opt
;
std
::
unique_ptr
<
ppspeech
::
FeatureExtractorInterface
>
base_feature_extractor
(
new
ppspeech
::
DecibelNormalizer
(
db_norm_opt
));
ppspeech
::
LinearSpectrogram
linear_spectrogram
(
opt
,
std
::
move
(
base_feature_extractor
));
float
streaming_chunk
=
0.36
;
int
sample_rate
=
16000
;
int
chunk_sample_size
=
streaming_chunk
*
sample_rate
;
LOG
(
INFO
)
<<
mean_
.
size
();
for
(
size_t
i
=
0
;
i
<
mean_
.
size
();
i
++
)
{
mean_
[
i
]
/=
count_
;
variance_
[
i
]
=
variance_
[
i
]
/
count_
-
mean_
[
i
]
*
mean_
[
i
];
if
(
variance_
[
i
]
<
1.0e-20
)
{
variance_
[
i
]
=
1.0e-20
;
}
variance_
[
i
]
=
1.0
/
std
::
sqrt
(
variance_
[
i
]);
}
for
(;
!
wav_reader
.
Done
();
wav_reader
.
Next
())
{
std
::
string
utt
=
wav_reader
.
Key
();
const
kaldi
::
WaveData
&
wave_data
=
wav_reader
.
Value
();
int32
this_channel
=
0
;
kaldi
::
SubVector
<
kaldi
::
BaseFloat
>
waveform
(
wave_data
.
Data
(),
this_channel
);
kaldi
::
Matrix
<
BaseFloat
>
features
;
linear_spectrogram
.
AcceptWaveform
(
waveform
);
linear_spectrogram
.
ReadFeats
(
&
features
);
int
tot_samples
=
waveform
.
Dim
();
int
sample_offset
=
0
;
std
::
vector
<
kaldi
::
Matrix
<
BaseFloat
>>
feats
;
int
feature_rows
=
0
;
while
(
sample_offset
<
tot_samples
)
{
int
cur_chunk_size
=
std
::
min
(
chunk_sample_size
,
tot_samples
-
sample_offset
);
kaldi
::
Vector
<
kaldi
::
BaseFloat
>
wav_chunk
(
cur_chunk_size
);
for
(
int
i
=
0
;
i
<
cur_chunk_size
;
++
i
)
{
wav_chunk
(
i
)
=
waveform
(
sample_offset
+
i
);
}
kaldi
::
Matrix
<
BaseFloat
>
features
;
linear_spectrogram
.
AcceptWaveform
(
wav_chunk
);
linear_spectrogram
.
ReadFeats
(
&
features
);
feats
.
push_back
(
features
);
sample_offset
+=
cur_chunk_size
;
feature_rows
+=
features
.
NumRows
();
}
int
cur_idx
=
0
;
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>
features
(
feature_rows
,
feats
[
0
].
NumCols
());
for
(
auto
feat
:
feats
)
{
for
(
int
row_idx
=
0
;
row_idx
<
feat
.
NumRows
();
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
feat
.
NumCols
();
++
col_idx
)
{
features
(
cur_idx
,
col_idx
)
=
(
feat
(
row_idx
,
col_idx
)
-
mean_
[
col_idx
])
*
variance_
[
col_idx
];
}
++
cur_idx
;
}
}
feat_writer
.
Write
(
utt
,
features
);
if
(
num_done
%
50
==
0
&&
num_done
!=
0
)
...
...
speechx/speechx/codelab/nnet_test/model_test.cc
0 → 100644
浏览文件 @
406b4fc7
#include "paddle_inference_api.h"
#include <gflags/gflags.h>
#include <iostream>
#include <thread>
#include <fstream>
#include <iterator>
#include <algorithm>
#include <numeric>
#include <functional>
void
produce_data
(
std
::
vector
<
std
::
vector
<
float
>>*
data
);
void
model_forward_test
();
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
model_forward_test
();
return
0
;
}
void
model_forward_test
()
{
std
::
cout
<<
"1. read the data"
<<
std
::
endl
;
std
::
vector
<
std
::
vector
<
float
>>
feats
;
produce_data
(
&
feats
);
std
::
cout
<<
"2. load the model"
<<
std
::
endl
;;
std
::
string
model_graph
=
"../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"
;
std
::
string
model_params
=
"../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"
;
paddle_infer
::
Config
config
;
config
.
SetModel
(
model_graph
,
model_params
);
config
.
SwitchIrOptim
(
false
);
config
.
DisableFCPadding
();
auto
predictor
=
paddle_infer
::
CreatePredictor
(
config
);
std
::
cout
<<
"3. feat shape, row="
<<
feats
.
size
()
<<
",col="
<<
feats
[
0
].
size
()
<<
std
::
endl
;
std
::
vector
<
float
>
paddle_input_feature_matrix
;
for
(
const
auto
&
item
:
feats
)
{
paddle_input_feature_matrix
.
insert
(
paddle_input_feature_matrix
.
end
(),
item
.
begin
(),
item
.
end
());
}
std
::
cout
<<
"4. fead the data to model"
<<
std
::
endl
;
int
row
=
feats
.
size
();
int
col
=
feats
[
0
].
size
();
std
::
vector
<
std
::
string
>
input_names
=
predictor
->
GetInputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
row
,
col
};
input_tensor
->
Reshape
(
INPUT_SHAPE
);
input_tensor
->
CopyFromCpu
(
paddle_input_feature_matrix
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
vector
<
int
>
input_len_size
=
{
1
};
input_len
->
Reshape
(
input_len_size
);
std
::
vector
<
int64_t
>
audio_len
;
audio_len
.
push_back
(
row
);
input_len
->
CopyFromCpu
(
audio_len
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
std
::
vector
<
int
>
chunk_state_h_box_shape
=
{
3
,
1
,
1024
};
chunk_state_h_box
->
Reshape
(
chunk_state_h_box_shape
);
int
chunk_state_h_box_size
=
std
::
accumulate
(
chunk_state_h_box_shape
.
begin
(),
chunk_state_h_box_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
chunk_state_h_box_data
(
chunk_state_h_box_size
,
0.0
f
);
chunk_state_h_box
->
CopyFromCpu
(
chunk_state_h_box_data
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
chunk_state_c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
std
::
vector
<
int
>
chunk_state_c_box_shape
=
{
3
,
1
,
1024
};
chunk_state_c_box
->
Reshape
(
chunk_state_c_box_shape
);
int
chunk_state_c_box_size
=
std
::
accumulate
(
chunk_state_c_box_shape
.
begin
(),
chunk_state_c_box_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
chunk_state_c_box_data
(
chunk_state_c_box_size
,
0.0
f
);
chunk_state_c_box
->
CopyFromCpu
(
chunk_state_c_box_data
.
data
());
bool
success
=
predictor
->
Run
();
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
std
::
vector
<
int
>
h_out_shape
=
h_out
->
shape
();
int
h_out_size
=
std
::
accumulate
(
h_out_shape
.
begin
(),
h_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
h_out_data
(
h_out_size
);
h_out
->
CopyToCpu
(
h_out_data
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
std
::
vector
<
int
>
c_out_shape
=
c_out
->
shape
();
int
c_out_size
=
std
::
accumulate
(
c_out_shape
.
begin
(),
c_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
std
::
vector
<
float
>
c_out_data
(
c_out_size
);
c_out
->
CopyToCpu
(
c_out_data
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
output_tensor
=
predictor
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
output_tensor
->
shape
();
std
::
vector
<
float
>
output_probs
;
int
output_size
=
std
::
accumulate
(
output_shape
.
begin
(),
output_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
output_probs
.
resize
(
output_size
);
output_tensor
->
CopyToCpu
(
output_probs
.
data
());
row
=
output_shape
[
1
];
col
=
output_shape
[
2
];
std
::
vector
<
std
::
vector
<
float
>>
probs
;
probs
.
reserve
(
row
);
for
(
int
i
=
0
;
i
<
row
;
i
++
)
{
probs
.
push_back
(
std
::
vector
<
float
>
());
probs
.
back
().
reserve
(
col
);
for
(
int
j
=
0
;
j
<
col
;
j
++
)
{
probs
.
back
().
push_back
(
output_probs
[
i
*
col
+
j
]);
}
}
std
::
vector
<
std
::
vector
<
float
>>
log_feat
=
probs
;
std
::
cout
<<
"probs, row: "
<<
log_feat
.
size
()
<<
" col: "
<<
log_feat
[
0
].
size
()
<<
std
::
endl
;
for
(
size_t
row_idx
=
0
;
row_idx
<
log_feat
.
size
();
++
row_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
log_feat
[
row_idx
].
size
();
++
col_idx
)
{
std
::
cout
<<
log_feat
[
row_idx
][
col_idx
]
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
}
void
produce_data
(
std
::
vector
<
std
::
vector
<
float
>>*
data
)
{
int
chunk_size
=
35
;
int
col_size
=
161
;
data
->
reserve
(
chunk_size
);
data
->
back
().
reserve
(
col_size
);
for
(
int
row
=
0
;
row
<
chunk_size
;
++
row
)
{
data
->
push_back
(
std
::
vector
<
float
>
());
for
(
int
col_idx
=
0
;
col_idx
<
col_size
;
++
col_idx
)
{
data
->
back
().
push_back
(
0.201
);
}
}
}
speechx/speechx/decoder/CMakeLists.txt
浏览文件 @
406b4fc7
project
(
decoder
)
include_directories
(
${
CMAKE_CURRENT_SOURCE_DIR/ctc_decoders
}
)
add_library
(
decoder
add_library
(
decoder
STATIC
ctc_beam_search_decoder.cc
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
)
target_link_libraries
(
decoder kenlm
)
\ No newline at end of file
target_link_libraries
(
decoder PUBLIC kenlm utils fst
)
\ No newline at end of file
speechx/speechx/frontend/CMakeLists.txt
浏览文件 @
406b4fc7
project
(
frontend
)
add_library
(
frontend
add_library
(
frontend
STATIC
normalizer.cc
linear_spectrogram.cc
)
target_link_libraries
(
frontend kaldi-matrix
)
\ No newline at end of file
target_link_libraries
(
frontend PUBLIC kaldi-matrix
)
\ No newline at end of file
speechx/speechx/frontend/linear_spectrogram.cc
浏览文件 @
406b4fc7
...
...
@@ -47,6 +47,7 @@ void CopyStdVector2Vector_(const vector<BaseFloat>& input,
LinearSpectrogram
::
LinearSpectrogram
(
const
LinearSpectrogramOptions
&
opts
,
std
::
unique_ptr
<
FeatureExtractorInterface
>
base_extractor
)
{
opts_
=
opts
;
base_extractor_
=
std
::
move
(
base_extractor
);
int32
window_size
=
opts
.
frame_opts
.
WindowSize
();
int32
window_shift
=
opts
.
frame_opts
.
WindowShift
();
...
...
@@ -105,7 +106,7 @@ void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
Compute
(
feats_vec
,
result
);
feats
->
Resize
(
result
.
size
(),
result
[
0
].
size
());
for
(
int
row_idx
=
0
;
row_idx
<
result
.
size
();
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
result
.
size
();
++
col_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
result
[
0
]
.
size
();
++
col_idx
)
{
(
*
feats
)(
row_idx
,
col_idx
)
=
result
[
row_idx
][
col_idx
];
}
}
...
...
@@ -133,7 +134,7 @@ bool LinearSpectrogram::Compute(const vector<float>& wave,
const
int
&
sample_rate
=
opts_
.
frame_opts
.
samp_freq
;
const
int
&
frame_shift
=
opts_
.
frame_opts
.
WindowShift
();
const
int
&
fft_points
=
fft_points_
;
const
float
scale
=
hanning_window_energy_
*
frame_shift
;
const
float
scale
=
hanning_window_energy_
*
sample_rate
;
if
(
num_samples
<
frame_length
)
{
return
true
;
...
...
@@ -153,10 +154,7 @@ bool LinearSpectrogram::Compute(const vector<float>& wave,
fft_img
.
clear
();
fft_real
.
clear
();
v
.
assign
(
data
.
begin
(),
data
.
end
());
if
(
NumpyFft
(
&
v
,
&
fft_real
,
&
fft_img
))
{
LOG
(
ERROR
)
<<
i
<<
" fft compute occurs error, please checkout the input data"
;
return
false
;
}
NumpyFft
(
&
v
,
&
fft_real
,
&
fft_img
);
feat
[
i
].
resize
(
fft_points
/
2
+
1
);
// the last dimension is Fs/2 Hz
for
(
int
j
=
0
;
j
<
(
fft_points
/
2
+
1
);
++
j
)
{
...
...
speechx/speechx/frontend/normalizer.h
浏览文件 @
406b4fc7
...
...
@@ -29,7 +29,7 @@ class DecibelNormalizer : public FeatureExtractorInterface {
explicit
DecibelNormalizer
(
const
DecibelNormalizerOptions
&
opts
);
virtual
void
AcceptWaveform
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
);
virtual
void
Read
(
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
);
virtual
size_t
Dim
()
const
{
return
0
;
}
virtual
size_t
Dim
()
const
{
return
dim_
;
}
bool
Compute
(
const
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>&
input
,
kaldi
::
VectorBase
<
kaldi
::
BaseFloat
>*
feat
)
const
;
private:
...
...
speechx/speechx/kaldi/matrix/BUILD
已删除
100644 → 0
浏览文件 @
b584b969
# Copyright (c) 2020 PeachLab. All Rights Reserved.
# Author : goat.zhou@qq.com (Yang Zhou)
package
(
default_visibility
=
[
"//visibility:public"
])
cc_library
(
name
=
'kaldi-matrix'
,
srcs
=
[
'compressed-matrix.cc'
,
'kaldi-matrix.cc'
,
'kaldi-vector.cc'
,
'matrix-functions.cc'
,
'optimization.cc'
,
'packed-matrix.cc'
,
'qr.cc'
,
'sparse-matrix.cc'
,
'sp-matrix.cc'
,
'srfft.cc'
,
'tp-matrix.cc'
,
],
hdrs
=
glob
([
"*.h"
]),
deps
=
[
'//base:kaldi-base'
,
'//common/third_party/openblas:openblas'
,
],
linkopts
=
[
'-lgfortran'
],
)
cc_binary
(
name
=
'matrix-lib-test'
,
srcs
=
[
'matrix-lib-test.cc'
,
],
deps
=
[
':kaldi-matrix'
,
'//util:kaldi-util'
,
],
)
speechx/speechx/kaldi/matrix/CMakeLists.txt
浏览文件 @
406b4fc7
...
...
@@ -13,4 +13,4 @@ srfft.cc
tp-matrix.cc
)
target_link_libraries
(
kaldi-matrix gfortran kaldi-base
libopenblas.a
)
target_link_libraries
(
kaldi-matrix gfortran kaldi-base
${
MATH_LIB
}
)
speechx/speechx/kaldi/matrix/kaldi-blas.h
浏览文件 @
406b4fc7
...
...
@@ -42,7 +42,7 @@
#define HAVE_
OPENBLAS
#define HAVE_
MKL
#if (defined(HAVE_CLAPACK) && (defined(HAVE_ATLAS) || defined(HAVE_MKL))) \
|| (defined(HAVE_ATLAS) && defined(HAVE_MKL))
...
...
speechx/speechx/nnet/CMakeLists.txt
浏览文件 @
406b4fc7
aux_source_directory
(
. DIR_LIB_SRCS
)
add_library
(
nnet STATIC
${
DIR_LIB_SRCS
}
)
project
(
nnet
)
add_library
(
nnet STATIC
decodable.cc
paddle_nnet.cc
)
target_link_libraries
(
nnet absl::strings
)
\ No newline at end of file
speechx/speechx/nnet/decodable-itf.h
浏览文件 @
406b4fc7
...
...
@@ -114,7 +114,7 @@ class DecodableInterface {
/// this is for compatibility with OpenFst).
virtual
int32
NumIndices
()
const
=
0
;
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
);
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
)
=
0
;
virtual
~
DecodableInterface
()
{}
};
...
...
speechx/speechx/nnet/decodable.cc
浏览文件 @
406b4fc7
...
...
@@ -38,6 +38,15 @@ void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
return
;
}
std
::
vector
<
BaseFloat
>
Decodable
::
FrameLogLikelihood
(
int32
frame
)
{
std
::
vector
<
BaseFloat
>
result
;
result
.
reserve
(
nnet_cache_
.
NumCols
());
for
(
int32
idx
=
0
;
idx
<
nnet_cache_
.
NumCols
();
++
idx
)
{
result
[
idx
]
=
nnet_cache_
(
frame
,
idx
);
}
return
result
;
}
void
Decodable
::
Reset
()
{
// frontend_.Reset();
nnet_
->
Reset
();
...
...
speechx/speechx/nnet/decodable.h
浏览文件 @
406b4fc7
...
...
@@ -15,9 +15,9 @@ class Decodable : public kaldi::DecodableInterface {
virtual
kaldi
::
BaseFloat
LogLikelihood
(
int32
frame
,
int32
index
);
virtual
bool
IsLastFrame
(
int32
frame
)
const
;
virtual
int32
NumIndices
()
const
;
virtual
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
);
void
Acceptlikelihood
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
likelihood
);
// remove later
void
FeedFeatures
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
feature
);
// only for test, todo remove later
std
::
vector
<
BaseFloat
>
FrameLogLikelihood
(
int32
frame
);
void
Reset
();
void
InputFinished
()
{
finished_
=
true
;
}
private:
...
...
speechx/speechx/nnet/nnet_interface.h
浏览文件 @
406b4fc7
...
...
@@ -9,11 +9,11 @@ namespace ppspeech {
class
NnetInterface
{
public:
virtual
~
NnetInterface
()
{}
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
);
virtual
void
Reset
();
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
)
=
0
;
virtual
void
Reset
()
=
0
;
virtual
~
NnetInterface
()
{}
};
}
// namespace ppspeech
\ No newline at end of file
}
// namespace ppspeech
speechx/speechx/nnet/paddle_nnet.cc
浏览文件 @
406b4fc7
...
...
@@ -10,14 +10,16 @@ using kaldi::Matrix;
void
PaddleNnet
::
InitCacheEncouts
(
const
ModelOptions
&
opts
)
{
std
::
vector
<
std
::
string
>
cache_names
;
cache_names
=
absl
::
StrSplit
(
opts
.
cache_names
,
",
"
);
cache_names
=
absl
::
StrSplit
(
opts
.
cache_names
,
","
);
std
::
vector
<
std
::
string
>
cache_shapes
;
cache_shapes
=
absl
::
StrSplit
(
opts
.
cache_shape
,
",
"
);
cache_shapes
=
absl
::
StrSplit
(
opts
.
cache_shape
,
","
);
assert
(
cache_shapes
.
size
()
==
cache_names
.
size
());
cache_encouts_
.
clear
();
cache_names_idx_
.
clear
();
for
(
size_t
i
=
0
;
i
<
cache_shapes
.
size
();
i
++
)
{
std
::
vector
<
std
::
string
>
tmp_shape
;
tmp_shape
=
absl
::
StrSplit
(
cache_shapes
[
i
],
"-
"
);
tmp_shape
=
absl
::
StrSplit
(
cache_shapes
[
i
],
"-"
);
std
::
vector
<
int
>
cur_shape
;
std
::
transform
(
tmp_shape
.
begin
(),
tmp_shape
.
end
(),
std
::
back_inserter
(
cur_shape
),
...
...
@@ -30,14 +32,14 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
}
}
PaddleNnet
::
PaddleNnet
(
const
ModelOptions
&
opts
)
{
PaddleNnet
::
PaddleNnet
(
const
ModelOptions
&
opts
)
:
opts_
(
opts
)
{
paddle_infer
::
Config
config
;
config
.
SetModel
(
opts
.
model_path
,
opts
.
params_path
);
if
(
opts
.
use_gpu
)
{
config
.
EnableUseGpu
(
500
,
0
);
}
config
.
SwitchIrOptim
(
opts
.
switch_ir_optim
);
if
(
opts
.
enable_fc_padding
)
{
if
(
opts
.
enable_fc_padding
==
false
)
{
config
.
DisableFCPadding
();
}
if
(
opts
.
enable_profile
)
{
...
...
@@ -54,8 +56,8 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts) {
LOG
(
INFO
)
<<
"start to check the predictor input and output names"
;
LOG
(
INFO
)
<<
"input names: "
<<
opts
.
input_names
;
LOG
(
INFO
)
<<
"output names: "
<<
opts
.
output_names
;
vector
<
string
>
input_names_vec
=
absl
::
StrSplit
(
opts
.
input_names
,
",
"
);
vector
<
string
>
output_names_vec
=
absl
::
StrSplit
(
opts
.
output_names
,
",
"
);
vector
<
string
>
input_names_vec
=
absl
::
StrSplit
(
opts
.
input_names
,
","
);
vector
<
string
>
output_names_vec
=
absl
::
StrSplit
(
opts
.
output_names
,
","
);
paddle_infer
::
Predictor
*
predictor
=
GetPredictor
();
std
::
vector
<
std
::
string
>
model_input_names
=
predictor
->
GetInputNames
();
...
...
@@ -70,10 +72,13 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts) {
assert
(
output_names_vec
[
i
]
==
model_output_names
[
i
]);
}
ReleasePredictor
(
predictor
);
InitCacheEncouts
(
opts
);
}
void
PaddleNnet
::
Reset
()
{
InitCacheEncouts
(
opts_
);
}
paddle_infer
::
Predictor
*
PaddleNnet
::
GetPredictor
()
{
LOG
(
INFO
)
<<
"attempt to get a new predictor instance "
<<
std
::
endl
;
paddle_infer
::
Predictor
*
predictor
=
nullptr
;
...
...
@@ -126,57 +131,71 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
}
void
PaddleNnet
::
FeedForward
(
const
Matrix
<
BaseFloat
>&
features
,
Matrix
<
BaseFloat
>*
inferences
)
{
paddle_infer
::
Predictor
*
predictor
=
GetPredictor
();
// 1. 得到所有的 input tensor 的名称
int
row
=
features
.
NumRows
();
int
col
=
features
.
NumCols
();
std
::
vector
<
std
::
string
>
input_names
=
predictor
->
GetInputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
LOG
(
INFO
)
<<
"feat info: row="
<<
row
<<
", col="
<<
col
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
row
,
col
};
input_tensor
->
Reshape
(
INPUT_SHAPE
);
input_tensor
->
CopyFromCpu
(
features
.
Data
());
// 3. 输入每个音频帧数
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
vector
<
int
>
input_len_size
=
{
1
};
input_len
->
Reshape
(
input_len_size
);
std
::
vector
<
int64_t
>
audio_len
;
audio_len
.
push_back
(
row
);
input_len
->
CopyFromCpu
(
audio_len
.
data
());
// 输入流式的缓存数据
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
shared_ptr
<
Tensor
<
BaseFloat
>>
h_cache
=
GetCacheEncoder
(
input_names
[
2
]);
h_box
->
Reshape
(
h_cache
->
get_shape
());
h_box
->
CopyFromCpu
(
h_cache
->
get_data
().
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
shared_ptr
<
Tensor
<
float
>>
c_cache
=
GetCacheEncoder
(
input_names
[
3
]);
c_box
->
Reshape
(
c_cache
->
get_shape
());
c_box
->
CopyFromCpu
(
c_cache
->
get_data
().
data
());
bool
success
=
predictor
->
Run
();
if
(
success
==
false
)
{
LOG
(
INFO
)
<<
"predictor run occurs error"
;
paddle_infer
::
Predictor
*
predictor
=
GetPredictor
();
int
row
=
features
.
NumRows
();
int
col
=
features
.
NumCols
();
std
::
vector
<
BaseFloat
>
feed_feature
;
// todo refactor feed feature: SmileGoat
feed_feature
.
reserve
(
row
*
col
);
for
(
size_t
row_idx
=
0
;
row_idx
<
features
.
NumRows
();
++
row_idx
)
{
for
(
size_t
col_idx
=
0
;
col_idx
<
features
.
NumCols
();
++
col_idx
)
{
feed_feature
.
push_back
(
features
(
row_idx
,
col_idx
));
}
}
std
::
vector
<
std
::
string
>
input_names
=
predictor
->
GetInputNames
();
std
::
vector
<
std
::
string
>
output_names
=
predictor
->
GetOutputNames
();
LOG
(
INFO
)
<<
"feat info: row="
<<
row
<<
", col= "
<<
col
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_tensor
=
predictor
->
GetInputHandle
(
input_names
[
0
]);
std
::
vector
<
int
>
INPUT_SHAPE
=
{
1
,
row
,
col
};
input_tensor
->
Reshape
(
INPUT_SHAPE
);
input_tensor
->
CopyFromCpu
(
feed_feature
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
input_len
=
predictor
->
GetInputHandle
(
input_names
[
1
]);
std
::
vector
<
int
>
input_len_size
=
{
1
};
input_len
->
Reshape
(
input_len_size
);
std
::
vector
<
int64_t
>
audio_len
;
audio_len
.
push_back
(
row
);
input_len
->
CopyFromCpu
(
audio_len
.
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_box
=
predictor
->
GetInputHandle
(
input_names
[
2
]);
shared_ptr
<
Tensor
<
BaseFloat
>>
h_cache
=
GetCacheEncoder
(
input_names
[
2
]);
h_box
->
Reshape
(
h_cache
->
get_shape
());
h_box
->
CopyFromCpu
(
h_cache
->
get_data
().
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_box
=
predictor
->
GetInputHandle
(
input_names
[
3
]);
shared_ptr
<
Tensor
<
float
>>
c_cache
=
GetCacheEncoder
(
input_names
[
3
]);
c_box
->
Reshape
(
c_cache
->
get_shape
());
c_box
->
CopyFromCpu
(
c_cache
->
get_data
().
data
());
bool
success
=
predictor
->
Run
();
if
(
success
==
false
)
{
LOG
(
INFO
)
<<
"predictor run occurs error"
;
}
LOG
(
INFO
)
<<
"get the model success"
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
assert
(
h_cache
->
get_shape
()
==
h_out
->
shape
());
h_out
->
CopyToCpu
(
h_cache
->
get_data
().
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
assert
(
c_cache
->
get_shape
()
==
c_out
->
shape
());
c_out
->
CopyToCpu
(
c_cache
->
get_data
().
data
());
// 5. 得到最后的输出结果
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
output_tensor
=
predictor
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
output_tensor
->
shape
();
row
=
output_shape
[
1
];
col
=
output_shape
[
2
];
inferences
->
Resize
(
row
,
col
);
output_tensor
->
CopyToCpu
(
inferences
->
Data
());
ReleasePredictor
(
predictor
);
LOG
(
INFO
)
<<
"get the model success"
;
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
h_out
=
predictor
->
GetOutputHandle
(
output_names
[
2
]);
assert
(
h_cache
->
get_shape
()
==
h_out
->
shape
());
h_out
->
CopyToCpu
(
h_cache
->
get_data
().
data
());
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
c_out
=
predictor
->
GetOutputHandle
(
output_names
[
3
]);
assert
(
c_cache
->
get_shape
()
==
c_out
->
shape
());
c_out
->
CopyToCpu
(
c_cache
->
get_data
().
data
());
// get result
std
::
unique_ptr
<
paddle_infer
::
Tensor
>
output_tensor
=
predictor
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
output_tensor
->
shape
();
row
=
output_shape
[
1
];
col
=
output_shape
[
2
];
vector
<
float
>
inferences_result
;
inferences
->
Resize
(
row
,
col
);
inferences_result
.
resize
(
row
*
col
);
output_tensor
->
CopyToCpu
(
inferences_result
.
data
());
ReleasePredictor
(
predictor
);
for
(
int
row_idx
=
0
;
row_idx
<
row
;
++
row_idx
)
{
for
(
int
col_idx
=
0
;
col_idx
<
col
;
++
col_idx
)
{
(
*
inferences
)(
row_idx
,
col_idx
)
=
inferences_result
[
col
*
row_idx
+
col_idx
];
}
}
}
}
// namespace ppspeech
\ No newline at end of file
speechx/speechx/nnet/paddle_nnet.h
浏览文件 @
406b4fc7
...
...
@@ -25,14 +25,14 @@ struct ModelOptions {
bool
enable_fc_padding
;
bool
enable_profile
;
ModelOptions
()
:
model_path
(
"
model/final.zip
"
),
params_path
(
"
model/avg_1.jit.pdmodel
"
),
model_path
(
"
../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel
"
),
params_path
(
"
../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams
"
),
thread_num
(
2
),
use_gpu
(
false
),
input_names
(
"audio"
),
output_names
(
"
probs
"
),
cache_names
(
"
enouts
"
),
cache_shape
(
"
1-1-1
"
),
input_names
(
"audio
_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box
"
),
output_names
(
"
save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1
"
),
cache_names
(
"
chunk_state_h_box,chunk_state_c_box
"
),
cache_shape
(
"
3-1-1024,3-1-1024
"
),
switch_ir_optim
(
false
),
enable_fc_padding
(
false
),
enable_profile
(
false
)
{
...
...
@@ -87,6 +87,7 @@ class PaddleNnet : public NnetInterface {
PaddleNnet
(
const
ModelOptions
&
opts
);
virtual
void
FeedForward
(
const
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>&
features
,
kaldi
::
Matrix
<
kaldi
::
BaseFloat
>*
inferences
);
virtual
void
Reset
();
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>
GetCacheEncoder
(
const
std
::
string
&
name
);
void
InitCacheEncouts
(
const
ModelOptions
&
opts
);
...
...
@@ -100,6 +101,7 @@ class PaddleNnet : public NnetInterface {
std
::
map
<
paddle_infer
::
Predictor
*
,
int
>
predictor_to_thread_id
;
std
::
map
<
std
::
string
,
int
>
cache_names_idx_
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
<
kaldi
::
BaseFloat
>>>
cache_encouts_
;
ModelOptions
opts_
;
public:
DISALLOW_COPY_AND_ASSIGN
(
PaddleNnet
);
...
...
speechx/speechx/utils/file_utils.cc
浏览文件 @
406b4fc7
#include "utils/file_utils.h"
namespace
ppspeech
{
bool
ReadFileToVector
(
const
std
::
string
&
filename
,
std
::
vector
<
std
::
string
>*
vocabulary
)
{
std
::
ifstream
file_in
(
filename
);
...
...
@@ -15,3 +17,5 @@ bool ReadFileToVector(const std::string& filename,
return
true
;
}
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录