提交 c3f762eb 编写于 作者: H Hui Zhang

format code

上级 3cf1f1f0
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import onnx
from onnx import version_converter, helper
import onnx
from onnx import version_converter
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(prog=__doc__) parser = argparse.ArgumentParser(prog=__doc__)
parser.add_argument("--model-file", type=str, required=True, help='path/to/the/model.onnx.') parser.add_argument(
parser.add_argument("--save-model", type=str, required=True, help='path/to/saved/model.onnx.') "--model-file", type=str, required=True, help='path/to/the/model.onnx.')
parser.add_argument(
"--save-model",
type=str,
required=True,
help='path/to/saved/model.onnx.')
# Models must be opset10 or higher to be quantized. # Models must be opset10 or higher to be quantized.
parser.add_argument("--target-opset", type=int, default=11, help='path/to/the/model.onnx.') parser.add_argument(
"--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
args = parser.parse_args() args = parser.parse_args()
...@@ -24,7 +30,8 @@ if __name__ == '__main__': ...@@ -24,7 +30,8 @@ if __name__ == '__main__':
# A full list of supported adapters can be found here: # A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21 # https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model # Apply the version conversion on the original model
converted_model = version_converter.convert_version(original_model, args.target_opset) converted_model = version_converter.convert_version(original_model,
args.target_opset)
# print('The model after conversion:\n{}'.format(converted_model)) # print('The model after conversion:\n{}'.format(converted_model))
onnx.save(converted_model, args.save_model) onnx.save(converted_model, args.save_model)
...@@ -494,6 +494,8 @@ class SymbolicShapeInference: ...@@ -494,6 +494,8 @@ class SymbolicShapeInference:
# contrib ops # contrib ops
'Attention', 'BiasGelu', \ 'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \ 'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \ 'FastGelu', 'Gelu', 'LayerNormalization', \
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
def quantize_onnx_model(onnx_model_path, quantized_model_path, nodes_to_exclude=[]): from onnxruntime.quantization import quantize_dynamic
from onnxruntime.quantization import QuantType
def quantize_onnx_model(onnx_model_path,
quantized_model_path,
nodes_to_exclude=[]):
print("Starting quantization...") print("Starting quantization...")
from onnxruntime.quantization import QuantType, quantize_dynamic
quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8, nodes_to_exclude=nodes_to_exclude) quantize_dynamic(
onnx_model_path,
quantized_model_path,
weight_type=QuantType.QInt8,
nodes_to_exclude=nodes_to_exclude)
print(f"Quantized model saved to: {quantized_model_path}") print(f"Quantized model saved to: {quantized_model_path}")
...@@ -18,26 +25,24 @@ def main(): ...@@ -18,26 +25,24 @@ def main():
"--model-in", "--model-in",
type=str, type=str,
required=True, required=True,
help="ONNX model", help="ONNX model", )
)
parser.add_argument( parser.add_argument(
"--model-out", "--model-out",
type=str, type=str,
required=True, required=True,
default='model.quant.onnx', default='model.quant.onnx',
help="ONNX model", help="ONNX model", )
)
parser.add_argument( parser.add_argument(
"--nodes-to-exclude", "--nodes-to-exclude",
type=str, type=str,
required=True, required=True,
help="nodes to exclude. e.g. conv,linear.", help="nodes to exclude. e.g. conv,linear.", )
)
args = parser.parse_args() args = parser.parse_args()
nodes_to_exclude = args.nodes_to_exclude.split(',') nodes_to_exclude = args.nodes_to_exclude.split(',')
quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude) quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册