提交 e2bb990e 编写于 作者: Z Zihao Mu

add QDQ format onnx model.

上级 bf616317
......@@ -5,9 +5,9 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import onnx
import onnx # version >= 1.12.0
import onnxruntime as rt
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType, QuantFormat
class DataReader(CalibrationDataReader):
def __init__(self, model_path, batchsize=5):
......@@ -20,16 +20,16 @@ class DataReader(CalibrationDataReader):
def get_next(self):
return next(self.enum_data_dicts, None)
def quantize_and_save_model(name, input, model, act_type="uint8", wt_type="uint8", per_channel=False):
def quantize_and_save_model(name, input, model, act_type="uint8", wt_type="uint8", per_channel=False, ops_version = 13, quanFormat=QuantFormat.QOperator):
float_model_path = os.path.join("models", "dummy.onnx")
quantized_model_path = os.path.join("models", name + ".onnx")
type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}
model.eval()
torch.onnx.export(model, input, float_model_path, export_params=True, opset_version=12)
torch.onnx.export(model, input, float_model_path, export_params=True, opset_version=ops_version)
dr = DataReader(float_model_path)
quantize_static(float_model_path, quantized_model_path, dr, per_channel=per_channel,
quantize_static(float_model_path, quantized_model_path, dr, quant_format=quanFormat, per_channel=per_channel,
activation_type=type_dict[act_type], weight_type=type_dict[wt_type])
os.remove(float_model_path)
......@@ -53,10 +53,16 @@ np.random.seed(0)
input = Variable(torch.randn(1, 3, 10, 10))
conv = nn.Conv2d(3, 5, kernel_size=3, stride=2, padding=1)
# generate QOperator qunatized model
quantize_and_save_model("quantized_conv_uint8_weights", input, conv)
quantize_and_save_model("quantized_conv_int8_weights", input, conv, wt_type="int8")
quantize_and_save_model("quantized_conv_per_channel_weights", input, conv, per_channel=True)
# generate QDQ qunatized model
quantize_and_save_model("quantized_conv_uint8_weights_qdq", input, conv, quanFormat=QuantFormat.QDQ)
quantize_and_save_model("quantized_conv_int8_weights_qdq", input, conv, wt_type="int8", quanFormat=QuantFormat.QDQ)
quantize_and_save_model("quantized_conv_per_channel_weights_qdq", input, conv, per_channel=True, quanFormat=QuantFormat.QDQ)
input = Variable(torch.randn(1, 3))
linear = nn.Linear(3, 4, bias=True)
quantize_and_save_model("quantized_matmul_uint8_weights", input, linear)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册