提交 b391eb2b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2209 add function for quant_utils that convert float to int

Merge pull request !2209 from chenzhongming/master
......@@ -19,6 +19,7 @@ import numpy as np
def cal_quantization_params(input_min,
input_max,
data_type,
num_bits=8,
symmetric=False,
narrow_range=False):
......@@ -28,6 +29,7 @@ def cal_quantization_params(input_min,
Args:
input_min (int, list): The dimension of channel or 1.
input_max (int, list): The dimension of channel or 1.
data_type (numpy type) : Can ben numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
......@@ -52,7 +54,7 @@ def cal_quantization_params(input_min,
# scale = 1.0, zp = 0.0
return np.ones(input_min.shape), np.zeros(input_min.shape)
if symmetric:
if data_type == np.int8:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1)
else:
......@@ -84,3 +86,33 @@ def cal_quantization_params(input_min,
zp = np.floor(zp_double + 0.5)
return scale, zp
def weight2int(data,
scale,
zero_point):
r"""
calculate int8/uint8 weight from fp32. the formula is defined as:
.. math::
int8/uint8 = round(float/scale) + offset
Args:
data (int, list): The dimension of channel or 1. Should be NCHW.
scale (int, list): The dimension of channel or 1.
zero_point (int, list): The dimension of channel or 1.
Outputs:
weight (int, list): The dimension of channel or 1.
Examples:
>>> weight = weight2int([1, 2, 1], 1, 0)
"""
if scale.shape != zero_point.shape:
raise ValueError("scale and zero_point should have the same shape.")
if scale.shape[0] > 0:
scale = scale.reshape(1, -1, 1, 1)
zero_point = zero_point.reshape(1, -1, 1, 1)
return np.round((data/scale) + zero_point)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册