提交 b472a148 编写于 作者: H Hui Zhang

format

上级 5a4e35b5
...@@ -34,4 +34,4 @@ For more details please see `run.sh`. ...@@ -34,4 +34,4 @@ For more details please see `run.sh`.
## Outputs ## Outputs
The optimized onnx model is `exp/model.opt.onnx`. The optimized onnx model is `exp/model.opt.onnx`.
To show the graph, please using `local/netron.sh`. To show the graph, please using `local/netron.sh`.
\ No newline at end of file
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import pickle
import numpy as np import numpy as np
import onnxruntime import onnxruntime
import paddle import paddle
import os
import pickle
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -26,26 +27,19 @@ def parse_args(): ...@@ -26,26 +27,19 @@ def parse_args():
'--input_file', '--input_file',
type=str, type=str,
default="static_ds2online_inputs.pickle", default="static_ds2online_inputs.pickle",
help="ds2 input pickle file.", help="ds2 input pickle file.", )
)
parser.add_argument( parser.add_argument(
'--model_dir', '--model_dir', type=str, default=".", help="paddle model dir.")
type=str,
default=".",
help="paddle model dir."
)
parser.add_argument( parser.add_argument(
'--model_prefix', '--model_prefix',
type=str, type=str,
default="avg_1.jit", default="avg_1.jit",
help="paddle model prefix." help="paddle model prefix.")
)
parser.add_argument( parser.add_argument(
'--onnx_model', '--onnx_model',
type=str, type=str,
default='./model.old.onnx', default='./model.old.onnx',
help="onnx model." help="onnx model.")
)
return parser.parse_args() return parser.parse_args()
...@@ -69,19 +63,19 @@ if __name__ == '__main__': ...@@ -69,19 +63,19 @@ if __name__ == '__main__':
paddle.to_tensor(audio_chunk), paddle.to_tensor(audio_chunk),
paddle.to_tensor(audio_chunk_lens), paddle.to_tensor(audio_chunk_lens),
paddle.to_tensor(chunk_state_h_box), paddle.to_tensor(chunk_state_h_box),
paddle.to_tensor(chunk_state_c_box), paddle.to_tensor(chunk_state_c_box), )
)
# onnxruntime # onnxruntime
options = onnxruntime.SessionOptions() options = onnxruntime.SessionOptions()
options.enable_profiling=True options.enable_profiling = True
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options) sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run( ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], ['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
{"audio_chunk": audio_chunk, "audio_chunk": audio_chunk,
"audio_chunk_lens":audio_chunk_lens, "audio_chunk_lens": audio_chunk_lens,
"chunk_state_h_box": chunk_state_h_box, "chunk_state_h_box": chunk_state_h_box,
"chunk_state_c_box":chunk_state_c_box}) "chunk_state_c_box": chunk_state_c_box
})
print(sess.end_profiling()) print(sess.end_profiling())
...@@ -89,4 +83,4 @@ if __name__ == '__main__': ...@@ -89,4 +83,4 @@ if __name__ == '__main__':
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6)) print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
print(np.allclose(ort_res_lens, res_lens, atol=1e-6)) print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6)) print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6)) print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))
\ No newline at end of file
# Copyright (c) Microsoft Corporation. All rights reserved. # Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. # Licensed under the MIT License.
# flake8: noqa
import argparse import argparse
import logging import logging
...@@ -491,9 +492,6 @@ class SymbolicShapeInference: ...@@ -491,9 +492,6 @@ class SymbolicShapeInference:
skip_infer = node.op_type in [ skip_infer = node.op_type in [
'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \ 'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \
# contrib ops # contrib ops
'Attention', 'BiasGelu', \ 'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \ 'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \ 'FastGelu', 'Gelu', 'LayerNormalization', \
...@@ -1605,8 +1603,8 @@ class SymbolicShapeInference: ...@@ -1605,8 +1603,8 @@ class SymbolicShapeInference:
def _infer_Scan(self, node): def _infer_Scan(self, node):
subgraph = get_attribute(node, 'body') subgraph = get_attribute(node, 'body')
num_scan_inputs = get_attribute(node, 'num_scan_inputs') num_scan_inputs = get_attribute(node, 'num_scan_inputs')
scan_input_axes = get_attribute(node, 'scan_input_axes', [0] * scan_input_axes = get_attribute(node, 'scan_input_axes',
num_scan_inputs) [0] * num_scan_inputs)
num_scan_states = len(node.input) - num_scan_inputs num_scan_states = len(node.input) - num_scan_inputs
scan_input_axes = [ scan_input_axes = [
handle_negative_axis( handle_negative_axis(
...@@ -1627,8 +1625,8 @@ class SymbolicShapeInference: ...@@ -1627,8 +1625,8 @@ class SymbolicShapeInference:
si.name = subgraph_name si.name = subgraph_name
self._onnx_infer_subgraph(node, subgraph) self._onnx_infer_subgraph(node, subgraph)
num_scan_outputs = len(node.output) - num_scan_states num_scan_outputs = len(node.output) - num_scan_states
scan_output_axes = get_attribute(node, 'scan_output_axes', [0] * scan_output_axes = get_attribute(node, 'scan_output_axes',
num_scan_outputs) [0] * num_scan_outputs)
scan_input_dim = get_shape_from_type_proto( scan_input_dim = get_shape_from_type_proto(
self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
for i, o in enumerate(node.output): for i, o in enumerate(node.output):
...@@ -1821,8 +1819,8 @@ class SymbolicShapeInference: ...@@ -1821,8 +1819,8 @@ class SymbolicShapeInference:
split = get_attribute(node, 'split') split = get_attribute(node, 'split')
if not split: if not split:
num_outputs = len(node.output) num_outputs = len(node.output)
split = [input_sympy_shape[axis] / sympy.Integer(num_outputs) split = [input_sympy_shape[axis] /
] * num_outputs sympy.Integer(num_outputs)] * num_outputs
self._update_computed_dims(split) self._update_computed_dims(split)
else: else:
split = [sympy.Integer(s) for s in split] split = [sympy.Integer(s) for s in split]
...@@ -2174,8 +2172,8 @@ class SymbolicShapeInference: ...@@ -2174,8 +2172,8 @@ class SymbolicShapeInference:
subgraphs = [] subgraphs = []
if 'If' == node.op_type: if 'If' == node.op_type:
subgraphs = [ subgraphs = [
get_attribute(node, 'then_branch'), get_attribute( get_attribute(node, 'then_branch'),
node, 'else_branch') get_attribute(node, 'else_branch')
] ]
elif node.op_type in ['Loop', 'Scan']: elif node.op_type in ['Loop', 'Scan']:
subgraphs = [get_attribute(node, 'body')] subgraphs = [get_attribute(node, 'body')]
...@@ -2330,8 +2328,8 @@ class SymbolicShapeInference: ...@@ -2330,8 +2328,8 @@ class SymbolicShapeInference:
'LessOrEqual', 'GreaterOrEqual' 'LessOrEqual', 'GreaterOrEqual'
]: ]:
shapes = [ shapes = [
self._get_shape(node, i) for i in range( self._get_shape(node, i)
len(node.input)) for i in range(len(node.input))
] ]
if node.op_type in [ if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16' 'MatMul', 'MatMulInteger', 'MatMulInteger16'
......
#!/usr/bin/env python3 -W ignore::DeprecationWarning #!/usr/bin/env python3 -W ignore::DeprecationWarning
# prune model by output names # prune model by output names
import argparse import argparse
import copy import copy
import sys import sys
import onnx import onnx
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
......
#!/usr/bin/env python3 -W ignore::DeprecationWarning #!/usr/bin/env python3 -W ignore::DeprecationWarning
# rename node to new names # rename node to new names
import argparse import argparse
import sys import sys
......
...@@ -4,6 +4,7 @@ import argparse ...@@ -4,6 +4,7 @@ import argparse
# paddle inference shape # paddle inference shape
def process_old_ops_desc(program): def process_old_ops_desc(program):
"""set matmul op head_number attr to 1 is not exist. """set matmul op head_number attr to 1 is not exist.
......
...@@ -6,6 +6,7 @@ from typing import List ...@@ -6,6 +6,7 @@ from typing import List
# paddle prune model. # paddle prune model.
def prepend_feed_ops(program, def prepend_feed_ops(program,
feed_target_names: List[str], feed_target_names: List[str],
feed_holder_name='feed'): feed_holder_name='feed'):
......
...@@ -747,7 +747,7 @@ def num2chn(number_string, ...@@ -747,7 +747,7 @@ def num2chn(number_string,
previous_symbol, (CNU, type(None))): previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and ( if next_symbol.power != 1 and (
(previous_symbol is None) or (previous_symbol is None) or
(previous_symbol.power != 1)): (previous_symbol.power != 1)): # noqa: E129
result_symbols[i] = liang result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output # if big is True, '两' will not be used and `alt_two` has no impact on output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册