From 563691808fbe513f3b42278c9256a12d15a938f8 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Wed, 17 Jun 2020 12:57:39 +0800 Subject: [PATCH] add quant utils --- mindspore/train/quant/quant_utils.py | 34 +++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index 18f068b25..50927b0ca 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -19,6 +19,7 @@ import numpy as np def cal_quantization_params(input_min, input_max, + data_type, num_bits=8, symmetric=False, narrow_range=False): @@ -28,6 +29,7 @@ def cal_quantization_params(input_min, Args: input_min (int, list): The dimension of channel or 1. input_max (int, list): The dimension of channel or 1. + data_type (numpy type) : Can ben numpy int8, numpy uint8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -52,7 +54,7 @@ def cal_quantization_params(input_min, # scale = 1.0, zp = 0.0 return np.ones(input_min.shape), np.zeros(input_min.shape) - if symmetric: + if data_type == np.int8: quant_min = 0 - 2 ** (num_bits - 1) quant_max = 2 ** (num_bits - 1) else: @@ -84,3 +86,33 @@ def cal_quantization_params(input_min, zp = np.floor(zp_double + 0.5) return scale, zp + + +def weight2int(data, + scale, + zero_point): + r""" + calculate int8/uint8 weight from fp32. the formula is defined as: + + .. math:: + + int8/uint8 = round(float/scale) + offset + + Args: + data (int, list): The dimension of channel or 1. Should be NCHW. + scale (int, list): The dimension of channel or 1. + zero_point (int, list): The dimension of channel or 1. + + Outputs: + weight (int, list): The dimension of channel or 1. + + Examples: + >>> weight = weight2int([1, 2, 1], 1, 0) + """ + if scale.shape != zero_point.shape: + raise ValueError("scale and zero_point should have the same shape.") + if scale.shape[0] > 0: + scale = scale.reshape(1, -1, 1, 1) + zero_point = zero_point.reshape(1, -1, 1, 1) + + return np.round((data/scale) + zero_point) -- GitLab