...
 
Commits (2)
    https://gitcode.net/pulltheflower/opencv_zoo/-/commit/c6ddb9f96e02cc1db1e3a73419b964bcdb75ae8a shorten int8-quantized naming (#149) 2023-03-21T15:39:45+08:00 Yuantao Feng yuantao.feng@opencv.org.cn https://gitcode.net/pulltheflower/opencv_zoo/-/commit/afc67194681f3dcca7421b6695fd73622a7f1488 more renamings (#150) 2023-03-21T23:04:09+08:00 Yuantao Feng yuantao.feng@opencv.org.cn
......@@ -59,29 +59,30 @@ class Quantize:
# data reader
self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim)
def check_opset(self, convert=True):
def check_opset(self):
model = onnx.load(self.model_path)
if model.opset_import[0].version != 13:
print('\tmodel opset version: {}. Converting to opset 13'.format(model.opset_import[0].version))
# convert opset version to 13
model_opset13 = version_converter.convert_version(model, 13)
# save converted model
output_name = '{}-opset.onnx'.format(self.model_path[:-5])
output_name = '{}-opset13.onnx'.format(self.model_path[:-5])
onnx.save_model(model_opset13, output_name)
# update model_path for quantization
self.model_path = output_name
return output_name
return self.model_path
def run(self):
print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
self.check_opset()
output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type)
quantize_static(self.model_path, output_name, self.dr,
new_model_path = self.check_opset()
output_name = '{}_{}.onnx'.format(self.model_path[:-5], self.wt_type)
quantize_static(new_model_path, output_name, self.dr,
quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
per_channel=self.per_channel,
weight_type=self.type_dict[self.wt_type],
activation_type=self.type_dict[self.act_type])
os.remove('augmented_model.onnx')
os.remove('{}-opt.onnx'.format(self.model_path[:-5]))
if new_model_path != self.model_path:
os.remove(new_model_path)
print('\tQuantized model saved to {}'.format(output_name))
models=dict(
......@@ -132,4 +133,3 @@ if __name__ == '__main__':
for selected_model_name in selected_models:
q = models[selected_model_name]
q.run()