diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc index 04cb4f31290825049a230246ec763f3c55b62d42..bc6e2bfbd0747ab941a2443c5e156c2564a07671 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc @@ -29,7 +29,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraitscast(); - auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_NCHW); + auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_ND); return ge_tensor == nullptr ? GeTensor() : *ge_tensor; } diff --git a/mindspore/common/api.py b/mindspore/common/api.py index b827ffe3455e92584052db9399abbad962b3b502..128c8d87e5b7fc77e0f965a76be1faeea2b9feb1 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -388,7 +388,7 @@ class _Executor: dic = dict(zip(args_names, args_list)) key = generate_key(phase, dic) self.phase_prefix = str(key[1]) - if phase == 'export': + if 'export' in phase: phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) else: phase = self.phase_prefix + phase + '.' + str(obj.create_time) diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 02b2319d6def1c36ed167734ac413dbee1a242ac..60d0f86be16adff16b7b6fd6dab5dbcd298041ca 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -332,6 +332,7 @@ class Quant(PrimitiveWithInfer): self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.round_mode = validator.check_string("round_mode", round_mode, ["Round", "Floor", "Ceil", "Trunc"], self.name) + self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape): return x_shape @@ -382,6 +383,7 @@ class Dequant(PrimitiveWithInfer): self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) self.add_prim_attr("dtype", mstype.float16) + self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, deq_scale_shape): return x_shape diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index fec60a5b00e19b6a2cd36ba6b07ef352537d5954..9fccff0622b13053f6d474aa41e266e20a7cc63d 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -258,6 +258,7 @@ class _Reduce(PrimitiveWithInfer): """init Reduce""" validator.check_value_type('keep_dims', keep_dims, [bool], self.name) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) + self.add_prim_attr("io_format", "ND") def __call__(self, x, axis=()): args = [x, axis] @@ -626,6 +627,7 @@ class MatMul(PrimitiveWithInfer): cls_name = self.name validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) + self.add_prim_attr("io_format", "ND") def check_shape_size(self, x, y): if len(x) != 2 or len(y) != 2: diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 767dc9bdadca1cf3180e3e8d96c17974b133f687..c112f1e7113a6bb3ef458fe108ad9cbc6bf043a1 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -314,8 +314,8 @@ class ExportToQuantInferNetwork: network = validator.check_isinstance('network', network, (nn.Cell,)) # quantize for inputs: q = f / scale + zero_point # dequantize for outputs: f = (q - zero_point) * scale - self.input_scale = round(mean) - self.input_zero_point = 1 / std_dev + self.input_scale = 1 / std_dev + self.input_zero_point = round(mean) self.data_type = mstype.int8 self.network = copy.deepcopy(network) self.all_parameters = {p.name: p for p in self.network.get_parameters()} @@ -351,20 +351,16 @@ class ExportToQuantInferNetwork: else: maxq = self.all_parameters[minq_name[:-4] + "maxq"] minq = self.all_parameters[minq_name] - scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) + scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type) else: logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") return None # Build the `Quant` `Dequant` op. # Quant only support perlayer version. Need check here. - quant_op = inner.Quant(float(scale_a_in), float(zp_a_in)) - sqrt_mode = False + quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in)) scale_deq = scale_a_out * scale_w - if (scale_deq < 2 ** -14).all(): - scale_deq = np.sqrt(scale_deq) - sqrt_mode = True - dequant_op = inner.Dequant(sqrt_mode) + dequant_op = inner.Dequant() if isinstance(activation, _AddFakeQuantAfterSubCell): activation = activation.subcell @@ -385,8 +381,19 @@ class ExportToQuantInferNetwork: # apply the quant weight = quant_utils.weight2int(weight, scale_w, zp_w) if bias is not None: - bias = Tensor(scale_a_in * scale_w * bias, mstype.int32) - scale_deq = Tensor(scale_deq, mstype.float16) + bias = Tensor(bias / scale_a_in / scale_w, mstype.int32) + + # fuse parameter + # |--------|47:40|--------|39:32|--------|31:0| + # offset_w [8] shift_N [8] deq_scale [32] + float32_deq_scale = scale_deq.astype(np.float32) + uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32) + scale_length = scale_deq.size # channel + dequant_param = np.zeros(scale_length, dtype=np.uint64) + for index in range(scale_length): + dequant_param[index] += uint32_deq_scale[index] + + scale_deq = Tensor(dequant_param, mstype.uint64) # get op if isinstance(cell_core, quant.DenseQuant): op_core = P.MatMul()