Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
6477b6f3
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6477b6f3
编写于
6月 08, 2022
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
onxx rename and prune
上级
28c1794b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
250 addition
and
4 deletion
+250
-4
speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py
speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py
+127
-0
speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py
speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py
+110
-0
speechx/examples/ds2_ol/onnx/local/pd_infer_shape.py
speechx/examples/ds2_ol/onnx/local/pd_infer_shape.py
+0
-0
speechx/examples/ds2_ol/onnx/local/prune.sh
speechx/examples/ds2_ol/onnx/local/prune.sh
+1
-0
speechx/examples/ds2_ol/onnx/local/tonnx.sh
speechx/examples/ds2_ol/onnx/local/tonnx.sh
+2
-0
speechx/examples/ds2_ol/onnx/run.sh
speechx/examples/ds2_ol/onnx/run.sh
+10
-4
未找到文件。
speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py
0 → 100644
浏览文件 @
6477b6f3
#!/usr/bin/env python3 -W ignore::DeprecationWarning
import
argparse
import
copy
import
sys
import
onnx
def
parse_arguments
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model'
,
required
=
True
,
help
=
'Path of directory saved the input model.'
)
parser
.
add_argument
(
'--output_names'
,
required
=
True
,
nargs
=
'+'
,
help
=
'The outputs of pruned model.'
)
parser
.
add_argument
(
'--save_file'
,
required
=
True
,
help
=
'Path to save the new onnx model.'
)
return
parser
.
parse_args
()
if
__name__
==
'__main__'
:
args
=
parse_arguments
()
if
len
(
set
(
args
.
output_names
))
<
len
(
args
.
output_names
):
print
(
"[ERROR] There's dumplicate name in --output_names, which is not allowed."
)
sys
.
exit
(
-
1
)
model
=
onnx
.
load
(
args
.
model
)
# collect all node outputs and graph output
output_tensor_names
=
set
()
for
node
in
model
.
graph
.
node
:
for
out
in
node
.
output
:
# may contain model output
output_tensor_names
.
add
(
out
)
# for out in model.graph.output:
# output_tensor_names.add(out.name)
for
output_name
in
args
.
output_names
:
if
output_name
not
in
output_tensor_names
:
print
(
"[ERROR] Cannot find output tensor name '{}' in onnx model graph."
.
format
(
output_name
))
sys
.
exit
(
-
1
)
output_node_indices
=
set
()
# has output names
output_to_node
=
dict
()
# all node outputs
for
i
,
node
in
enumerate
(
model
.
graph
.
node
):
for
out
in
node
.
output
:
output_to_node
[
out
]
=
i
if
out
in
args
.
output_names
:
output_node_indices
.
add
(
i
)
# from outputs find all the ancestors
reserved_node_indices
=
copy
.
deepcopy
(
output_node_indices
)
# nodes need to keep
reserved_inputs
=
set
()
# model input to keep
new_output_node_indices
=
copy
.
deepcopy
(
output_node_indices
)
while
True
and
len
(
new_output_node_indices
)
>
0
:
output_node_indices
=
copy
.
deepcopy
(
new_output_node_indices
)
new_output_node_indices
=
set
()
for
out_node_idx
in
output_node_indices
:
# backtrace to parenet
for
ipt
in
model
.
graph
.
node
[
out_node_idx
].
input
:
if
ipt
in
output_to_node
:
reserved_node_indices
.
add
(
output_to_node
[
ipt
])
new_output_node_indices
.
add
(
output_to_node
[
ipt
])
else
:
reserved_inputs
.
add
(
ipt
)
num_inputs
=
len
(
model
.
graph
.
input
)
num_outputs
=
len
(
model
.
graph
.
output
)
num_nodes
=
len
(
model
.
graph
.
node
)
print
(
f
"old graph has
{
num_inputs
}
inputs,
{
num_outputs
}
outpus,
{
num_nodes
}
nodes"
)
print
(
f
"
{
len
(
reserved_node_indices
)
}
node to keep."
)
# del node not to keep
for
idx
in
range
(
num_nodes
-
1
,
-
1
,
-
1
):
if
idx
not
in
reserved_node_indices
:
del
model
.
graph
.
node
[
idx
]
# del graph input not to keep
for
idx
in
range
(
num_inputs
-
1
,
-
1
,
-
1
):
if
model
.
graph
.
input
[
idx
].
name
not
in
reserved_inputs
:
del
model
.
graph
.
input
[
idx
]
# del old graph outputs
for
i
in
range
(
num_outputs
):
del
model
.
graph
.
output
[
0
]
# new graph output as user input
for
out
in
args
.
output_names
:
model
.
graph
.
output
.
extend
([
onnx
.
ValueInfoProto
(
name
=
out
)])
# infer shape
try
:
from
onnx_infer_shape
import
SymbolicShapeInference
model
=
SymbolicShapeInference
.
infer_shapes
(
model
,
int_max
=
2
**
31
-
1
,
auto_merge
=
True
,
guess_output_rank
=
False
,
verbose
=
1
)
except
Exception
as
e
:
print
(
f
"skip infer shape step:
{
e
}
"
)
# check onnx model
onnx
.
checker
.
check_model
(
model
)
# save onnx model
onnx
.
save
(
model
,
args
.
save_file
)
print
(
"[Finished] The new model saved in {}."
.
format
(
args
.
save_file
))
print
(
"[DEBUG INFO] The inputs of new model: {}"
.
format
(
[
x
.
name
for
x
in
model
.
graph
.
input
]))
print
(
"[DEBUG INFO] The outputs of new model: {}"
.
format
(
[
x
.
name
for
x
in
model
.
graph
.
output
]))
speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py
0 → 100755
浏览文件 @
6477b6f3
#!/usr/bin/env python3 -W ignore::DeprecationWarning
import
argparse
import
sys
import
onnx
def
parse_arguments
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model'
,
required
=
True
,
help
=
'Path of directory saved the input model.'
)
parser
.
add_argument
(
'--origin_names'
,
required
=
True
,
nargs
=
'+'
,
help
=
'The original name you want to modify.'
)
parser
.
add_argument
(
'--new_names'
,
required
=
True
,
nargs
=
'+'
,
help
=
'The new name you want change to, the number of new_names should be same with the number of origin_names'
)
parser
.
add_argument
(
'--save_file'
,
required
=
True
,
help
=
'Path to save the new onnx model.'
)
return
parser
.
parse_args
()
if
__name__
==
'__main__'
:
args
=
parse_arguments
()
if
len
(
set
(
args
.
origin_names
))
<
len
(
args
.
origin_names
):
print
(
"[ERROR] There's dumplicate name in --origin_names, which is not allowed."
)
sys
.
exit
(
-
1
)
if
len
(
set
(
args
.
new_names
))
<
len
(
args
.
new_names
):
print
(
"[ERROR] There's dumplicate name in --new_names, which is not allowed."
)
sys
.
exit
(
-
1
)
if
len
(
args
.
new_names
)
!=
len
(
args
.
origin_names
):
print
(
"[ERROR] Number of --new_names must be same with the number of --origin_names."
)
sys
.
exit
(
-
1
)
model
=
onnx
.
load
(
args
.
model
)
# collect input and all node output
output_tensor_names
=
set
()
for
ipt
in
model
.
graph
.
input
:
output_tensor_names
.
add
(
ipt
.
name
)
for
node
in
model
.
graph
.
node
:
for
out
in
node
.
output
:
output_tensor_names
.
add
(
out
)
for
origin_name
in
args
.
origin_names
:
if
origin_name
not
in
output_tensor_names
:
print
(
f
"[ERROR] Cannot find tensor name '
{
origin_name
}
' in onnx model graph."
)
sys
.
exit
(
-
1
)
for
new_name
in
args
.
new_names
:
if
new_name
in
output_tensor_names
:
print
(
"[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
)
sys
.
exit
(
-
1
)
# rename graph input
for
i
,
ipt
in
enumerate
(
model
.
graph
.
input
):
if
ipt
.
name
in
args
.
origin_names
:
idx
=
args
.
origin_names
.
index
(
ipt
.
name
)
model
.
graph
.
input
[
i
].
name
=
args
.
new_names
[
idx
]
# rename node input and output
for
i
,
node
in
enumerate
(
model
.
graph
.
node
):
for
j
,
ipt
in
enumerate
(
node
.
input
):
if
ipt
in
args
.
origin_names
:
idx
=
args
.
origin_names
.
index
(
ipt
)
model
.
graph
.
node
[
i
].
input
[
j
]
=
args
.
new_names
[
idx
]
for
j
,
out
in
enumerate
(
node
.
output
):
if
out
in
args
.
origin_names
:
idx
=
args
.
origin_names
.
index
(
out
)
model
.
graph
.
node
[
i
].
output
[
j
]
=
args
.
new_names
[
idx
]
# rename graph output
for
i
,
out
in
enumerate
(
model
.
graph
.
output
):
if
out
.
name
in
args
.
origin_names
:
idx
=
args
.
origin_names
.
index
(
out
.
name
)
model
.
graph
.
output
[
i
].
name
=
args
.
new_names
[
idx
]
# check onnx model
onnx
.
checker
.
check_model
(
model
)
# save model
onnx
.
save
(
model
,
args
.
save_file
)
print
(
"[Finished] The new model saved in {}."
.
format
(
args
.
save_file
))
print
(
"[DEBUG INFO] The inputs of new model: {}"
.
format
(
[
x
.
name
for
x
in
model
.
graph
.
input
]))
print
(
"[DEBUG INFO] The outputs of new model: {}"
.
format
(
[
x
.
name
for
x
in
model
.
graph
.
output
]))
speechx/examples/ds2_ol/onnx/local/pd_infer_shape.py
100644 → 100755
浏览文件 @
6477b6f3
文件模式从 100644 更改为 100755
speechx/examples/ds2_ol/onnx/local/prune.sh
浏览文件 @
6477b6f3
...
...
@@ -3,6 +3,7 @@
set
-e
if
[
$#
!=
5
]
;
then
# local/prune.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 $PWD
echo
"usage:
$0
model_dir model_filename param_filename outputs_names save_dir"
exit
1
fi
...
...
speechx/examples/ds2_ol/onnx/local/tonnx.sh
浏览文件 @
6477b6f3
#!/bin/bash
if
[
$#
!=
4
]
;
then
# local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx
echo
"usage:
$0
model_dir model_name param_name onnx_output_name"
exit
1
fi
...
...
@@ -11,6 +12,7 @@ param=$3
output
=
$4
pip
install
paddle2onnx
pip
install
onnx
# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2
paddle2onnx
--model_dir
$dir
\
...
...
speechx/examples/ds2_ol/onnx/run.sh
浏览文件 @
6477b6f3
...
...
@@ -10,6 +10,9 @@ stop_stage=100
.
utils/parse_options.sh
data
=
data
exp
=
exp
mkdir
-p
$data
$exp
if
[
${
stage
}
-le
0
]
&&
[
${
stop_stage
}
-ge
0
]
;
then
test
-f
$data
/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
||
wget
-c
https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz
-P
$data
...
...
@@ -25,21 +28,24 @@ param=avg_1.jit.pdiparams
output_names
=
softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0
if
[
${
stage
}
-le
1
]
&&
[
${
stop_stage
}
-ge
1
]
;
then
mkdir
-p
$
data
/prune
mkdir
-p
$
exp
/prune
# prune model deps on output_names.
./local/prune.sh
$dir
$model
$param
$output_names
$
data
/prune
./local/prune.sh
$dir
$model
$param
$output_names
$
exp
/prune
fi
input_shape_dict
=
"{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
if
[
${
stage
}
-le
2
]
&&
[
${
stop_stage
}
-ge
2
]
;
then
mkdir
-p
$
data
/shape
mkdir
-p
$
exp
/shape
python3
local
/pd_infer_shape.py
\
--model_dir
$dir
\
--model_filename
$model
\
--params_filename
$param
\
--save_dir
$
data
/shape
\
--save_dir
$
exp
/shape
\
--input_shape_dict
=
${
input_shape_dict
}
fi
if
[
${
stage
}
-le
3
]
&&
[
${
stop_stage
}
-ge
3
]
;
then
./local/tonnx.sh
$dir
$model
$param
$exp
/model.onnx
fi
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录