提交 c3f762eb 编写于 作者: H Hui Zhang

format code

上级 3cf1f1f0
#!/usr/bin/env python3
import argparse
import onnx
from onnx import version_converter, helper
import onnx
from onnx import version_converter
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog=__doc__)
parser.add_argument("--model-file", type=str, required=True, help='path/to/the/model.onnx.')
parser.add_argument("--save-model", type=str, required=True, help='path/to/saved/model.onnx.')
parser.add_argument(
"--model-file", type=str, required=True, help='path/to/the/model.onnx.')
parser.add_argument(
"--save-model",
type=str,
required=True,
help='path/to/saved/model.onnx.')
# Models must be opset10 or higher to be quantized.
parser.add_argument("--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
parser.add_argument(
"--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
args = parser.parse_args()
......@@ -24,7 +30,8 @@ if __name__ == '__main__':
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model
converted_model = version_converter.convert_version(original_model, args.target_opset)
converted_model = version_converter.convert_version(original_model,
args.target_opset)
# print('The model after conversion:\n{}'.format(converted_model))
onnx.save(converted_model, args.save_model)
......@@ -494,6 +494,8 @@ class SymbolicShapeInference:
# contrib ops
'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \
......
#!/usr/bin/env python3
import argparse
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
def quantize_onnx_model(onnx_model_path, quantized_model_path, nodes_to_exclude=[]):
from onnxruntime.quantization import quantize_dynamic
from onnxruntime.quantization import QuantType
def quantize_onnx_model(onnx_model_path,
quantized_model_path,
nodes_to_exclude=[]):
print("Starting quantization...")
from onnxruntime.quantization import QuantType, quantize_dynamic
quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8, nodes_to_exclude=nodes_to_exclude)
quantize_dynamic(
onnx_model_path,
quantized_model_path,
weight_type=QuantType.QInt8,
nodes_to_exclude=nodes_to_exclude)
print(f"Quantized model saved to: {quantized_model_path}")
......@@ -18,26 +25,24 @@ def main():
"--model-in",
type=str,
required=True,
help="ONNX model",
)
help="ONNX model", )
parser.add_argument(
"--model-out",
type=str,
required=True,
default='model.quant.onnx',
help="ONNX model",
)
help="ONNX model", )
parser.add_argument(
"--nodes-to-exclude",
type=str,
required=True,
help="nodes to exclude. e.g. conv,linear.",
)
help="nodes to exclude. e.g. conv,linear.", )
args = parser.parse_args()
nodes_to_exclude = args.nodes_to_exclude.split(',')
quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude)
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册