Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
0a8ef58a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a8ef58a
编写于
10月 18, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove uesless code
上级
36af34b2
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
10 addition
and
100 deletion
+10
-100
speechx/speechx/nnet/u2_nnet.cc
speechx/speechx/nnet/u2_nnet.cc
+2
-81
speechx/speechx/nnet/u2_nnet.h
speechx/speechx/nnet/u2_nnet.h
+8
-19
未找到文件。
speechx/speechx/nnet/u2_nnet.cc
浏览文件 @
0a8ef58a
...
@@ -25,65 +25,6 @@ using paddle::platform::TracerEventType;
...
@@ -25,65 +25,6 @@ using paddle::platform::TracerEventType;
namespace
ppspeech
{
namespace
ppspeech
{
int
U2NnetBase
::
num_frames_for_chunk
(
bool
start
)
const
{
int
num_needed_frames
=
0
;
// num feat frames
bool
first
=
!
start
;
// start == false is first
if
(
chunk_size_
>
0
)
{
// streaming mode
if
(
first
)
{
// first chunk
// 1 decoder frame need `context` feat frames
int
context
=
this
->
context
();
num_needed_frames
=
(
chunk_size_
-
1
)
*
subsampling_rate_
+
context
;
}
else
{
// after first chunk, we need stride this num frames.
num_needed_frames
=
chunk_size_
*
subsampling_rate_
;
}
}
else
{
// non-streaming mode. feed all feats once.
num_needed_frames
=
std
::
numeric_limits
<
int
>::
max
();
}
return
num_needed_frames
;
}
// cache feats for next chunk
void
U2NnetBase
::
CacheFeature
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
chunk_feats
,
int32
feat_dim
)
{
// chunk_feats is nframes*feat_dim
const
int
chunk_size
=
chunk_feats
.
size
()
/
feat_dim
;
const
int
cached_feat_size
=
this
->
context
()
-
subsampling_rate_
;
if
(
chunk_size
>=
cached_feat_size
)
{
cached_feats_
.
resize
(
cached_feat_size
);
for
(
int
i
=
0
;
i
<
cached_feat_size
;
++
i
)
{
auto
start
=
chunk_feats
.
begin
()
+
chunk_size
-
cached_feat_size
+
i
;
auto
end
=
start
+
feat_dim
;
cached_feats_
[
i
]
=
std
::
vector
<
float
>
(
start
,
end
);
}
}
}
void
U2NnetBase
::
ForwardEncoderChunk
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
chunk_feats
,
const
int32
&
feat_dim
,
std
::
vector
<
kaldi
::
BaseFloat
>*
ctc_probs
,
int32
*
vocab_dim
)
{
ctc_probs
->
clear
();
// int num_frames = cached_feats_.size() + chunk_feats.size();
int
num_frames
=
chunk_feats
.
size
()
/
feat_dim
;
VLOG
(
3
)
<<
"foward encoder chunk: "
<<
num_frames
<<
" frames"
;
VLOG
(
3
)
<<
"context: "
<<
this
->
context
()
<<
" frames"
;
if
(
num_frames
>=
this
->
context
())
{
this
->
ForwardEncoderChunkImpl
(
chunk_feats
,
feat_dim
,
ctc_probs
,
vocab_dim
);
VLOG
(
3
)
<<
"after forward chunk"
;
this
->
CacheFeature
(
chunk_feats
,
feat_dim
);
}
}
void
U2Nnet
::
LoadModel
(
const
std
::
string
&
model_path_w_prefix
)
{
void
U2Nnet
::
LoadModel
(
const
std
::
string
&
model_path_w_prefix
)
{
paddle
::
jit
::
utils
::
InitKernelSignatureMap
();
paddle
::
jit
::
utils
::
InitKernelSignatureMap
();
...
@@ -188,7 +129,7 @@ U2Nnet::U2Nnet(const U2Nnet& other) {
...
@@ -188,7 +129,7 @@ U2Nnet::U2Nnet(const U2Nnet& other) {
forward_attention_decoder_
=
other
.
forward_attention_decoder_
;
forward_attention_decoder_
=
other
.
forward_attention_decoder_
;
ctc_activation_
=
other
.
ctc_activation_
;
ctc_activation_
=
other
.
ctc_activation_
;
// offset_ = other.offset_; // TODO: not used in nnets
offset_
=
other
.
offset_
;
// copy model ptr
// copy model ptr
model_
=
other
.
model_
;
model_
=
other
.
model_
;
...
@@ -204,8 +145,7 @@ std::shared_ptr<NnetBase> U2Nnet::Copy() const {
...
@@ -204,8 +145,7 @@ std::shared_ptr<NnetBase> U2Nnet::Copy() const {
}
}
void
U2Nnet
::
Reset
()
{
void
U2Nnet
::
Reset
()
{
// offset_ = 0;
offset_
=
0
;
// cached_feats_.clear(); // TODO: not used in nnets
att_cache_
=
att_cache_
=
std
::
move
(
paddle
::
zeros
({
0
,
0
,
0
,
0
},
paddle
::
DataType
::
FLOAT32
));
std
::
move
(
paddle
::
zeros
({
0
,
0
,
0
,
0
},
paddle
::
DataType
::
FLOAT32
));
...
@@ -263,16 +203,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
...
@@ -263,16 +203,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
paddle
::
zeros
({
1
,
num_frames
,
feat_dim
},
paddle
::
DataType
::
FLOAT32
);
paddle
::
zeros
({
1
,
num_frames
,
feat_dim
},
paddle
::
DataType
::
FLOAT32
);
float
*
feats_ptr
=
feats
.
mutable_data
<
float
>
();
float
*
feats_ptr
=
feats
.
mutable_data
<
float
>
();
// for (size_t i = 0; i < cached_feats_.size(); ++i) {
// float* row = feats_ptr + i * feat_dim;
// std::memcpy(row, cached_feats_[i].data(), feat_dim * sizeof(float));
// }
// for (size_t i = 0; i < chunk_feats.size(); ++i) {
// float* row = feats_ptr + (cached_feats_.size() + i) * feat_dim;
// std::memcpy(row, chunk_feats[i].data(), feat_dim * sizeof(float));
// }
// not cache feature in nnet
// not cache feature in nnet
CHECK
(
cached_feats_
.
size
()
==
0
);
CHECK
(
cached_feats_
.
size
()
==
0
);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
...
@@ -427,15 +357,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
...
@@ -427,15 +357,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
float
*
ctc_log_probs_ptr
=
ctc_log_probs
.
data
<
float
>
();
float
*
ctc_log_probs_ptr
=
ctc_log_probs
.
data
<
float
>
();
// // vector<vector<float>>
// out_prob->resize(T);
// for (int i = 0; i < T; i++) {
// (*out_prob)[i].resize(D);
// float* dst_ptr = (*out_prob)[i].data();
// float* src_ptr = ctc_log_probs_ptr + (i * D);
// std::memcpy(dst_ptr, src_ptr, D * sizeof(float));
// }
// CHECK(std::is_same<float, kaldi::BaseFloat>::value);
out_prob
->
resize
(
T
*
D
);
out_prob
->
resize
(
T
*
D
);
std
::
memcpy
(
std
::
memcpy
(
out_prob
->
data
(),
ctc_log_probs_ptr
,
T
*
D
*
sizeof
(
kaldi
::
BaseFloat
));
out_prob
->
data
(),
ctc_log_probs_ptr
,
T
*
D
*
sizeof
(
kaldi
::
BaseFloat
));
...
...
speechx/speechx/nnet/u2_nnet.h
浏览文件 @
0a8ef58a
...
@@ -28,29 +28,21 @@ namespace ppspeech {
...
@@ -28,29 +28,21 @@ namespace ppspeech {
class
U2NnetBase
:
public
NnetBase
{
class
U2NnetBase
:
public
NnetBase
{
public:
public:
virtual
int
c
ontext
()
const
{
return
right_context_
+
1
;
}
virtual
int
C
ontext
()
const
{
return
right_context_
+
1
;
}
virtual
int
right_c
ontext
()
const
{
return
right_context_
;
}
virtual
int
RightC
ontext
()
const
{
return
right_context_
;
}
virtual
int
eos
()
const
{
return
eos_
;
}
virtual
int
EOS
()
const
{
return
eos_
;
}
virtual
int
sos
()
const
{
return
sos_
;
}
virtual
int
SOS
()
const
{
return
sos_
;
}
virtual
int
is_b
idecoder
()
const
{
return
is_bidecoder_
;
}
virtual
int
IsB
idecoder
()
const
{
return
is_bidecoder_
;
}
// current offset in decoder frame
// current offset in decoder frame
virtual
int
o
ffset
()
const
{
return
offset_
;
}
virtual
int
O
ffset
()
const
{
return
offset_
;
}
virtual
void
set_chunk_s
ize
(
int
chunk_size
)
{
chunk_size_
=
chunk_size
;
}
virtual
void
SetChunkS
ize
(
int
chunk_size
)
{
chunk_size_
=
chunk_size
;
}
virtual
void
set_num_left_c
hunks
(
int
num_left_chunks
)
{
virtual
void
SetNumLeftC
hunks
(
int
num_left_chunks
)
{
num_left_chunks_
=
num_left_chunks
;
num_left_chunks_
=
num_left_chunks
;
}
}
// start: false, it is the start chunk of one sentence, else true
virtual
int
num_frames_for_chunk
(
bool
start
)
const
;
virtual
std
::
shared_ptr
<
NnetBase
>
Copy
()
const
=
0
;
virtual
std
::
shared_ptr
<
NnetBase
>
Copy
()
const
=
0
;
virtual
void
ForwardEncoderChunk
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
chunk_feats
,
const
int32
&
feat_dim
,
std
::
vector
<
kaldi
::
BaseFloat
>*
ctc_probs
,
int32
*
vocab_dim
);
protected:
protected:
virtual
void
ForwardEncoderChunkImpl
(
virtual
void
ForwardEncoderChunkImpl
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
chunk_feats
,
const
std
::
vector
<
kaldi
::
BaseFloat
>&
chunk_feats
,
...
@@ -58,9 +50,6 @@ class U2NnetBase : public NnetBase {
...
@@ -58,9 +50,6 @@ class U2NnetBase : public NnetBase {
std
::
vector
<
kaldi
::
BaseFloat
>*
ctc_probs
,
std
::
vector
<
kaldi
::
BaseFloat
>*
ctc_probs
,
int32
*
vocab_dim
)
=
0
;
int32
*
vocab_dim
)
=
0
;
virtual
void
CacheFeature
(
const
std
::
vector
<
kaldi
::
BaseFloat
>&
chunk_feats
,
int32
feat_dim
);
protected:
protected:
// model specification
// model specification
int
right_context_
{
0
};
int
right_context_
{
0
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录