Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
57dcd0d1
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
57dcd0d1
编写于
9月 20, 2022
作者:
Z
Zhao Yuting
提交者:
GitHub
9月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update infer.py
change the infer in order to implement the new faster model for text
上级
b627666c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
82 addition
and
9 deletion
+82
-9
paddlespeech/cli/text/infer.py
paddlespeech/cli/text/infer.py
+82
-9
未找到文件。
paddlespeech/cli/text/infer.py
浏览文件 @
57dcd0d1
...
...
@@ -20,10 +20,13 @@ from typing import Optional
from
typing
import
Union
import
paddle
import
yaml
from
yacs.config
import
CfgNode
from
..executor
import
BaseExecutor
from
..log
import
logger
from
..utils
import
stats_wrapper
from
paddlespeech.text.models.ernie_linear
import
ErnieLinear
__all__
=
[
'TextExecutor'
]
...
...
@@ -139,6 +142,66 @@ class TextExecutor(BaseExecutor):
self
.
model
.
eval
()
#init new models
def
_init_from_path_new
(
self
,
task
:
str
=
'punc'
,
model_type
:
str
=
'ernie_linear_p7_wudao'
,
lang
:
str
=
'zh'
,
cfg_path
:
Optional
[
os
.
PathLike
]
=
None
,
ckpt_path
:
Optional
[
os
.
PathLike
]
=
None
,
vocab_file
:
Optional
[
os
.
PathLike
]
=
None
):
if
hasattr
(
self
,
'model'
):
logger
.
debug
(
'Model had been initialized.'
)
return
self
.
task
=
task
if
cfg_path
is
None
or
ckpt_path
is
None
or
vocab_file
is
None
:
tag
=
'-'
.
join
([
model_type
,
task
,
lang
])
self
.
task_resource
.
set_task_model
(
tag
,
version
=
None
)
self
.
cfg_path
=
os
.
path
.
join
(
self
.
task_resource
.
res_dir
,
self
.
task_resource
.
res_dict
[
'cfg_path'
])
self
.
ckpt_path
=
os
.
path
.
join
(
self
.
task_resource
.
res_dir
,
self
.
task_resource
.
res_dict
[
'ckpt_path'
])
self
.
vocab_file
=
os
.
path
.
join
(
self
.
task_resource
.
res_dir
,
self
.
task_resource
.
res_dict
[
'vocab_file'
])
else
:
self
.
cfg_path
=
os
.
path
.
abspath
(
cfg_path
)
self
.
ckpt_path
=
os
.
path
.
abspath
(
ckpt_path
)
self
.
vocab_file
=
os
.
path
.
abspath
(
vocab_file
)
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
if
self
.
task
==
'punc'
:
# punc list
self
.
_punc_list
=
[]
with
open
(
self
.
vocab_file
,
'r'
)
as
f
:
for
line
in
f
:
self
.
_punc_list
.
append
(
line
.
strip
())
# model
with
open
(
self
.
cfg_path
)
as
f
:
config
=
CfgNode
(
yaml
.
safe_load
(
f
))
self
.
model
=
ErnieLinear
(
**
config
[
"model"
])
_
,
tokenizer_class
=
self
.
task_resource
.
get_model_class
(
model_name
)
state_dict
=
paddle
.
load
(
self
.
ckpt_path
)
self
.
model
.
set_state_dict
(
state_dict
[
"main_params"
])
self
.
model
.
eval
()
#tokenizer: fast version: ernie-3.0-mini-zh slow version:ernie-1.0
if
'fast'
not
in
model_type
:
self
.
tokenizer
=
tokenizer_class
.
from_pretrained
(
'ernie-1.0'
)
else
:
self
.
tokenizer
=
tokenizer_class
.
from_pretrained
(
'ernie-3.0-mini-zh'
)
else
:
raise
NotImplementedError
def
_clean_text
(
self
,
text
):
text
=
text
.
lower
()
text
=
re
.
sub
(
'[^A-Za-z0-9
\u4e00
-
\u9fa5
]'
,
''
,
text
)
...
...
@@ -179,7 +242,7 @@ class TextExecutor(BaseExecutor):
else
:
raise
NotImplementedError
def
postprocess
(
self
)
->
Union
[
str
,
os
.
PathLike
]:
def
postprocess
(
self
,
isNewTrainer
:
bool
=
False
)
->
Union
[
str
,
os
.
PathLike
]:
"""
Output postprocess and return human-readable results such as texts and audio files.
"""
...
...
@@ -192,13 +255,13 @@ class TextExecutor(BaseExecutor):
input_ids
[
1
:
seq_len
-
1
])
labels
=
preds
[
1
:
seq_len
-
1
].
tolist
()
assert
len
(
tokens
)
==
len
(
labels
)
if
isNewTrainer
:
self
.
_punc_list
=
[
0
]
+
self
.
_punc_list
text
=
''
for
t
,
l
in
zip
(
tokens
,
labels
):
text
+=
t
if
l
!=
0
:
# Non punc.
text
+=
self
.
_punc_list
[
l
]
return
text
else
:
raise
NotImplementedError
...
...
@@ -255,10 +318,20 @@ class TextExecutor(BaseExecutor):
"""
Python API to call an executor.
"""
paddle
.
set_device
(
device
)
self
.
_init_from_path
(
task
,
model
,
lang
,
config
,
ckpt_path
,
punc_vocab
)
self
.
preprocess
(
text
)
self
.
infer
()
res
=
self
.
postprocess
()
# Retrieve result of text task.
#Here is old version models
if
model
in
[
'ernie_linear_p7_wudao'
,
'ernie_linear_p3_wudao'
]:
paddle
.
set_device
(
device
)
self
.
_init_from_path
(
task
,
model
,
lang
,
config
,
ckpt_path
,
punc_vocab
)
self
.
preprocess
(
text
)
self
.
infer
()
res
=
self
.
postprocess
()
# Retrieve result of text task.
#Add new way to infer
else
:
paddle
.
set_device
(
device
)
self
.
_init_from_path_new
(
task
,
model
,
lang
,
config
,
ckpt_path
,
punc_vocab
)
self
.
preprocess
(
text
)
self
.
infer
()
res
=
self
.
postprocess
(
isNewTrainer
=
True
)
return
res
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录