提交 b13343f8 编写于 作者: W Wei Luning

fix quant bug

上级 8878f448
...@@ -29,7 +29,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor ...@@ -29,7 +29,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor
// To-DO the format may read from ME tensor // To-DO the format may read from ME tensor
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
auto me_tensor = value->cast<MeTensorPtr>(); auto me_tensor = value->cast<MeTensorPtr>();
auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_NCHW); auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_ND);
return ge_tensor == nullptr ? GeTensor() : *ge_tensor; return ge_tensor == nullptr ? GeTensor() : *ge_tensor;
} }
......
...@@ -388,7 +388,7 @@ class _Executor: ...@@ -388,7 +388,7 @@ class _Executor:
dic = dict(zip(args_names, args_list)) dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic) key = generate_key(phase, dic)
self.phase_prefix = str(key[1]) self.phase_prefix = str(key[1])
if phase == 'export': if 'export' in phase:
phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
else: else:
phase = self.phase_prefix + phase + '.' + str(obj.create_time) phase = self.phase_prefix + phase + '.' + str(obj.create_time)
......
...@@ -332,6 +332,7 @@ class Quant(PrimitiveWithInfer): ...@@ -332,6 +332,7 @@ class Quant(PrimitiveWithInfer):
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
self.round_mode = validator.check_string("round_mode", round_mode, self.round_mode = validator.check_string("round_mode", round_mode,
["Round", "Floor", "Ceil", "Trunc"], self.name) ["Round", "Floor", "Ceil", "Trunc"], self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape
...@@ -382,6 +383,7 @@ class Dequant(PrimitiveWithInfer): ...@@ -382,6 +383,7 @@ class Dequant(PrimitiveWithInfer):
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
self.add_prim_attr("dtype", mstype.float16) self.add_prim_attr("dtype", mstype.float16)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, deq_scale_shape): def infer_shape(self, x_shape, deq_scale_shape):
return x_shape return x_shape
......
...@@ -258,6 +258,7 @@ class _Reduce(PrimitiveWithInfer): ...@@ -258,6 +258,7 @@ class _Reduce(PrimitiveWithInfer):
"""init Reduce""" """init Reduce"""
validator.check_value_type('keep_dims', keep_dims, [bool], self.name) validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
self.add_prim_attr("io_format", "ND")
def __call__(self, x, axis=()): def __call__(self, x, axis=()):
args = [x, axis] args = [x, axis]
...@@ -626,6 +627,7 @@ class MatMul(PrimitiveWithInfer): ...@@ -626,6 +627,7 @@ class MatMul(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
self.add_prim_attr("io_format", "ND")
def check_shape_size(self, x, y): def check_shape_size(self, x, y):
if len(x) != 2 or len(y) != 2: if len(x) != 2 or len(y) != 2:
......
...@@ -314,8 +314,8 @@ class ExportToQuantInferNetwork: ...@@ -314,8 +314,8 @@ class ExportToQuantInferNetwork:
network = validator.check_isinstance('network', network, (nn.Cell,)) network = validator.check_isinstance('network', network, (nn.Cell,))
# quantize for inputs: q = f / scale + zero_point # quantize for inputs: q = f / scale + zero_point
# dequantize for outputs: f = (q - zero_point) * scale # dequantize for outputs: f = (q - zero_point) * scale
self.input_scale = round(mean) self.input_scale = 1 / std_dev
self.input_zero_point = 1 / std_dev self.input_zero_point = round(mean)
self.data_type = mstype.int8 self.data_type = mstype.int8
self.network = copy.deepcopy(network) self.network = copy.deepcopy(network)
self.all_parameters = {p.name: p for p in self.network.get_parameters()} self.all_parameters = {p.name: p for p in self.network.get_parameters()}
...@@ -351,20 +351,16 @@ class ExportToQuantInferNetwork: ...@@ -351,20 +351,16 @@ class ExportToQuantInferNetwork:
else: else:
maxq = self.all_parameters[minq_name[:-4] + "maxq"] maxq = self.all_parameters[minq_name[:-4] + "maxq"]
minq = self.all_parameters[minq_name] minq = self.all_parameters[minq_name]
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type)
else: else:
logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
return None return None
# Build the `Quant` `Dequant` op. # Build the `Quant` `Dequant` op.
# Quant only support perlayer version. Need check here. # Quant only support perlayer version. Need check here.
quant_op = inner.Quant(float(scale_a_in), float(zp_a_in)) quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
sqrt_mode = False
scale_deq = scale_a_out * scale_w scale_deq = scale_a_out * scale_w
if (scale_deq < 2 ** -14).all(): dequant_op = inner.Dequant()
scale_deq = np.sqrt(scale_deq)
sqrt_mode = True
dequant_op = inner.Dequant(sqrt_mode)
if isinstance(activation, _AddFakeQuantAfterSubCell): if isinstance(activation, _AddFakeQuantAfterSubCell):
activation = activation.subcell activation = activation.subcell
...@@ -385,8 +381,19 @@ class ExportToQuantInferNetwork: ...@@ -385,8 +381,19 @@ class ExportToQuantInferNetwork:
# apply the quant # apply the quant
weight = quant_utils.weight2int(weight, scale_w, zp_w) weight = quant_utils.weight2int(weight, scale_w, zp_w)
if bias is not None: if bias is not None:
bias = Tensor(scale_a_in * scale_w * bias, mstype.int32) bias = Tensor(bias / scale_a_in / scale_w, mstype.int32)
scale_deq = Tensor(scale_deq, mstype.float16)
# fuse parameter
# |--------|47:40|--------|39:32|--------|31:0|
# offset_w [8] shift_N [8] deq_scale [32]
float32_deq_scale = scale_deq.astype(np.float32)
uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32)
scale_length = scale_deq.size # channel
dequant_param = np.zeros(scale_length, dtype=np.uint64)
for index in range(scale_length):
dequant_param[index] += uint32_deq_scale[index]
scale_deq = Tensor(dequant_param, mstype.uint64)
# get op # get op
if isinstance(cell_core, quant.DenseQuant): if isinstance(cell_core, quant.DenseQuant):
op_core = P.MatMul() op_core = P.MatMul()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册