Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
1f4f98b1
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1f4f98b1
编写于
10月 08, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug
上级
e86337a4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
22 addition
and
19 deletion
+22
-19
paddlespeech/s2t/exps/u2/bin/quant.py
paddlespeech/s2t/exps/u2/bin/quant.py
+12
-6
paddlespeech/s2t/models/u2/u2.py
paddlespeech/s2t/models/u2/u2.py
+10
-12
paddlespeech/server/engine/asr/online/python/asr_engine.py
paddlespeech/server/engine/asr/online/python/asr_engine.py
+0
-1
未找到文件。
paddlespeech/s2t/exps/u2/bin/quant.py
浏览文件 @
1f4f98b1
...
...
@@ -18,6 +18,7 @@ from pathlib import Path
import
paddle
import
soundfile
from
paddleslim
import
PTQ
from
yacs.config
import
CfgNode
from
paddlespeech.audio.transform.transformation
import
Transformation
...
...
@@ -26,7 +27,6 @@ from paddlespeech.s2t.models.u2 import U2Model
from
paddlespeech.s2t.training.cli
import
default_argument_parser
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddleslim
import
PTQ
logger
=
Log
(
__name__
).
getlog
()
...
...
@@ -90,14 +90,14 @@ class U2Infer():
ctc_weight
=
decode_config
.
ctc_weight
,
decoding_chunk_size
=
decode_config
.
decoding_chunk_size
,
num_decoding_left_chunks
=
decode_config
.
num_decoding_left_chunks
,
simulate_streaming
=
decode_config
.
simulate_streaming
simulate_streaming
=
decode_config
.
simulate_streaming
,
reverse_weight
=
decode_config
.
reverse_weight
)
rsl
=
result_transcripts
[
0
][
0
]
utt
=
Path
(
self
.
audio_file
).
name
logger
.
info
(
f
"hyp:
{
utt
}
{
rsl
}
"
)
# print(self.model)
# print(self.model.forward_encoder_chunk)
logger
.
info
(
"-------------start quant ----------------------"
)
batch_size
=
1
feat_dim
=
80
...
...
@@ -161,7 +161,11 @@ class U2Infer():
# jit save
logger
.
info
(
f
"export save:
{
self
.
args
.
export_path
}
"
)
config
=
{
'is_static'
:
True
,
'combine_params'
:
True
,
'skip_forward'
:
True
}
config
=
{
'is_static'
:
True
,
'combine_params'
:
True
,
'skip_forward'
:
True
}
self
.
ptq
.
save_quantized_model
(
self
.
model
,
self
.
args
.
export_path
)
# paddle.jit.save(
# self.model,
...
...
@@ -169,7 +173,6 @@ class U2Infer():
# combine_params=True,
# skip_forward=True)
def
check
(
audio_file
):
if
not
os
.
path
.
isfile
(
audio_file
):
...
...
@@ -201,7 +204,10 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--audio_file"
,
type
=
str
,
help
=
"path of the input audio file"
)
parser
.
add_argument
(
"--export_path"
,
type
=
str
,
default
=
'export'
,
help
=
"path of the input audio file"
)
"--export_path"
,
type
=
str
,
default
=
'export'
,
help
=
"path of the input audio file"
)
args
=
parser
.
parse_args
()
config
=
CfgNode
(
new_allowed
=
True
)
...
...
paddlespeech/s2t/models/u2/u2.py
浏览文件 @
1f4f98b1
...
...
@@ -131,7 +131,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
if
self
.
ctc_weight
!=
1.0
:
start
=
time
.
time
()
loss_att
,
acc_att
=
self
.
_calc_att_loss
(
encoder_out
,
encoder_mask
,
text
,
text_lengths
,
self
.
reverse_weight
)
text
,
text_lengths
,
self
.
reverse_weight
)
decoder_time
=
time
.
time
()
-
start
#logger.debug(f"decoder time: {decoder_time}")
...
...
@@ -152,13 +153,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
loss
=
self
.
ctc_weight
*
loss_ctc
+
(
1
-
self
.
ctc_weight
)
*
loss_att
return
loss
,
loss_att
,
loss_ctc
def
_calc_att_loss
(
self
,
encoder_out
:
paddle
.
Tensor
,
encoder_mask
:
paddle
.
Tensor
,
ys_pad
:
paddle
.
Tensor
,
ys_pad_lens
:
paddle
.
Tensor
,
reverse_weight
:
float
)
->
Tuple
[
paddle
.
Tensor
,
float
]:
def
_calc_att_loss
(
self
,
encoder_out
:
paddle
.
Tensor
,
encoder_mask
:
paddle
.
Tensor
,
ys_pad
:
paddle
.
Tensor
,
ys_pad_lens
:
paddle
.
Tensor
,
reverse_weight
:
float
)
->
Tuple
[
paddle
.
Tensor
,
float
]:
"""Calc attention loss.
Args:
...
...
@@ -188,8 +188,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
r_loss_att
=
paddle
.
to_tensor
(
0.0
)
if
reverse_weight
>
0.0
:
r_loss_att
=
self
.
criterion_att
(
r_decoder_out
,
r_ys_out_pad
)
loss_att
=
loss_att
*
(
1
-
reverse_weight
)
+
r_loss_att
*
reverse_weight
loss_att
=
loss_att
*
(
1
-
reverse_weight
)
+
r_loss_att
*
reverse_weight
acc_att
=
th_accuracy
(
decoder_out
.
view
(
-
1
,
self
.
vocab_size
),
ys_out_pad
,
...
...
@@ -599,8 +598,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
f
"hyp
{
i
}
len
{
len
(
hyp
[
0
])
}
r2l score:
{
r_score
}
ctc_score:
{
hyp
[
1
]
}
reverse_weight:
{
reverse_weight
}
"
)
score
=
score
*
(
1
-
reverse_weight
)
+
r_score
*
reverse_weight
score
=
score
*
(
1
-
reverse_weight
)
+
r_score
*
reverse_weight
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
ctc_weight
if
score
>
best_score
:
...
...
paddlespeech/server/engine/asr/online/python/asr_engine.py
浏览文件 @
1f4f98b1
...
...
@@ -22,7 +22,6 @@ from numpy import float32
from
yacs.config
import
CfgNode
from
paddlespeech.audio.transform.transformation
import
Transformation
from
paddlespeech.audio.utils.tensor_utils
import
st_reverse_pad_list
from
paddlespeech.cli.asr.infer
import
ASRExecutor
from
paddlespeech.cli.log
import
logger
from
paddlespeech.resource
import
CommonTaskResource
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录